#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# Copyright 2021 Univention GmbH
#
# https://www.univention.de/
#
# All rights reserved.
#
# The source code of this program is made available
# under the terms of the GNU Affero General Public License version 3
# (GNU AGPL V3) as published by the Free Software Foundation.
#
# Binary versions of this program provided by Univention to you as
# well as other copyrighted, protected or trademarked materials like
# Logos, graphics, fonts, specific documentations and configurations,
# cryptographic keys etc. are subject to a license agreement between
# you and Univention and not subject to the GNU AGPL V3.
#
# In the case you use this program under the terms of the GNU AGPL V3,
# the program is provided in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License with the Debian GNU/Linux or Univention distribution in file
# /usr/share/common-licenses/AGPL-3; if not, see
# <https://www.gnu.org/licenses/>.
#

from __future__ import print_function

import os
import re
from hashlib import md5
from glob import glob
import subprocess
from argparse import ArgumentParser, SUPPRESS
import sys

from univention.uldap import getMachineConnection
from univention.ldap_cache.cache import get_cache
from univention.ldap_cache.cache.shard_config import add_shard_to_config, rm_shard_from_config

listener_template = '''#!/usr/bin/python3
from __future__ import absolute_import

from univention.ldap_cache.listener_module import LdapCacheHandler

name = {}

class QueryHandler(LdapCacheHandler):
    class Configuration(LdapCacheHandler.Configuration):
        name = {}
        def get_description(self):
            return 'Automatically created listener module for the univention-group-membership-member cache ({})'
        def get_ldap_filter(self):
            return {}
        def get_attributes(self):
            return {}
'''


def _md5sum(value):
    m = md5()
    m.update(value.encode('utf-8'))
    return m.hexdigest()


def _query_objects(query, attrs):
    lo = getMachineConnection()
    attrs = [attr.decode('utf-8') if isinstance(attr, bytes) else attr for attr in attrs]
    print('Querying', query, 'with', attrs)
    return lo.search(query, attr=attrs)


def _get_queries(caches, cache_names=None):
    queries = {}
    for name, cache in caches:
        if cache_names is not None and name not in cache_names:
            continue
        for shard in cache.shards:
            queries.setdefault(shard.ldap_filter, ([], set()))
            shards, attrs = queries[shard.ldap_filter]
            shards.append(shard)
            attrs.add(shard.key)
            attrs.add(shard.value)
            attrs.update(shard.attributes)
    return queries


def add_cache(args):
    add_shard_to_config(args.db_name, args.single_value, args.reverse, args.key, args.value, args.ldap_filter)


def rm_cache(args):
    rm_shard_from_config(args.db_name, args.single_value, args.reverse, args.key, args.value, args.ldap_filter)


def cleanup(args):
    for name, db in get_cache():
        db.cleanup()


def rebuild(args):
    caches = get_cache()
    cache_names = args.cache_name
    if not cache_names:
        cache_names = sorted([x[0] for x in caches])
    print('Rebuilding', cache_names)
    for name, cache in caches:
        if name in cache_names:
            cache.clear()
    for query, (_caches, attrs) in _get_queries(caches, cache_names).items():
        print('Searching for', query)
        attrs.discard('dn')
        i = 0
        for obj in _query_objects(query, attrs):
            i += 1
            if i % 1000 == 0:
                print('\rProcessing object #', i, end='')
                sys.stdout.flush()
                cleanup(args)
            for shard in _caches:
                shard.add_object(obj)
        if i >= 1000:
            print()
        print('Added', i, 'objects')


def create_listener_modules(args):
    listener_dir = '/usr/lib/univention-directory-listener/system/'
    existing_listeners = set(glob(os.path.join(listener_dir, 'ldap-cache-*')))
    for query, (_caches, attrs) in _get_queries(get_cache()).items():
        attrs.discard('dn')
        attrs.discard('entryUUID')
        listener_name = 'ldap-cache-%s' % _md5sum(query)
        fname = os.path.join(listener_dir, '%s.py' % listener_name)
        existing_listeners.discard(fname)
        print('Writing', fname, 'for', query)
        with open(fname, 'w') as fd:
            fd.write(listener_template.format(repr(listener_name), repr(listener_name), ', '.join(sorted(set(cache.db_name for cache in _caches))), repr(query), repr(sorted(attrs))))
    for fname in existing_listeners:
        print('Removing', fname)
        os.unlink(fname)
    subprocess.call(['service', 'univention-directory-listener', 'restart'])
    for fname in existing_listeners:
        # also remove status "initialized" so that the listener may be added again in the future
        # needs to be done after the listener is running again
        listener_name = os.path.basename(fname)[:-3]
        try:
            os.unlink(os.path.join('/var/lib/univention-directory-listener/handlers/', listener_name))
        except EnvironmentError:
            pass


def list_caches(args):
    caches = get_cache()
    for name, cache in caches:
        print(name)
        print(' The following objects store data:')
        for shard in cache.shards:
            print('  ', shard.ldap_filter)
            if shard.reverse:
                key = shard.value
                value = '[{}]'.format(shard.key)
            elif shard.single_value:
                key = shard.key
                value = shard.value
            else:
                key = shard.key
                value = '[{}]'.format(shard.value)
            print('    ', key, '=>', value)


def query(args):
    caches = get_cache()
    cache = caches.get_sub_cache(args.cache_name)
    if not cache:
        print('No cache named', args.cache_name)
        return
    data = cache.load()
    regex = None
    if args.pattern:
        try:
            regex = re.compile(args.pattern)
        except re.error:
            print('Broken pattern')
            return
    for key in sorted(data):
        if regex is None or regex.search(key):
            print(key, '=>', data[key])


def main():
    usage = '%(prog)s'
    description = 'The LDAP cache stores some portions of the LDAP database so that it can be accessed fast and reliably.'
    parser = ArgumentParser(usage=usage, description=description)
    subparsers = parser.add_subparsers(description='type %(prog)s <action> --help for further help and possible arguments', metavar='action')

    subparser = subparsers.add_parser('query', description='Queries a sub cache and shows the value(s). Mainly for test purposes, this tool is not meant for real applications', help='Query the cache')
    subparser.add_argument('cache_name', help='The name of the sub cache. See "list"')
    subparser.add_argument('pattern', nargs='?', help='Queries the key with this regexp')
    subparser.set_defaults(func=query)

    subparser = subparsers.add_parser('list', description='Lists all sub caches of the cache. Each sub cache may be fed from multiple sources', help='List all parts of the cache')
    subparser.set_defaults(func=list_caches)

    subparser = subparsers.add_parser('rebuild', description='Rebuild the cache completely, retrieve the objects and overwrite all previous data', help='Rebuild the cache')
    subparser.add_argument('cache_name', nargs='*', help='The cache consists of different parts. You can only rebuild certain parts of the cache. See "list"')
    subparser.set_defaults(func=rebuild)

    subparser = subparsers.add_parser('create-listener-modules', description='Automatically creates listener modules that will eventually fill the cache (and removes unnecessary); restarts the univention-directory-listener. May be needed after shards are added to /removed from the cache', help='Create listener modules')
    subparser.set_defaults(func=create_listener_modules)

    subparser = subparsers.add_parser('add-cache', description=SUPPRESS)
    subparser.add_argument('db_name')
    subparser.add_argument('--single-value', action='store_true')
    subparser.add_argument('--reverse', action='store_true')
    subparser.add_argument('key')
    subparser.add_argument('value')
    subparser.add_argument('ldap_filter')
    subparser.set_defaults(func=add_cache)

    subparser = subparsers.add_parser('rm-cache', description=SUPPRESS)
    subparser.add_argument('db_name')
    subparser.add_argument('--single-value', action='store_true')
    subparser.add_argument('--reverse', action='store_true')
    subparser.add_argument('key')
    subparser.add_argument('value')
    subparser.add_argument('ldap_filter')
    subparser.set_defaults(func=rm_cache)

    subparser = subparsers.add_parser('cleanup', description='Over the time, the database may become polluted. Shrink all caches to their optimal size if possible.', help='Clean up cache')
    subparser.set_defaults(func=cleanup)

    args = parser.parse_args()
    try:
        func = args.func
        func(args)
    except AttributeError:
        parser.error("too few arguments")


if __name__ == '__main__':
    main()
