#!/usr/bin/python3
#
# Univention AD Connector
#  Resync object from AD to OpenLDAP
#
# SPDX-FileCopyrightText: 2018-2026 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only


import os
import sqlite3
import sys
import time
from argparse import ArgumentParser

import ldap
from samba.dcerpc import misc
from samba.ndr import ndr_unpack

import univention.connector.ad
from univention.config_registry import ConfigRegistry
from univention.dn import DN


class GUIDNotFound(BaseException):
    pass


class DNNotFound(BaseException):
    pass


class ad(univention.connector.ad.ad):

    def _remove_cache_entries(self, guid):
        cache_filename = f'/etc/univention/{CONFIGBASENAME}/adcache.sqlite'
        if not os.path.exists(cache_filename):
            return
        cache_db = sqlite3.connect(cache_filename)
        c = cache_db.cursor()
        c.execute("SELECT id FROM GUIDS WHERE guid=?", (str(guid),))
        guid_ids = c.fetchone()
        if guid_ids:
            guid_id = guid_ids[0]
            c.execute("DELETE from DATA where guid_id = ?", (guid_id,))
            c.execute("DELETE from GUIDS where id = ?", (guid_id,))
            cache_db.commit()
        cache_db.close()
        os.chmod(cache_filename, 640)

    def _add_object_to_rejected(self, ad_dn, usn):
        state_filename = f'/etc/univention/{CONFIGBASENAME}/internal.sqlite'
        db = sqlite3.connect(state_filename)
        c = db.cursor()
        c.execute("INSERT OR REPLACE INTO 'AD rejected' (key, value) VALUES (?, ?);", (usn, ad_dn))
        db.commit()
        db.close()
        os.chmod(state_filename, 640)

    def resync(self, ad_dns=None, ldapfilter=None, ldapbase=None):
        result = self.search_ad(ad_dns, ldapfilter, ldapbase)

        # If a DN to resync happens to be a subtree DN, we might also want to resync the ancestors
        self.prepend_ancestors_of_allowed_subtrees(result, ldapfilter, ldapbase)

        treated_dns = []
        for ad_dn, guid, usn in result:
            self._remove_cache_entries(guid)
            self._add_object_to_rejected(ad_dn, usn)
            treated_dns.append(ad_dn)

        return treated_dns

    def search_ad(self, ad_dns=None, ldapfilter=None, ldapbase=None):
        search_result = []
        if ad_dns:
            if not ldapfilter:
                ldapfilter = '(objectClass=*)'

            error_dns = []
            missing_dns = []
            for targetdn in ad_dns:
                guid = None
                try:
                    res = self.__search_ad(base=targetdn, scope=ldap.SCOPE_BASE, filter=ldapfilter, attrlist=["objectGUID", "uSNChanged"])

                    for msg in res:
                        if not msg[0]:  # Referral
                            continue
                        guid_blob = msg[1]["objectGUID"][0]
                        guid = ndr_unpack(misc.GUID, guid_blob)
                        usn = msg[1]["uSNChanged"][0].decode('ASCII')
                        search_result.append((str(msg[0]), guid, usn))
                    if not guid:
                        missing_dns.append(targetdn)
                except ldap.NO_SUCH_OBJECT as ex:
                    error_dns.append((targetdn, str(ex)))
                except (ldap.REFERRAL, ldap.INVALID_DN_SYNTAX) as ex:
                    error_dns.append((targetdn, str(ex)))
            if error_dns:
                raise DNNotFound(1, error_dns, [r[0] for r in search_result])
            if missing_dns:
                raise GUIDNotFound(1, missing_dns, [r[0] for r in search_result])
        else:
            if not ldapfilter:
                ldapfilter = '(objectClass=*)'

            if not ldapbase:
                ldapbase = self.configRegistry[f'{CONFIGBASENAME}/ad/ldap/base']

            guid = None
            try:
                res = self.__search_ad(base=ldapbase, scope=ldap.SCOPE_SUBTREE, filter=ldapfilter, attrlist=["objectGUID", "uSNChanged"])

                for msg in res:
                    if not msg[0]:  # Referral
                        continue
                    guid_blob = msg[1]["objectGUID"][0]
                    guid = ndr_unpack(misc.GUID, guid_blob)
                    usn = msg[1]["uSNChanged"][0].decode('ASCII')
                    search_result.append((str(msg[0]), guid, usn))
            except (ldap.REFERRAL, ldap.INVALID_DN_SYNTAX):
                raise DNNotFound(2, ldapbase)

            if not guid:
                raise GUIDNotFound(2, "No match")

        return search_result

    def _get_allowed_subtrees(self) -> list[DN]:
        allowed_subtrees: list[DN] = []

        for key in self.configRegistry:
            if key.startswith(f'{CONFIGBASENAME}/ad/mapping/allowsubtree') and key.endswith('/ad'):
                allowed_subtrees.append(DN(self.configRegistry[key]))

        return allowed_subtrees

    def prepend_ancestors_of_allowed_subtrees(self, ad_search_result: list[tuple] | None, ldapfilter, ldapbase):
        if self.configRegistry.is_false(f"{CONFIGBASENAME}/ad/mapping/allow-subtree-ancestors", False) or ad_search_result is None:
            return

        ad_ldap_base = DN(self.configRegistry.get("connector/ad/ldap/base"))

        allowed_subtrees = self._get_allowed_subtrees()

        ancestor_list = []
        for object_dn, _, _ in ad_search_result:
            object_dn = DN(object_dn)
            if object_dn not in allowed_subtrees:
                continue

            subtree_dn = object_dn
            parent_dn = object_dn.parent
            while parent_dn and parent_dn != ad_ldap_base:
                parent = str(parent_dn)
                if parent not in ancestor_list:
                    print(f"{subtree_dn} is an allowed subtree. Adding ancestor DN {parent} to resync list.")
                    ancestor_list.insert(0, parent)
                parent_dn = parent_dn.parent

        if not ancestor_list:
            return

        ancestor_ad_search_result = self.search_ad(ancestor_list, ldapfilter, ldapbase)
        ancestor_ad_search_result.sort(key=lambda x: len(x[0]))
        ad_search_result[:0] = ancestor_ad_search_result


if __name__ == '__main__':
    parser = ArgumentParser(description="Resync object from AD to UCS")
    parser.add_argument("-f", "--filter", dest="ldapfilter", help="LDAP search filter")
    parser.add_argument("-b", "--base", dest="ldapbase", help="LDAP search base")
    parser.add_argument("-c", "--configbasename", help="Config basename", metavar="CONFIGBASENAME", default="connector")
    parser.add_argument("dn", nargs='?', default=None, help="Active Directory DN to resync")
    options = parser.parse_args()

    CONFIGBASENAME = options.configbasename
    state_directory = f'/etc/univention/{CONFIGBASENAME}'
    if not os.path.exists(state_directory):
        parser.error(f"Invalid configbasename, directory {state_directory} does not exist")

    if not options.dn and not options.ldapfilter:
        parser.print_help()
        sys.exit(2)

    configRegistry = ConfigRegistry()
    configRegistry.load()

    poll_sleep = int(configRegistry[f'{CONFIGBASENAME}/ad/poll/sleep'])
    ad_init = None

    ad_dns = list(filter(None, [options.dn]))

    treated_dns = []

    try:
        resync = ad.main(configRegistry, CONFIGBASENAME)
        resync.init_ldap_connections()
        treated_dns = resync.resync(ad_dns, options.ldapfilter, options.ldapbase)
    except ldap.SERVER_DOWN:
        print("Warning: Can't initialize LDAP-Connections, wait...")
        sys.stdout.flush()
        time.sleep(poll_sleep)
    except DNNotFound as ex:
        print(f'ERROR: The AD object was not found: {ex.args[1]}')
        if len(ex.args) == 3:
            treated_dns = ex.args[2]
        sys.exit(1)
    except GUIDNotFound as ex:
        print(f'ERROR: The AD search for objectGUID failed: {ex.args[1]}')
        if len(ex.args) == 3:
            treated_dns = ex.args[2]
        sys.exit(1)
    finally:
        for dn in treated_dns:
            print(f'resync triggered for {dn}')

    if treated_dns:
        estimated_delay = 60
        try:
            estimated_delay = int(resync.configRegistry.get(f'{CONFIGBASENAME}/ad/retryrejected', 10)) * int(resync.configRegistry.get(f'{CONFIGBASENAME}/ad/poll/sleep', 5))
        except ValueError:
            pass

        print(f'Estimated sync in {estimated_delay} seconds.')
    else:
        print('No matching objects.')

    sys.exit(0)
