#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# Univention ucslint
"""Check UCS packages for policy compliance."""
#
# Like what you see? Join us!
# https://www.univention.com/about-us/careers/vacancies/
#
# Copyright 2008-2023 Univention GmbH
#
# https://www.univention.de/
#
# All rights reserved.
#
# The source code of this program is made available
# under the terms of the GNU Affero General Public License version 3
# (GNU AGPL V3) as published by the Free Software Foundation.
#
# Binary versions of this program provided by Univention to you as
# well as other copyrighted, protected or trademarked materials like
# Logos, graphics, fonts, specific documentations and configurations,
# cryptographic keys etc. are subject to a license agreement between
# you and Univention and not subject to the GNU AGPL V3.
#
# In the case you use this program under the terms of the GNU AGPL V3,
# the program is provided in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License with the Debian GNU/Linux or Univention distribution in file
# /usr/share/common-licenses/AGPL-3; if not, see
# <https://www.gnu.org/licenses/>.

import os
import re
import sys
from argparse import ArgumentParser, ArgumentTypeError, FileType, Namespace  # noqa: F401
from errno import ENOENT
from fnmatch import fnmatch
from importlib.util import module_from_spec, spec_from_file_location
from os.path import abspath, dirname, exists, expanduser, isdir, join, normpath, relpath
from types import ModuleType  # noqa: F401
from typing import IO, Container, Dict, List, Optional, Sequence, Set, Tuple  # noqa: F401

import univention.ucslint.base as uub


try:
    from junit_xml import TestSuite  # type: ignore
except ImportError:
    pass


RE_OVERRIDE = re.compile(
    r'''^
    (?P<module> \d+-[BEFNW]?\d+)
    (?: [:]
        (?: \s* (?P<pattern> .+?) \s*
            (?: [:] \s* (?P<linenumber> \d+)
            )?
        )?
    )?
    $''', re.VERBOSE)


class DebianPackageCheck(object):
    """
    Check Debian package for policy compliance.

    :param path: Base directory of Debian package to check.
    :param plugindirs: List of directories containing plugins to load.
    :param enabled_modules: List of enabled modules.
    :param disabled_modules: List of disabled modules.
    :param debuglevel: Vebosity level.
    """

    def __init__(self, path: str, plugindirs: Sequence[str], enabled_modules: Container[str] = None, disabled_modules: Container[str] = None, debuglevel: int = 0):
        self.path = path
        self.plugindirs = plugindirs
        self.pluginlist = {}  # type: Dict[str, ModuleType]
        self.msglist = []  # type: List[uub.UPCMessage]
        self.enabled_modules = enabled_modules or ()  # type: Container[str]
        self.disabled_modules = disabled_modules or ()  # type: Container[str]
        self.debuglevel = debuglevel
        self.msgidlist = {}  # type: Dict[str, Tuple[int, str]]
        self.overrides = set()  # type: Set[Tuple[str, Optional[str], Optional[int]]]
        self.loadplugins()

    def loadplugins(self) -> None:
        """Load modules from plugin directory."""
        for plugindir in self.plugindirs:
            if not os.path.exists(plugindir):
                if self.debuglevel:
                    print('WARNING: plugindir %s does not exist' % plugindir, file=sys.stderr)
            else:
                for f in os.listdir(plugindir):
                    if f.endswith('.py') and f[0:4].isdigit():
                        # self.modules == None ==> load all modules
                        # otherwise load only listed modules
                        if (not self.enabled_modules or f[0:4] in self.enabled_modules) and f[0:4] not in self.disabled_modules:
                            modname = f[0:-3]
                            try:
                                fn = os.path.join(plugindir, f)
                                spec = spec_from_file_location(f[:-3], fn)
                                module = module_from_spec(spec)
                                assert spec.loader
                                spec.loader.exec_module(module)  # type: ignore
                                self.pluginlist[modname] = module
                            except Exception as exc:
                                print('ERROR: Loading module %s failed: %s' % (f, exc), file=sys.stderr)
                                if self.debuglevel:
                                    raise
                        else:
                            if self.debuglevel:
                                print('Module %s is not enabled' % f, file=sys.stderr)

    def check(self) -> None:
        """Run plugin on files in path."""
        for plugin in self.pluginlist.values():
            obj = plugin.UniventionPackageCheck()  # type: ignore
            self.msgidlist.update(obj.getMsgIds())
            obj.setdebug(self.debuglevel)
            obj.postinit(self.path)
            try:
                obj.check(self.path)
            except uub.UCSLintException as ex:
                print(ex, file=sys.stderr)
            self.msglist.extend(obj.result())

    def loadOverrides(self) -> None:
        """Parse :file:`debian/ucslint.overrides` file."""
        self.overrides = set()
        fn = os.path.join(self.path, 'debian', 'ucslint.overrides')
        try:
            with open(fn) as overrides:
                for row, line in enumerate(overrides, start=1):
                    line = line.strip()
                    if not line:
                        continue
                    if line.startswith('#'):
                        continue
                    result = RE_OVERRIDE.match(line)
                    if not result:
                        print('IGNORED: debian/ucslint.overrides:%d: %s' % (row, line), file=sys.stderr)
                        continue

                    module, pattern, linenumber = result.groups()
                    override = (module, pattern, int(linenumber) if pattern and linenumber else None)
                    self.overrides.add(override)
        except EnvironmentError as ex:
            if ex.errno != ENOENT:
                print('WARNING: load debian/ucslint.overrides: %s' % (ex,), file=sys.stderr)

    def in_overrides(self, msg: uub.UPCMessage) -> bool:
        """
        Check message against overrides.

        :param msg: Message to check.
        :returns: `True` when the check should be ignored, `False` otherwise.
        """
        filepath = normpath(relpath(msg.filename, self.path)) if msg.filename else ''
        for (modulename, pattern, linenumber) in self.overrides:
            if modulename != msg.getId():
                continue
            if pattern and not fnmatch(filepath, pattern):
                continue
            if linenumber is not None and linenumber != msg.row:
                continue
            return True
        return False

    def printResult(self, ignore_IDs: Container[str], display_only_IDs: Container[str], display_only_categories: str, exitcode_categories: str, junit: IO[str] = None) -> Tuple[int, int]:
        """
        Print result of checks.

        :param ignore_IDs: List of message identifiers to ignore.
        :param display_only_IDs: List of message identifiers to display.
        :param display_only_categories: List of message categories to display.
        :param exitcode_categories: List of message categories to signal as fatal.
        :param junit: Generate JUnit XML output to given file.
        :returns: 2-tuple (incident-count, exitcode-count)
        """
        incident_cnt = 0
        exitcode_cnt = 0

        self.loadOverrides()
        test_cases = []  # type: List[uub.TestCase]

        for msg in self.msglist:
            tc = msg.junit()
            test_cases.append(tc)

            if msg.getId() in ignore_IDs:
                tc.add_skipped_info('ignored')
                continue
            if display_only_IDs and msg.getId() not in display_only_IDs:
                tc.add_skipped_info('hidden')
                continue
            if self.in_overrides(msg):
                # ignore msg if mentioned in overrides files
                tc.add_skipped_info('overridden')
                continue

            msgid = msg.getId()
            try:
                lvl, msgstr = self.msgidlist[msgid]
                category = uub.RESULT_INT2STR[lvl]
            except LookupError:
                category = 'FIXME'

            if category in display_only_categories or display_only_categories == '':
                print('%s:%s' % (category, msg))
                incident_cnt += 1

            if category in exitcode_categories or exitcode_categories == '':
                exitcode_cnt += 1

        if junit:
            ts = TestSuite("ucslint", test_cases)
            TestSuite.to_file(junit, [ts], prettyprint=False)

        return incident_cnt, exitcode_cnt


def clean_id(idstr: str) -> str:
    """
    Format message ID string.

    :param idstr: message identifier.
    :returns: formatted message identifier.

    >>> clean_id('1-2')
    '0001-2'
    """
    if '-' not in idstr:
        raise ValueError('no valid id (%s) - missing dash' % idstr)
    modid, msgid = idstr.strip().split('-', 1)
    return '%s-%s' % (clean_modid(modid), clean_msgid(msgid))


def clean_modid(modid: str) -> str:
    """
    Format module ID string.

    :param modid: module number.
    :returns: formatted module number.

    >>> clean_modid('1')
    '0001'
    """
    if not modid.isdigit():
        raise ValueError('modid contains invalid characters: %s' % modid)
    return '%04d' % (int(modid))


def clean_msgid(msgid: str) -> str:
    """
    Format message ID string.

    :param msgid: message number.
    :returns: formatted message number.

    >>> clean_msgid('01')
    '1'
    """
    if not msgid.isdigit():
        raise ValueError('msgid contains invalid characters: %s' % msgid)
    return '%d' % int(msgid)


def parse_args() -> Namespace:
    """
    Parse command line arguments.

    :returns: parsed options.
    """
    parser = ArgumentParser()
    parser.add_argument(
        '-d', '--debug', metavar='LEVEL',
        type=int, default=0,
        help='if set, debugging is activated and set to the specified level',
    )
    parser.add_argument(
        '-m', '--modules', metavar='MODULES',
        dest='enabled_modules', action='append', default=[],
        help='list of modules to be loaded (e.g. -m 0009,27)',
    )
    parser.add_argument(
        '-x', '--exclude-modules', metavar='MODULES',
        dest='disabled_modules', action='append', default=[],
        help='list of modules to be disabled (e.g. -x 9,027)',
    )
    parser.add_argument(
        '-o', '--display-only', metavar='MODULES',
        dest='display_only_IDs', action='append', default=[],
        help='list of IDs to be displayed (e.g. -o 9-1,0027-12)',
    )
    parser.add_argument(
        '-i', '--ignore', metavar='MODULES',
        dest='ignore_IDs', action='append', default=[],
        help='list of IDs to be ignored (e.g. -i 0003-4,19-27)',
    )
    parser.add_argument(
        '--skip-univention', '-U',
        action='append_const', dest='ignore_IDs', const='0007-2,0010-2,0010-3,0010-4,0011-3,0011-4,0011-5,0011-13',
        help='Ignore Univention specific tests',
    )
    parser.add_argument(
        '-p', '--plugindir', metavar='DIRECTORY', action='append',
        default=[],
        help='override plugin directory with <plugindir>',
    )
    parser.add_argument(
        '-c', '--display-categories', metavar='CATEGORIES',
        dest='display_only_categories', default='',
        help='categories to be displayed (e.g. -c EWIS)',
    )
    parser.add_argument(
        '-e', '--exitcode-categories', metavar='CATEGORIES',
        default='E',
        help='categories that cause an exitcode != 0 (e.g. -e EWIS)',
    )
    parser.add_argument(
        '--junit-xml', '-j', metavar='FILE',
        type=FileType("w"),
        help='generate JUnit-XML output',
    )
    parser.add_argument(
        'pkgpath',
        nargs='*', type=debian_dir,
        help="Source package directory",
    )
    args = parser.parse_args()

    if args.junit_xml and not uub.JUNIT:
        parser.error("Missing Python support for JUNIT_XML")

    try:
        if not args.pkgpath:
            args.pkgpath = [debian_dir('.')]
    except ArgumentTypeError as ex:
        parser.error(str(ex))

    return args


def debian_dir(pkgpath: str) -> str:
    """
    Check if given path is base for a Debian package.

    :param pkgpath: base path.
    :returns: same path.
    :raises ArgumentTypeError: when it is no Debian package.
    """
    if not exists(pkgpath):
        msg = "directory %r does not exist!" % (pkgpath,)
        raise ArgumentTypeError(msg)

    if not isdir(pkgpath):
        msg = "%r is no directory!" % (pkgpath,)
        raise ArgumentTypeError(msg)

    debdir = join(pkgpath, 'debian')
    if not isdir(debdir):
        msg = "%r does not exist or is not a directory!" % (debdir,)
        raise ArgumentTypeError(msg)

    return pkgpath


def main() -> None:
    """Run checks."""
    options = parse_args()
    if options.debug:
        print('Using univention.ucslint.base from %s' % (uub.__file__,))

    plugindirs = [abspath(path) for path in (options.plugindir or [expanduser('~/.ucslint'), dirname(uub.__file__)])]
    ignore_IDs = [clean_id(x) for ign in options.ignore_IDs for x in ign.split(',')]
    display_only_IDs = [clean_id(x) for dsp in options.display_only_IDs for x in dsp.split(',')]
    enabled_modules = [clean_modid(x) for mod in options.enabled_modules for x in mod.split(',')]
    disabled_modules = [clean_modid(x) for mod in options.disabled_modules for x in mod.split(',')]

    fail = False
    for pkgpath in options.pkgpath or ['.']:
        chk = DebianPackageCheck(pkgpath, plugindirs, enabled_modules=enabled_modules, disabled_modules=disabled_modules, debuglevel=options.debug)
        try:
            chk.check()
        except uub.UCSLintException as ex:
            print(ex, file=sys.stderr)

        incident_cnt, exitcode_cnt = chk.printResult(ignore_IDs, display_only_IDs, options.display_only_categories, options.exitcode_categories, options.junit_xml)
        fail |= bool(exitcode_cnt)

    if fail:
        sys.exit(2)


if __name__ == '__main__':
    main()
