#!/usr/bin/python3
# Copyright 2016-2021 Univention GmbH
#
# https://www.univention.de/
#
# All rights reserved.
#
# The source code of this program is made available
# under the terms of the GNU Affero General Public License version 3
# (GNU AGPL V3) as published by the Free Software Foundation.
#
# Binary versions of this program provided by Univention to you as
# well as other copyrighted, protected or trademarked materials like
# Logos, graphics, fonts, specific documentations and configurations,
# cryptographic keys etc. are subject to a license agreement between
# you and Univention and not subject to the GNU AGPL V3.
#
# In the case you use this program under the terms of the GNU AGPL V3,
# the program is provided in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License with the Debian GNU/Linux or Univention distribution in file
# /usr/share/common-licenses/AGPL-3; if not, see
# <https://www.gnu.org/licenses/>.

"""
Configure UCS-domain and forward DNS servers.
"""

from __future__ import absolute_import

from argparse import SUPPRESS, ArgumentParser, Namespace
from collections import OrderedDict
from ipaddress import IPv4Address, IPv6Address, ip_address
from logging import DEBUG, ERROR, INFO, WARNING, basicConfig, getLogger
from os import environ
from subprocess import CalledProcessError, check_call, check_output
from sys import exit, stderr
from typing import Dict, List, Optional, Union

from DNS import DnsRequest, SocketError, TimeoutError

from univention.config_registry import ConfigRegistry
from univention.config_registry.frontend import ucr_update
from univention.config_registry.interfaces import Interfaces

IPAddress = Union[IPv4Address, IPv6Address]
DictAddr2Str = Dict[IPAddress, Optional[str]]

UCR_VARS_FWD = ['dns/forwarder%d' % (i,) for i in range(1, 4)]
UCR_VARS_DNS = ['nameserver%d' % (i,) for i in range(1, 4)]
LOCAL = '127.0.0.1'  # or ::1 for IPv6


options = Namespace()


def main() -> None:
	"""
	Fix name server setting in Univention Configuration Registry
	"""
	global options
	options = parse_args()
	setup_logging()
	log = getLogger(__name__)

	if options.run_tests:
		run_tests()

	ucr = ConfigRegistry()
	ucr.load()
	if ucr.is_true('nameserver/external'):
		log.fatal('Using external DNS - aborting')
		exit(0)

	nameservers: DictAddr2Str = OrderedDict()
	forwarders: DictAddr2Str = OrderedDict()
	need_fixing = get_nameservers_cli(nameservers)
	need_fixing |= get_forwarders(forwarders, ucr)
	need_fixing |= get_nameservers_ucr(nameservers, ucr)
	need_fixing |= validate_servers(nameservers, forwarders, ucr['domainname'])
	need_fixing |= either_or(nameservers, forwarders)
	need_fixing |= not nameservers
	need_fixing |= options.force_self
	if not need_fixing:
		log.info("No action required.")
		return
	add_self(nameservers, ucr, options.own_ip)
	add_nameservers(nameservers, ucr['domainname'])
	add_master(nameservers, ucr['ldap/master'])
	move_nameservers(forwarders, nameservers)

	if not nameservers:
		log.fatal('No nameserver remains - aborting')
		exit(1)

	update_ucr(ucr, nameservers, forwarders)


def parse_args() -> Namespace:
	"""
	Parse command line arguments

	:returns: parsed arguments.
	"""
	parser = ArgumentParser(description=__doc__)
	parser.add_argument(
		'--verbose', '-v',
		action='count', default=2,
		help='Increase verbosity')
	parser.add_argument(
		'--no-act', '-d',
		action='store_true',
		help='Enable dry-run mode')
	parser.add_argument(
		'--ipv6', '-6',
		action='store_const', const=('A', 'AAAA'), default=('A',), dest='rrtyp',
		help='Also add IPv6 addresses')
	self_group = parser.add_mutually_exclusive_group()
	self_group.add_argument(
		'--no-self', '-S',
		action='store_true',
		help='Do not add self as name-server')
	self_group.add_argument(
		'--force-self', '-f',
		action='store_true',
		help='Force adding self as name-server')
	parser.add_argument(
		'--own-ip',
		help='Specify own IP address')
	parser.add_argument(
		'--add-master', '-m',
		action='store_true',
		help='Add domaincontroller_master as name-server')
	parser.add_argument(
		'--add-nameservers', '-n',
		action='store_true',
		help='Add other name-servers')
	parser.add_argument(
		'--no-validation', '-V',
		action='store_true',
		help='Do not validate DNS servers')
	parser.add_argument(
		'--xor', '-x',
		action='store_true',
		help='Remove name-servers from forwarders')
	parser.add_argument(
		'--run-tests',
		action='store_true',
		help=SUPPRESS)
	parser.add_argument(
		'--no-ucr',
		action='store_true',
		help='Do not load nameservers and forwarders from UCR variables')
	parser.add_argument(
		'--dnsserver',
		action="append",
		dest="dnsservers",
		default=[],
		help='Specify nameserver delivered e.g. via DHCP. May be specified multiple times.')

	options = parser.parse_args()

	return options


def setup_logging() -> None:
	"""
	Setup logging output.
	"""
	FORMAT = '%(asctime)-15s %(levelname)-7s %(name)-17s %(message)s'
	LEVELS = [ERROR, WARNING, INFO, DEBUG]
	try:
		level = LEVELS[options.verbose]
	except IndexError:
		level = LEVELS[-1]
	basicConfig(format=FORMAT, level=level, stream=stderr)


def get_nameservers_cli(nameservers: DictAddr2Str) -> bool:
	"""
	Get DNS servers from command line arguments.

	:returns: `True` if any server is given.
	"""
	log = getLogger(__name__).getChild('cli/ns')
	log.debug('Reading UCS domain servers from CLI...')

	need_fixing = False
	for ns in options.dnsservers:
		dns = ip_address(u'%s' % (ns,))
		log.info('Added server %s via CLI argument', dns)
		nameservers[dns] = None
		need_fixing = True

	return need_fixing


def get_forwarders(forwarders: DictAddr2Str, ucr: Dict[str, str]) -> bool:
	"""
	Get currently configured externnal DNS servers from UCR.

	:param forwarders: Dictionary receiving mapping IP address to `None`.
	:param ucr: UCR instance.
	:returns: `True` if self is configured as forwarder.
	"""
	log = getLogger(__name__).getChild('ucr/fwd')
	if options.no_ucr:
		log.info('Skip reading forwarders from UCR')
		return False
	log.debug('Reading external DNS forwarders from UCR...')

	need_fixing = False
	for var in UCR_VARS_FWD:
		fwd_str = ucr.get(var, '').strip()
		if not fwd_str:
			continue
		fwd = ip_address(u'%s' % (fwd_str,))
		if is_self(fwd):
			log.error("Dropping local address %s from UCRV %s", fwd, var)
			need_fixing = True
			continue
		log.info('Found forwarder %s from UCRV %s', fwd, var)
		forwarders[fwd] = None

	return need_fixing


def get_nameservers_ucr(nameservers: DictAddr2Str, ucr: Dict[str, str]) -> bool:
	"""
	Get currently configured internal DNS servers from UCR.

	:param nameservers: Dictionary receiving mapping IP address to `None`.
	:param ucr: UCR instance.
	:returns: `False`.
	"""
	log = getLogger(__name__).getChild('ucr/ns')
	if options.no_ucr:
		log.info('Skip reading nameservers from UCR')
		return False
	log.debug('Reading UCS domain servers from UCR...')

	need_fixing = False
	for var in UCR_VARS_DNS:
		ns = ucr.get(var, '').strip()
		if not ns:
			continue
		dns = ip_address(u'%s' % (ns,))
		log.info('Found server %s from UCRV %s', dns, var)
		nameservers[dns] = None

	return need_fixing


def validate_servers(nameservers: DictAddr2Str, forwarders: DictAddr2Str, domain: str) -> bool:
	"""
	Check DNS servers being internal or external and re-categorize.

	:param nameservers: Mapping of internal DNS servers.
	:param forwarders: Mapping of external DNS servers.
	:param domain: DNS domain name.
	:returns: `True` if any DNS server is re-categorized.
	"""
	log = getLogger(__name__).getChild('val')
	if options.no_validation:
		log.info('Skip validation of DNS servers')
		return False
	log.debug('Validating UCS domain servers...')

	need_fixing = False
	for server in list(nameservers):
		try:
			if query_master_src_record(domain, server):
				log.info('Validated UCS domain server: %s', server)
			else:
				log.warning('UCS Primary Directory Node SRV record is unknown at %s, converting into forwarder', server)
				need_fixing = True
				del nameservers[server]
				forwarders[server] = None
		except (SocketError, TimeoutError) as exc:
			log.warning('Connection check to %s (%s) failed, maybe down?!', server, exc.args[0])
			log.info('Leaving it configured as nameserver anyway')

	return need_fixing


def either_or(nameservers: DictAddr2Str, forwarders: DictAddr2Str) -> bool:
	"""
	Remove forwarders which are also internal DNS servers.

	:param nameservers: Mapping of internal DNS servers.
	:param forwarders: Mapping of external DNS servers.
	:returns: `True` if any DNS forwarder was removed.
	"""
	log = getLogger(__name__).getChild('xor')
	if not options.xor:
		log.info('Skip removing nameservers from forwarders')
		return False
	log.info('Removing UCS domain servers from forwarders...')

	need_fixing = False
	unique = object()
	for server in nameservers:
		if forwarders.pop(server, unique) is not unique:
			log.info('Removed UCS domain server %s from forwarders', server)
			need_fixing = True

	return need_fixing


def add_self(nameservers: DictAddr2Str, ucr: Dict[str, str], own_ip: str = None) -> None:
	"""
	Add self as internal DNS server (on DCs).

	:param nameservers: Mapping of internal DNS servers.
	:param ucr: UCR instance.
	"""
	log = getLogger(__name__).getChild('ucr/self')
	if options.no_self:
		log.info('Skip adding self')
		return

	if any(is_self(addr) for addr in nameservers):
		log.info('Already using self')
		return

	if own_ip:
		myself = own_ip
		log.info('Own IP address given via CLI option: %s', myself)
	else:
		iface = Interfaces(ucr)
		mynet = iface.get_default_ip_address()
		myself = mynet.ip
		log.info('Default IP address configured in UCR: %s', myself)

	domain = ucr['domainname']
	if not options.force_self and not query_master_src_record(domain, myself):
		log.warning('Failed to query local server %s for %s', myself, domain)
		if nameservers:
			return
		log.warning('Adding anyway as no other nameserer remains.')

	old = list(nameservers.items())
	nameservers.clear()
	nameservers[myself] = None
	nameservers.update(old)


def add_nameservers(nameservers: DictAddr2Str, domain: str) -> None:
	"""
	Add DNS servers from zone as internal DNS servers.

	:param nameservers: Mapping of internal DNS servers.
	:param domain: DNS domain name.
	"""
	log = getLogger(__name__).getChild('ns')
	if not options.add_nameservers:
		log.info('Skip adding NS')
		return

	log.debug('Querying %s for additional NS records in %s', LOCAL, domain)
	r = DnsRequest(domain, qtype='NS', server=[LOCAL], aa=1, rd=0).req()
	log.debug('header=%r', r.header)

	if r.header['status'] == 'NOERROR' and r.header['aa']:
		names = set(rr['data'] for rr in r.answers)
		log.debug('servers=%r', names)
		for rr in r.additional:
			log.debug('rr=%r', rr)
			name = rr['name']
			if rr['typename'] in options.rrtyp and name in names:
				ip = get_ip(rr)
				if is_self(ip):
					log.info('Skipping local interface address %s found for NS record %s', ip, name)
					continue
				log.info('Adding server found in NS: %s=%s', name, ip)
				nameservers[ip] = None
				names.remove(name)
	else:
		log.error('DNS lookup of NS records in %s against %s failed', domain, LOCAL)


def add_master(nameservers: DictAddr2Str, master: str) -> None:
	"""
	Add Primary DC as internal DNS server.

	:param nameservers: Mapping of internal DNS servers.
	:param master: Fully qualified host name of Primary Directory Node.
	"""
	log = getLogger(__name__).getChild('ldap')
	if not options.add_master:
		log.info('Skip adding Primary Directory Node')
		return

	log.debug('Querying %s for address of Primary Directory Node %s', LOCAL, master)
	r = DnsRequest(master, qtype='ANY', server=[LOCAL], aa=1, rd=0).req()
	log.debug('header=%r', r.header)

	if r.header['status'] == 'NOERROR' and r.header['aa']:
		for rr in r.answers:
			log.debug('rr=%r', rr)
			if rr['typename'] in options.rrtyp:
				ip = get_ip(rr)
				if is_self(ip):
					log.info('Skipping local interface address %s found for ldap/master %s', ip, master)
					continue
				log.info('Adding Primary Directory Node %s', ip)
				nameservers[ip] = None
				break
	else:
		log.error('DNS lookup of %s against %s failed', master, LOCAL)


def move_nameservers(forwarders: DictAddr2Str, nameservers: DictAddr2Str) -> None:
	"""
	Move all forwarders to nameservers.

	:param nameservers: Mapping of internal DNS servers.
	:param forwarders: Mapping of external DNS servers.
	"""
	log = getLogger(__name__).getChild('move')
	if not options.no_self:
		log.info('Skip moving forwarders to nameservers')
		return
	log.info('Moving forwarders %s to nameservers %s ...', list(forwarders), list(nameservers))

	nameservers.update(forwarders)
	forwarders.clear()


def update_ucr(ucr: Dict[str, str], nameservers: DictAddr2Str, forwarders: DictAddr2Str) -> None:
	"""
	Update internal and external DNS servers in UCR settings.

	:param ucr: UCR instance.
	:param nameservers: Mapping of internal DNS servers.
	:param forwarders: Mapping of external DNS servers.
	"""
	log = getLogger(__name__).getChild('ucr')
	new_ucr_settings = {}

	def update(names: List[str], input_values: DictAddr2Str, typ: str) -> None:
		log.debug('%s=%r', typ, list(input_values))
		values: List[Optional[str]] = ['%s' % (val,) for val in input_values]
		diff = len(names) - len(values)
		if diff > 0:
			values += [None] * diff
		elif diff < 0:
			log.warning('Skipping extra %s: %r', typ, values[len(names):])
		new_ucr_settings.update(dict(zip(names, values)))

	update(UCR_VARS_FWD, forwarders, 'forwarders')
	update(UCR_VARS_DNS, nameservers, 'nameservers')
	log.info('Updating %r', new_ucr_settings)

	if options.no_act:
		return

	changes_found = False
	for (key, val) in sorted(new_ucr_settings.items()):
		old = ucr.get(key)
		if old != val:
			log.info('Updating %r: %r -> %r', key, old, val)
			changes_found = True

	if not changes_found:
		return

	ucr_update(ucr, new_ucr_settings)

	if options.no_self:
		# we assume no BIND is running on an unjoined DC or MemberServer
		return

	log.info('Reloading BIND')
	check_call(('rndc', 'reconfig'))


def query_master_src_record(domain: str, server: IPAddress) -> Union[bool, str]:
	"""
	Lookup Primary DC entry in DNS zone.

	:param domain: DNS domain name.
	:param server: DNS server to query.
	:returns: `False` or the SRV RR.
	"""
	log = getLogger(__name__).getChild('dns/srv')

	rec = '_domaincontroller_master._tcp.%s.' % (domain.rstrip('.'),)
	log.debug('Querying %s for SRV %s', server, rec)

	req = DnsRequest(rec, qtype='SRV', server=['%s' % (server,)], aa=1, rd=0)
	res = req.req()
	log.debug('header=%r', res.header)

	return res.header['status'] == 'NOERROR' and res.header['aa']  # type: ignore


def get_ip(rr: Dict[str, str]) -> IPAddress:
	r"""
	Resolve DNS Resource Record to IP address.

	:param rr: DNS RR.
	:returns: resolved IP address.
	:raises TypeError: if the RR is neither an `A` nor `AAAA` record.

	>>> get_ip({'typename': 'A', 'data': '127.0.0.1'})
	IPv4Address('127.0.0.1')
	>>> get_ip({'typename': 'AAAA', 'data': '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01'})
	IPv6Address('::1')
	"""
	typ, data = rr['typename'], rr['data']
	if typ == 'A':
		return ip_address(u'%s' % (data,))  # type: ignore
	elif typ == 'AAAA':
		# Work-around bug in python-pydns, which does not unpack IPv6 addresses
		assert len(data) == 16
		if not isinstance(data, bytes):
			data = bytes(data, 'charmap')  # type: ignore
		return ip_address(data)  # type: ignore
	else:
		raise TypeError(typ)


def is_self(addr: Union[str, IPv4Address, IPv6Address]) -> bool:
	"""
	Check if given address is associated with the local host.

	:param addr: An IP address or domain name.
	:returns: `True` if the address is local.

	>>> is_self('127.0.0.1')
	True
	>>> is_self('::1')
	True
	>>> is_self('8.8.8.8')
	False
	>>> is_self('0.0.0.1')
	False
	"""
	log = getLogger(__name__).getChild('ip')

	env = dict(environ)
	env['LC_ALL'] = 'C'
	cmd = ['ip', 'route', 'get', '%s' % addr]
	log.debug('calling %r', cmd)
	try:
		out = check_output(cmd, env=env).decode('UTF-8')
		return out.startswith('local ')
	except CalledProcessError as ex:
		log.warning('Failed to determine route: %s', ex)
		return False


def run_tests() -> None:
	"""
	Run internal test suite.
	"""
	import doctest
	doctest.testmod()


if __name__ == '__main__':
	main()
