#!/usr/bin/python3
# Like what you see? Join us!
# https://www.univention.com/about-us/careers/vacancies/
#
# Copyright 2016-2023 Univention GmbH
#
# https://www.univention.de/
#
# All rights reserved.
#
# The source code of this program is made available
# under the terms of the GNU Affero General Public License version 3
# (GNU AGPL V3) as published by the Free Software Foundation.
#
# Binary versions of this program provided by Univention to you as
# well as other copyrighted, protected or trademarked materials like
# Logos, graphics, fonts, specific documentations and configurations,
# cryptographic keys etc. are subject to a license agreement between
# you and Univention and not subject to the GNU AGPL V3.
#
# In the case you use this program under the terms of the GNU AGPL V3,
# the program is provided in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License with the Debian GNU/Linux or Univention distribution in file
# /usr/share/common-licenses/AGPL-3; if not, see
# <https://www.gnu.org/licenses/>.

"""Configure UCS-domain and forward DNS servers."""

from __future__ import absolute_import

from argparse import SUPPRESS, ArgumentParser, Namespace
from collections import OrderedDict
from ipaddress import IPv4Address, IPv6Address, ip_address
from logging import DEBUG, ERROR, INFO, WARNING, basicConfig, getLogger
from os import environ
from subprocess import CalledProcessError, check_call, check_output
from sys import exit, stderr
from typing import Dict, List, Optional, Union

from DNS import DnsRequest, SocketError, TimeoutError

from univention.config_registry import ConfigRegistry
from univention.config_registry.frontend import ucr_update
from univention.config_registry.interfaces import Interfaces


IPAddress = Union[IPv4Address, IPv6Address]
DictAddr2Str = Dict[IPAddress, Optional[str]]

UCR_VARS_FWD = ['dns/forwarder%d' % (i,) for i in range(1, 4)]
UCR_VARS_DNS = ['nameserver%d' % (i,) for i in range(1, 4)]
LOCAL = '127.0.0.1'  # or ::1 for IPv6


options = Namespace()


def main() -> None:
    """Fix name server setting in Univention Configuration Registry"""
    global options
    options = parse_args()
    setup_logging()
    log = getLogger(__name__)

    if options.run_tests:
        run_tests()

    ucr = ConfigRegistry()
    ucr.load()
    if ucr.is_true('nameserver/external'):
        log.fatal('Using external DNS - aborting')
        exit(0)

    nameservers: DictAddr2Str = OrderedDict()
    forwarders: DictAddr2Str = OrderedDict()
    need_fixing = get_nameservers_cli(nameservers)
    need_fixing |= get_forwarders(forwarders, ucr)
    need_fixing |= get_nameservers_ucr(nameservers, ucr)
    need_fixing |= validate_servers(nameservers, forwarders, ucr['domainname'])
    need_fixing |= either_or(nameservers, forwarders)
    need_fixing |= not nameservers
    need_fixing |= options.force_self
    if not need_fixing:
        log.info("No action required.")
        return
    add_self(nameservers, ucr, options.own_ip)
    add_nameservers(nameservers, ucr['domainname'])
    add_master(nameservers, ucr['ldap/master'])
    move_nameservers(forwarders, nameservers)

    if not nameservers:
        log.fatal('No nameserver remains - aborting')
        exit(1)

    update_ucr(ucr, nameservers, forwarders)


def parse_args() -> Namespace:
    """
    Parse command line arguments

    :returns: parsed arguments.
    """
    parser = ArgumentParser(description=__doc__)
    parser.add_argument(
        '--verbose', '-v',
        action='count', default=2,
        help='Increase verbosity')
    parser.add_argument(
        '--no-act', '-d',
        action='store_true',
        help='Enable dry-run mode')
    parser.add_argument(
        '--ipv6', '-6',
        action='store_const', const=('A', 'AAAA'), default=('A',), dest='rrtyp',
        help='Also add IPv6 addresses')
    self_group = parser.add_mutually_exclusive_group()
    self_group.add_argument(
        '--no-self', '-S',
        action='store_true',
        help='Do not add self as name-server')
    self_group.add_argument(
        '--force-self', '-f',
        action='store_true',
        help='Force adding self as name-server')
    parser.add_argument(
        '--own-ip',
        help='Specify own IP address')
    parser.add_argument(
        '--add-master', '-m',
        action='store_true',
        help='Add domaincontroller_master as name-server')
    parser.add_argument(
        '--add-nameservers', '-n',
        action='store_true',
        help='Add other name-servers')
    parser.add_argument(
        '--no-validation', '-V',
        action='store_true',
        help='Do not validate DNS servers')
    parser.add_argument(
        '--xor', '-x',
        action='store_true',
        help='Remove name-servers from forwarders')
    parser.add_argument(
        '--run-tests',
        action='store_true',
        help=SUPPRESS)
    parser.add_argument(
        '--no-ucr',
        action='store_true',
        help='Do not load nameservers and forwarders from UCR variables')
    parser.add_argument(
        '--dnsserver',
        action="append",
        dest="dnsservers",
        default=[],
        help='Specify nameserver delivered e.g. via DHCP. May be specified multiple times.')

    options = parser.parse_args()

    return options


def setup_logging() -> None:
    """Setup logging output."""
    FORMAT = '%(asctime)-15s %(levelname)-7s %(name)-17s %(message)s'
    LEVELS = [ERROR, WARNING, INFO, DEBUG]
    try:
        level = LEVELS[options.verbose]
    except IndexError:
        level = LEVELS[-1]
    basicConfig(format=FORMAT, level=level, stream=stderr)


def get_nameservers_cli(nameservers: DictAddr2Str) -> bool:
    """
    Get DNS servers from command line arguments.

    :returns: `True` if any server is given.
    """
    log = getLogger(__name__).getChild('cli/ns')
    log.debug('Reading UCS domain servers from CLI...')

    need_fixing = False
    for ns in options.dnsservers:
        dns = ip_address(u'%s' % (ns,))
        log.info('Added server %s via CLI argument', dns)
        nameservers[dns] = None
        need_fixing = True

    return need_fixing


def get_forwarders(forwarders: DictAddr2Str, ucr: Dict[str, str]) -> bool:
    """
    Get currently configured externnal DNS servers from UCR.

    :param forwarders: Dictionary receiving mapping IP address to `None`.
    :param ucr: UCR instance.
    :returns: `True` if self is configured as forwarder.
    """
    log = getLogger(__name__).getChild('ucr/fwd')
    if options.no_ucr:
        log.info('Skip reading forwarders from UCR')
        return False
    log.debug('Reading external DNS forwarders from UCR...')

    need_fixing = False
    for var in UCR_VARS_FWD:
        fwd_str = ucr.get(var, '').strip()
        if not fwd_str:
            continue
        fwd = ip_address(u'%s' % (fwd_str,))
        if is_self(fwd):
            log.error("Dropping local address %s from UCRV %s", fwd, var)
            need_fixing = True
            continue
        log.info('Found forwarder %s from UCRV %s', fwd, var)
        forwarders[fwd] = None

    return need_fixing


def get_nameservers_ucr(nameservers: DictAddr2Str, ucr: Dict[str, str]) -> bool:
    """
    Get currently configured internal DNS servers from UCR.

    :param nameservers: Dictionary receiving mapping IP address to `None`.
    :param ucr: UCR instance.
    :returns: `False`.
    """
    log = getLogger(__name__).getChild('ucr/ns')
    if options.no_ucr:
        log.info('Skip reading nameservers from UCR')
        return False
    log.debug('Reading UCS domain servers from UCR...')

    need_fixing = False
    for var in UCR_VARS_DNS:
        ns = ucr.get(var, '').strip()
        if not ns:
            continue
        dns = ip_address(u'%s' % (ns,))
        log.info('Found server %s from UCRV %s', dns, var)
        nameservers[dns] = None

    return need_fixing


def validate_servers(nameservers: DictAddr2Str, forwarders: DictAddr2Str, domain: str) -> bool:
    """
    Check DNS servers being internal or external and re-categorize.

    :param nameservers: Mapping of internal DNS servers.
    :param forwarders: Mapping of external DNS servers.
    :param domain: DNS domain name.
    :returns: `True` if any DNS server is re-categorized.
    """
    log = getLogger(__name__).getChild('val')
    if options.no_validation:
        log.info('Skip validation of DNS servers')
        return False
    log.debug('Validating UCS domain servers...')

    need_fixing = False
    for server in list(nameservers):
        try:
            if query_master_src_record(domain, server):
                log.info('Validated UCS domain server: %s', server)
            else:
                log.warning('UCS Primary Directory Node SRV record is unknown at %s, converting into forwarder', server)
                need_fixing = True
                del nameservers[server]
                forwarders[server] = None
        except (SocketError, TimeoutError) as exc:
            log.warning('Connection check to %s (%s) failed, maybe down?!', server, exc.args[0])
            log.info('Leaving it configured as nameserver anyway')

    return need_fixing


def either_or(nameservers: DictAddr2Str, forwarders: DictAddr2Str) -> bool:
    """
    Remove forwarders which are also internal DNS servers.

    :param nameservers: Mapping of internal DNS servers.
    :param forwarders: Mapping of external DNS servers.
    :returns: `True` if any DNS forwarder was removed.
    """
    log = getLogger(__name__).getChild('xor')
    if not options.xor:
        log.info('Skip removing nameservers from forwarders')
        return False
    log.info('Removing UCS domain servers from forwarders...')

    need_fixing = False
    unique = object()
    for server in nameservers:
        if forwarders.pop(server, unique) is not unique:
            log.info('Removed UCS domain server %s from forwarders', server)
            need_fixing = True

    return need_fixing


def add_self(nameservers: DictAddr2Str, ucr: Dict[str, str], own_ip: Optional[str] = None) -> None:
    """
    Add self as internal DNS server (on DCs).

    :param nameservers: Mapping of internal DNS servers.
    :param ucr: UCR instance.
    """
    log = getLogger(__name__).getChild('ucr/self')
    if options.no_self:
        log.info('Skip adding self')
        return

    if any(is_self(addr) for addr in nameservers):
        log.info('Already using self')
        return

    if own_ip:
        myself = own_ip
        log.info('Own IP address given via CLI option: %s', myself)
    else:
        iface = Interfaces(ucr)
        mynet = iface.get_default_ip_address()
        myself = mynet.ip
        log.info('Default IP address configured in UCR: %s', myself)

    domain = ucr['domainname']
    if not options.force_self and not query_master_src_record(domain, myself):
        log.warning('Failed to query local server %s for %s', myself, domain)
        if nameservers:
            return
        log.warning('Adding anyway as no other nameserer remains.')

    old = list(nameservers.items())
    nameservers.clear()
    nameservers[myself] = None
    nameservers.update(old)


def add_nameservers(nameservers: DictAddr2Str, domain: str) -> None:
    """
    Add DNS servers from zone as internal DNS servers.

    :param nameservers: Mapping of internal DNS servers.
    :param domain: DNS domain name.
    """
    log = getLogger(__name__).getChild('ns')
    if not options.add_nameservers:
        log.info('Skip adding NS')
        return

    log.debug('Querying %s for additional NS records in %s', LOCAL, domain)
    r = DnsRequest(domain, qtype='NS', server=[LOCAL], aa=1, rd=0).req()
    log.debug('header=%r', r.header)

    if r.header['status'] == 'NOERROR' and r.header['aa']:
        names = {rr['data'] for rr in r.answers}
        log.debug('servers=%r', names)
        for rr in r.additional:
            log.debug('rr=%r', rr)
            name = rr['name']
            if rr['typename'] in options.rrtyp and name in names:
                ip = get_ip(rr)
                if is_self(ip):
                    log.info('Skipping local interface address %s found for NS record %s', ip, name)
                    continue
                log.info('Adding server found in NS: %s=%s', name, ip)
                nameservers[ip] = None
                names.remove(name)
    else:
        log.error('DNS lookup of NS records in %s against %s failed', domain, LOCAL)


def add_master(nameservers: DictAddr2Str, master: str) -> None:
    """
    Add Primary DC as internal DNS server.

    :param nameservers: Mapping of internal DNS servers.
    :param master: Fully qualified host name of Primary Directory Node.
    """
    log = getLogger(__name__).getChild('ldap')
    if not options.add_master:
        log.info('Skip adding Primary Directory Node')
        return

    log.debug('Querying %s for address of Primary Directory Node %s', LOCAL, master)
    r = DnsRequest(master, qtype='ANY', server=[LOCAL], aa=1, rd=0).req()
    log.debug('header=%r', r.header)

    if r.header['status'] == 'NOERROR' and r.header['aa']:
        for rr in r.answers:
            log.debug('rr=%r', rr)
            if rr['typename'] in options.rrtyp:
                ip = get_ip(rr)
                if is_self(ip):
                    log.info('Skipping local interface address %s found for ldap/master %s', ip, master)
                    continue
                log.info('Adding Primary Directory Node %s', ip)
                nameservers[ip] = None
                break
    else:
        log.error('DNS lookup of %s against %s failed', master, LOCAL)


def move_nameservers(forwarders: DictAddr2Str, nameservers: DictAddr2Str) -> None:
    """
    Move all forwarders to nameservers.

    :param nameservers: Mapping of internal DNS servers.
    :param forwarders: Mapping of external DNS servers.
    """
    log = getLogger(__name__).getChild('move')
    if not options.no_self:
        log.info('Skip moving forwarders to nameservers')
        return
    log.info('Moving forwarders %s to nameservers %s ...', list(forwarders), list(nameservers))

    nameservers.update(forwarders)
    forwarders.clear()


def update_ucr(ucr: Dict[str, str], nameservers: DictAddr2Str, forwarders: DictAddr2Str) -> None:
    """
    Update internal and external DNS servers in UCR settings.

    :param ucr: UCR instance.
    :param nameservers: Mapping of internal DNS servers.
    :param forwarders: Mapping of external DNS servers.
    """
    log = getLogger(__name__).getChild('ucr')
    new_ucr_settings = {}

    def update(names: List[str], input_values: DictAddr2Str, typ: str) -> None:
        log.debug('%s=%r', typ, list(input_values))
        values: List[Optional[str]] = ['%s' % (val,) for val in input_values]
        diff = len(names) - len(values)
        if diff > 0:
            values += [None] * diff
        elif diff < 0:
            log.warning('Skipping extra %s: %r', typ, values[len(names):])
        new_ucr_settings.update(dict(zip(names, values)))

    update(UCR_VARS_FWD, forwarders, 'forwarders')
    update(UCR_VARS_DNS, nameservers, 'nameservers')
    log.info('Updating %r', new_ucr_settings)

    if options.no_act:
        return

    changes_found = False
    for (key, val) in sorted(new_ucr_settings.items()):
        old = ucr.get(key)
        if old != val:
            log.info('Updating %r: %r -> %r', key, old, val)
            changes_found = True

    if not changes_found:
        return

    ucr_update(ucr, new_ucr_settings)

    if options.no_self:
        # we assume no BIND is running on an unjoined DC or MemberServer
        return

    log.info('Reloading BIND')
    check_call(('rndc', 'reconfig'))


def query_master_src_record(domain: str, server: IPAddress) -> Union[bool, str]:
    """
    Lookup Primary DC entry in DNS zone.

    :param domain: DNS domain name.
    :param server: DNS server to query.
    :returns: `False` or the SRV RR.
    """
    log = getLogger(__name__).getChild('dns/srv')

    rec = '_domaincontroller_master._tcp.%s.' % (domain.rstrip('.'),)
    log.debug('Querying %s for SRV %s', server, rec)

    req = DnsRequest(rec, qtype='SRV', server=['%s' % (server,)], aa=1, rd=0)
    res = req.req()
    log.debug('header=%r', res.header)

    return res.header['status'] == 'NOERROR' and res.header['aa']  # type: ignore


def get_ip(rr: Dict[str, str]) -> IPAddress:
    r"""
    Resolve DNS Resource Record to IP address.

    :param rr: DNS RR.
    :returns: resolved IP address.
    :raises TypeError: if the RR is neither an `A` nor `AAAA` record.

    >>> get_ip({'typename': 'A', 'data': '127.0.0.1'})
    IPv4Address('127.0.0.1')
    >>> get_ip({'typename': 'AAAA', 'data': '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01'})
    IPv6Address('::1')
    """
    typ, data = rr['typename'], rr['data']
    if typ == 'A':
        return ip_address(u'%s' % (data,))  # type: ignore
    elif typ == 'AAAA':
        # Work-around bug in python-pydns, which does not unpack IPv6 addresses
        assert len(data) == 16
        if not isinstance(data, bytes):
            data = bytes(data, 'charmap')  # type: ignore
        return ip_address(data)  # type: ignore
    else:
        raise TypeError(typ)


def is_self(addr: Union[str, IPv4Address, IPv6Address]) -> bool:
    """
    Check if given address is associated with the local host.

    :param addr: An IP address or domain name.
    :returns: `True` if the address is local.

    >>> is_self('127.0.0.1')
    True
    >>> is_self('::1')
    True
    >>> is_self('8.8.8.8')
    False
    >>> is_self('0.0.0.1')
    False
    """
    log = getLogger(__name__).getChild('ip')

    env = dict(environ)
    env['LC_ALL'] = 'C'
    cmd = ['ip', 'route', 'get', '%s' % addr]
    log.debug('calling %r', cmd)
    try:
        out = check_output(cmd, env=env).decode('UTF-8')
        return out.startswith('local ')
    except CalledProcessError as ex:
        log.warning('Failed to determine route: %s', ex)
        return False


def run_tests() -> None:
    """Run internal test suite."""
    import doctest
    doctest.testmod()


if __name__ == '__main__':
    main()
