#!/usr/bin/python3 -u
# Like what you see? Join us!
# https://www.univention.com/about-us/careers/vacancies/
#
# SPDX-FileCopyrightText: 2024-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

"""UCS Testrunner - run UCS test in sane environment."""

from __future__ import annotations

import argparse
import logging
import operator
import sys
from time import time
from typing import TYPE_CHECKING

import univention.testing.format
from univention.testing.codes import Reason
from univention.testing.coverage import Coverage
from univention.testing.data import TestCase, TestEnvironment, TestFormatInterface, TestResult
from univention.testing.errors import TestError
from univention.testing.internal import LOG_BASE, get_sections, get_tests, setup_debug, setup_environment
from univention.testing.pytest import PytestRunner


if TYPE_CHECKING:
    from collections.abc import Iterable, Mapping


class StoreTag(argparse.Action):
    tag_level = ""

    def __call__(self, parser, namespace, value, option_string=None):
        dest = getattr(namespace, self.dest, {})
        dest[value] = self.tag_level
        setattr(namespace, self.dest, dest)


class StoreRequired(StoreTag):
    tag_level = "required"


class StoreProhibited(StoreTag):
    tag_level = "prohibited"


class StoreIgnored(StoreTag):
    tag_level = "ignored"


def parse_options(sections: Iterable[str]) -> tuple[argparse.Namespace, list[str]]:
    """Parse command line options."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-H", "--hold",
        action="store_true",
        help="Stop execution after the first failed test")
    parser.add_argument(
        "-t", "--timeout",
        default=3600, type=int,
        help="Abort single test after [%(default)s]s")

    selection_group = parser.add_argument_group('Test selection')
    selection_group.add_argument(
        "-s", "--section",
        dest="sections", action="append", choices=sections,
        help="Run tests only from this section", metavar="SECTION")
    selection_group.add_argument(
        "-p", "--prohibit",
        dest="tags", action=StoreProhibited, default={},
        help="Skip tests with this tag", metavar="TAG")
    selection_group.add_argument(
        "-r", "--require",
        dest="tags", action=StoreRequired, default={},
        help="Only run tests with this tag", metavar="TAG")
    selection_group.add_argument(
        "-g", "--ignore",
        dest="tags", action=StoreIgnored, default={},
        help="Neither require nor prohibit this tag", metavar="TAG")
    selection_group.add_argument(
        "-E", "--exposure",
        choices=('safe', 'careful', 'dangerous'),
        help="Run more dangerous tests")

    output_group = parser.add_argument_group('Output options')
    output_group.add_argument(
        "-n", "--dry-run",
        dest="dry", action="store_true",
        help="Only show which tests would run")
    output_group.add_argument(
        "-f", "--filtered",
        dest="filter", action="store_true",
        help="Hide tests with unmatched pre-conditions")
    output_group.add_argument(
        "-F", "--format",
        choices=univention.testing.format.FORMATS, default='text',
        help="Select output format [%(default)s]")
    output_group.add_argument(
        "-v", "--verbose",
        action="count",
        help="Increase verbosity")
    output_group.add_argument(
        "-i", "--interactive",
        action="store_true",
        help="Run test connected to terminal")
    output_group.add_argument(
        "-c", "--count",
        action="store_true",
        help="Prefix tests by count")
    output_group.add_argument(
        "-l", "--logfile",
        default=LOG_BASE % (time(),),
        help="Path to log file [%(default)s]")

    PytestRunner.get_argument_group(parser)
    Coverage.get_argument_group(parser)
    return parser.parse_known_args()


class TestSet:
    """Container for tests."""

    def __init__(self, tests: Mapping[str, Iterable[str]]) -> None:
        self.tests = tests
        self.max_count = sum(1 for tests in tests.values() for test in tests)
        self.test_environment: TestEnvironment | None = None
        self.format: TestFormatInterface | None = None
        self.prefix = ''

    def set_environment(self, test_environment: TestEnvironment) -> None:
        """Set environment for running tests."""
        self.test_environment = test_environment

    def set_format(self, format: str) -> None:
        """Select output format."""
        formatter = getattr(univention.testing.format, f'format_{format}')
        self.format = formatter()

    def set_prefix(self, prefix: bool) -> None:
        """Enable or disable test numbering."""
        if prefix:
            count_width = len('%d' % (self.max_count,))
            self.prefix = '%%0%dd/%%0%dd ' % (count_width, count_width)
        else:
            self.prefix = ''

    def run_tests(self, filter_condition: bool = False, dry_run: bool = False, stop_on_failure: bool = False) -> int | None:
        """Run selected tests."""
        assert self.format
        assert self.test_environment
        self.format.begin_run(self.test_environment, self.max_count)
        try:
            count = 0
            for section, cases in sorted(self.tests.items(), key=operator.itemgetter(1)):
                self.format.begin_section(section)
                try:
                    for fname in cases:
                        count += 1
                        test_case = TestCase(fname)
                        test_result = TestResult(test_case, self.test_environment)
                        try:
                            test_case.load()
                        except TestError as ex:
                            logger = logging.getLogger('test')
                            failed_message = f'Failed to load test "{fname}": {ex}'
                            logger.critical(failed_message)
                            check = True
                        else:
                            failed_message = ""
                            check = test_result.check()

                        if filter_condition and not check:
                            continue

                        if section == 's4connector':
                            from univention.testing.utils import wait_for_s4_connector_to_be_inactive
                            wait_for_s4_connector_to_be_inactive()

                        if self.prefix:
                            self.format.begin_test(test_case, self.prefix % (count, self.max_count))
                        else:
                            self.format.begin_test(test_case)

                        try:
                            if failed_message:
                                test_result.reason = Reason.FAIL
                                test_result.attach('stdout', 'text/plain', failed_message)
                            elif not dry_run:
                                test_result.run()
                                if stop_on_failure and test_result.reason.eofs in 'EF':
                                    return 1
                        finally:
                            self.format.end_test(test_result)
                finally:
                    self.format.end_section()
        finally:
            self.format.end_run()
        return 0


def main() -> int | None:
    """Run UCS test suite."""
    all_sections = get_sections()

    (options, args) = parse_options(all_sections.keys())
    if args:
        logger = logging.getLogger('test')
        logger.error('Unused arguments: %r', args)
        sys.exit(2)

    setup_environment()
    setup_debug(options.verbose)

    selected_sections = options.sections or all_sections.keys()
    tests = get_tests(selected_sections)

    coverage = Coverage(options)
    coverage.start()
    PytestRunner.set_arguments(options)
    test_set = TestSet(tests)

    if options.dry:
        test_environment = TestEnvironment(interactive=options.interactive)
    else:
        test_environment = TestEnvironment(
            interactive=options.interactive,
            logfile=options.logfile)
    tags_required = [tag for tag, level in options.tags.items() if level == "required"]
    tags_ignored = [tag for tag, level in options.tags.items() if level == "ignored"]
    tags_prohibited = [tag for tag, level in options.tags.items() if level == "prohibited"]
    test_environment.tag(
        require=tags_required,
        ignore=tags_ignored,
        prohibit=tags_prohibited)
    if options.exposure:
        test_environment.set_exposure(options.exposure)
    test_environment.set_timeout(options.timeout)
    test_set.set_environment(test_environment)
    test_set.set_prefix(options.count)
    test_set.set_format(options.format)
    try:
        return test_set.run_tests(options.filter, options.dry, options.hold)
    finally:
        coverage.stop()


if __name__ == '__main__':
    try:
        sys.exit(main())
    except KeyboardInterrupt:
        sys.exit(1)
