#!/usr/bin/python3
# SPDX-FileCopyrightText: 2016-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only
"""Configure UCS-domain and forward DNS servers."""

from __future__ import annotations

import sys
from argparse import SUPPRESS, ArgumentParser, Namespace
from collections import OrderedDict
from ipaddress import IPv4Address, IPv6Address, ip_address
from logging import DEBUG, ERROR, INFO, WARNING, basicConfig, getLogger
from os import environ
from subprocess import CalledProcessError, check_call, check_output

import dns.exception
import dns.flags
import dns.message
import dns.query
import dns.rcode
import dns.rdatatype

from univention.config_registry import ucr, ucr_factory
from univention.config_registry.frontend import ucr_update
from univention.config_registry.interfaces import Interfaces


IPAddress = IPv4Address | IPv6Address
DictAddr2Str = dict[IPAddress, str | None]

UCR_VARS_FWD = [f'dns/forwarder{i}' for i in range(1, 4)]
UCR_VARS_DNS = [f'nameserver{i}' for i in range(1, 4)]
LOCAL = '127.0.0.1'  # or ::1 for IPv6
TIMEOUT = 10.0


options = Namespace()


def main() -> None:
    """Fix name server setting in Univention Configuration Registry"""
    global options
    options = parse_args()
    setup_logging()
    log = getLogger(__name__)

    if options.run_tests:
        run_tests()

    if ucr.is_true('nameserver/external'):
        log.fatal('Using external DNS - aborting')
        sys.exit(0)

    nameservers: DictAddr2Str = OrderedDict()
    forwarders: DictAddr2Str = OrderedDict()
    need_fixing = get_nameservers_cli(nameservers)
    need_fixing |= get_forwarders(forwarders, ucr)
    need_fixing |= get_nameservers_ucr(nameservers, ucr)
    need_fixing |= validate_servers(nameservers, forwarders, ucr['domainname'])
    need_fixing |= either_or(nameservers, forwarders)
    need_fixing |= not nameservers
    need_fixing |= options.force_self
    if not need_fixing:
        log.info("No action required.")
        return
    add_self(nameservers, ucr, options.own_ip)
    add_nameservers(nameservers, ucr['domainname'])
    add_master(nameservers, ucr['ldap/master'])
    move_nameservers(forwarders, nameservers)

    if not nameservers:
        log.fatal('No nameserver remains - aborting')
        sys.exit(1)

    update_ucr(ucr, nameservers, forwarders)


def parse_args() -> Namespace:
    """
    Parse command line arguments

    :returns: parsed arguments.
    """
    parser = ArgumentParser(description=__doc__)
    parser.add_argument(
        '--verbose', '-v',
        action='count', default=2,
        help='Increase verbosity')
    parser.add_argument(
        '--no-act', '-d',
        action='store_true',
        help='Enable dry-run mode')
    parser.add_argument(
        '--ipv6', '-6',
        action='store_const', const=(dns.rdatatype.A, dns.rdatatype.AAAA), default=(dns.rdatatype.A), dest='rrtyp',  # type: ignore[attr-defined]
        help='Also add IPv6 addresses')
    self_group = parser.add_mutually_exclusive_group()
    self_group.add_argument(
        '--no-self', '-S',
        action='store_true',
        help='Do not add self as name-server')
    self_group.add_argument(
        '--force-self', '-f',
        action='store_true',
        help='Force adding self as name-server')
    parser.add_argument(
        '--own-ip',
        type=ip_address,
        help='Specify own IP address',
        metavar="IP",
    )
    parser.add_argument(
        '--add-master', '-m',
        action='store_true',
        help='Add domaincontroller_master as name-server')
    parser.add_argument(
        '--add-nameservers', '-n',
        action='store_true',
        help='Add other name-servers')
    parser.add_argument(
        '--no-validation', '-V',
        action='store_true',
        help='Do not validate DNS servers')
    parser.add_argument(
        '--xor', '-x',
        action='store_true',
        help='Remove name-servers from forwarders')
    parser.add_argument(
        '--run-tests',
        action='store_true',
        help=SUPPRESS)
    parser.add_argument(
        '--no-ucr',
        action='store_true',
        help='Do not load nameservers and forwarders from UCR variables')
    parser.add_argument(
        '--dnsserver',
        action="append",
        type=ip_address,
        default=[],
        help='Specify nameserver delivered e.g. via DHCP. May be specified multiple times.',
        metavar="IP",
        dest="dnsservers",
    )

    options = parser.parse_args()

    return options


def setup_logging() -> None:
    """Setup logging output."""
    FORMAT = '%(levelname)-8s %(funcName)-20s %(message)s'
    LEVELS = [ERROR, WARNING, INFO, DEBUG]
    try:
        level = LEVELS[options.verbose]
    except IndexError:
        level = LEVELS[-1]
    basicConfig(format=FORMAT, level=level, stream=sys.stderr)


def get_nameservers_cli(nameservers: DictAddr2Str) -> bool:
    """
    Get DNS servers from command line arguments.

    :returns: `True` if any server is given.
    """
    log = getLogger(__name__).getChild('cli/ns')
    log.debug('Reading UCS domain servers from CLI...')

    need_fixing = False
    for ns in options.dnsservers:
        log.info('Added server %s via CLI argument', ns)
        nameservers[ns] = None
        need_fixing = True

    return need_fixing


def get_forwarders(forwarders: DictAddr2Str, ucr: dict[str, str]) -> bool:
    """
    Get currently configured externnal DNS servers from UCR.

    :param forwarders: Dictionary receiving mapping IP address to `None`.
    :param ucr: UCR instance.
    :returns: `True` if self is configured as forwarder.
    """
    log = getLogger(__name__).getChild('ucr/fwd')
    if options.no_ucr:
        log.info('Skip reading forwarders from UCR')
        return False
    log.debug('Reading external DNS forwarders from UCR...')

    need_fixing = False
    for var in UCR_VARS_FWD:
        fwd_str = ucr.get(var, '').strip()
        if not fwd_str:
            continue
        fwd = ip_address(fwd_str)
        if is_self(fwd):
            log.error("Dropping local address %s from UCRV %s", fwd, var)
            need_fixing = True
            continue
        log.info('Found forwarder %s from UCRV %s', fwd, var)
        forwarders[fwd] = None

    return need_fixing


def get_nameservers_ucr(nameservers: DictAddr2Str, ucr: dict[str, str]) -> bool:
    """
    Get currently configured internal DNS servers from UCR.

    :param nameservers: Dictionary receiving mapping IP address to `None`.
    :param ucr: UCR instance.
    :returns: `False`.
    """
    log = getLogger(__name__).getChild('ucr/ns')
    if options.no_ucr:
        log.info('Skip reading nameservers from UCR')
        return False
    log.debug('Reading UCS domain servers from UCR...')

    need_fixing = False
    for var in UCR_VARS_DNS:
        ns = ucr.get(var, '').strip()
        if not ns:
            continue
        dns = ip_address(ns)
        log.info('Found server %s from UCRV %s', dns, var)
        nameservers[dns] = None

    return need_fixing


def validate_servers(nameservers: DictAddr2Str, forwarders: DictAddr2Str, domain: str) -> bool:
    """
    Check DNS servers being internal or external and re-categorize.

    :param nameservers: Mapping of internal DNS servers.
    :param forwarders: Mapping of external DNS servers.
    :param domain: DNS domain name.
    :returns: `True` if any DNS server is re-categorized.
    """
    log = getLogger(__name__).getChild('val')
    if options.no_validation:
        log.info('Skip validation of DNS servers')
        return False
    log.debug('Validating UCS domain servers...')

    need_fixing = False
    for server in list(nameservers):
        try:
            if query_master_srv_record(domain, server):
                log.info('Validated UCS domain server: %s', server)
            else:
                log.warning('UCS Primary Directory Node SRV record is unknown at %s, converting into forwarder', server)
                need_fixing = True
                del nameservers[server]
                forwarders[server] = None
        except dns.exception.Timeout as exc:
            log.warning('Connection check to %s (%s) failed, maybe down?!', server, exc.args[0])
            log.info('Leaving it configured as nameserver anyway')

    return need_fixing


def either_or(nameservers: DictAddr2Str, forwarders: DictAddr2Str) -> bool:
    """
    Remove forwarders which are also internal DNS servers.

    :param nameservers: Mapping of internal DNS servers.
    :param forwarders: Mapping of external DNS servers.
    :returns: `True` if any DNS forwarder was removed.
    """
    log = getLogger(__name__).getChild('xor')
    if not options.xor:
        log.info('Skip removing nameservers from forwarders')
        return False
    log.info('Removing UCS domain servers from forwarders...')

    need_fixing = False
    unique = object()
    for server in nameservers:
        if forwarders.pop(server, unique) is not unique:
            log.info('Removed UCS domain server %s from forwarders', server)
            need_fixing = True

    return need_fixing


def add_self(nameservers: DictAddr2Str, ucr: dict[str, str], own_ip: IPAddress | None = None) -> None:
    """
    Add self as internal DNS server (on DCs).

    :param nameservers: Mapping of internal DNS servers.
    :param ucr: UCR instance.
    """
    log = getLogger(__name__).getChild('ucr/self')
    if options.no_self:
        log.info('Skip adding self')
        return

    if any(is_self(addr) for addr in nameservers):
        log.info('Already using self')
        return

    if own_ip:
        myself = own_ip
        log.info('Own IP address given via CLI option: %s', myself)
    else:
        iface = Interfaces(ucr)
        mynet = iface.get_default_ip_address()
        myself = mynet.ip
        log.info('Default IP address configured in UCR: %s', myself)

    domain = ucr['domainname']
    if not options.force_self and not query_master_srv_record(domain, myself):
        log.warning('Failed to query local server %s for %s', myself, domain)
        if nameservers:
            return
        log.warning('Adding anyway as no other nameserer remains.')

    old = list(nameservers.items())
    nameservers.clear()
    nameservers[myself] = None
    nameservers.update(old)


def add_nameservers(nameservers: DictAddr2Str, domain: str) -> None:
    """
    Add DNS servers from zone as internal DNS servers.

    :param nameservers: Mapping of internal DNS servers.
    :param domain: DNS domain name.
    """
    log = getLogger(__name__).getChild('ns')
    if not options.add_nameservers:
        log.info('Skip adding NS')
        return

    log.debug('Querying %s for additional NS records in %s', LOCAL, domain)

    req = dns.message.make_query(domain, "NS")
    req.flags |= dns.flags.AA  # type: ignore[attr-defined]
    try:
        res = dns.query.udp(req, LOCAL, timeout=TIMEOUT, ignore_trailing=True)
    except dns.exception.Timeout:
        log.error('DNS lookup of NS records in %s against %s failed', domain, LOCAL)
        return

    log.debug('header=%r', res)

    if res.rcode() == dns.rcode.NOERROR and res.flags & dns.flags.AA:  # type: ignore[attr-defined]
        names = {item.target for rr in res.answer for item in rr.items}
        log.debug('servers=%r', names)
        for rr in res.additional:
            log.debug('rr=%s', rr)
            name = rr.name
            if rr.rdtype in options.rrtyp and name in names:
                ip = ip_address(rr.items[0])
                if is_self(ip):
                    log.info('Skipping local interface address %s found for NS record %s', ip, name)
                    continue
                log.info('Adding server found in NS: %s=%s', name, ip)
                nameservers[ip] = None
                names.remove(name)
    else:
        log.error('DNS lookup of NS records in %s against %s failed', domain, LOCAL)


def add_master(nameservers: DictAddr2Str, master: str) -> None:
    """
    Add Primary DC as internal DNS server.

    :param nameservers: Mapping of internal DNS servers.
    :param master: Fully qualified host name of Primary Directory Node.
    """
    log = getLogger(__name__).getChild('ldap')
    if not options.add_master:
        log.info('Skip adding Primary Directory Node')
        return

    log.debug('Querying %s for address of Primary Directory Node %s', LOCAL, master)
    req = dns.message.make_query(master, "ANY")
    req.flags |= dns.flags.AA  # type: ignore[attr-defined]
    try:
        res = dns.query.udp(req, LOCAL, timeout=TIMEOUT, ignore_trailing=True)
    except dns.exception.Timeout:
        log.error('DNS lookup of %r against %r failed', master, LOCAL)
        return

    log.debug('header=%r', res)

    if res.rcode() == dns.rcode.NOERROR and res.flags & dns.flags.AA:  # type: ignore[attr-defined]
        for rr in res.answer:
            log.debug('rr=%s', rr)
            if rr.rdtype in options.rrtyp:
                ip = ip_address(rr.items[0])
                if is_self(ip):
                    log.info('Skipping local interface address %s found for ldap/master %s', ip, master)
                    continue
                log.info('Adding Primary Directory Node %s', ip)
                nameservers[ip] = None
                break
    else:
        log.error('DNS lookup of %s against %s failed', master, LOCAL)


def move_nameservers(forwarders: DictAddr2Str, nameservers: DictAddr2Str) -> None:
    """
    Move all forwarders to nameservers.

    :param nameservers: Mapping of internal DNS servers.
    :param forwarders: Mapping of external DNS servers.
    """
    log = getLogger(__name__).getChild('move')
    if not options.no_self:
        log.info('Skip moving forwarders to nameservers')
        return
    log.info('Moving forwarders %s to nameservers %s ...', list(forwarders), list(nameservers))

    nameservers.update(forwarders)
    forwarders.clear()


def update_ucr(ucr: dict[str, str], nameservers: DictAddr2Str, forwarders: DictAddr2Str) -> None:
    """
    Update internal and external DNS servers in UCR settings.

    :param ucr: UCR instance.
    :param nameservers: Mapping of internal DNS servers.
    :param forwarders: Mapping of external DNS servers.
    """
    log = getLogger(__name__).getChild('ucr')
    new_ucr_settings = {}

    def update(names: list[str], input_values: DictAddr2Str, typ: str) -> None:
        log.debug('%s=%r', typ, list(input_values))
        values: list[str | None] = [str(val) for val in input_values]
        diff = len(names) - len(values)
        if diff > 0:
            values += [None] * diff
        elif diff < 0:
            log.warning('Skipping extra %s: %r', typ, values[len(names):])
        new_ucr_settings.update(dict(zip(names, values)))

    update(UCR_VARS_FWD, forwarders, 'forwarders')
    update(UCR_VARS_DNS, nameservers, 'nameservers')
    log.info('Updating %r', new_ucr_settings)

    if options.no_act:
        return

    changes_found = False
    for (key, val) in sorted(new_ucr_settings.items()):
        old = ucr.get(key)
        if old != val:
            log.info('Updating %r: %r -> %r', key, old, val)
            changes_found = True

    if not changes_found:
        return

    ucr_update(ucr_factory(), new_ucr_settings)

    if options.no_self:
        # we assume no BIND is running on an unjoined DC or MemberServer
        return

    log.info('Reloading BIND')
    check_call(('rndc', 'reconfig'))


def query_master_srv_record(domain: str, server: IPAddress) -> bool | dns.flags.Flag:
    """
    Lookup Primary DC entry in DNS zone.

    :param domain: DNS domain name.
    :param server: DNS server to query.
    :returns: `False` or the SRV RR.
    """
    log = getLogger(__name__).getChild('dns/srv')

    rec = f"_domaincontroller_master._tcp.{domain.rstrip('.')}."
    log.debug('Querying %s for SRV %s', server, rec)

    req = dns.message.make_query(rec, "SRV")
    req.flags |= dns.flags.AA  # type: ignore[attr-defined]
    res = dns.query.udp(req, str(server), timeout=TIMEOUT, ignore_trailing=True)

    log.debug('header=%r', res)

    return res.rcode() == dns.rcode.NOERROR and res.flags & dns.flags.AA  # type: ignore[attr-defined]


def is_self(addr: str | IPv4Address | IPv6Address) -> bool:
    """
    Check if given address is associated with the local host.

    :param addr: An IP address or domain name.
    :returns: `True` if the address is local.

    >>> is_self('127.0.0.1')
    True
    >>> is_self('::1')
    True
    >>> is_self('8.8.8.8')
    False
    >>> is_self('0.0.0.1')
    False
    """
    log = getLogger(__name__).getChild('ip')

    env = dict(environ)
    env['LC_ALL'] = 'C'
    cmd = ['ip', 'route', 'get', str(addr)]
    log.debug('calling %r', cmd)
    try:
        out = check_output(cmd, env=env).decode('UTF-8')
        return out.startswith('local ')
    except CalledProcessError as ex:
        log.warning('Failed to determine route: %s', ex)
        return False


def run_tests() -> None:
    """Run internal test suite."""
    import doctest
    doctest.testmod()


if __name__ == '__main__':
    main()
