#!/usr/bin/python3 -u
# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: 2022-2023 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

"""USI - collect support information of UCS systems"""

import argparse
import glob
import io
import os
import shutil
import smtplib
import subprocess
import sys
import urllib.parse
from contextlib import contextmanager
from email.mime.text import MIMEText
from tempfile import mkdtemp, mkstemp
from typing import cast, Generator, Optional, List, Union

import lxml.html
import requests
from gnupg import GPG, ImportResult

USI = 'https://updates.software-univention.de/download/scripts/univention-support-info'
USI_SCRIPT = '/usr/bin/univention-support-info'

UPLOAD = 'https://upload.univention.de/'


@contextmanager
def temp_dir(suffix='', prefix='', permissions=0o700):
    # type: (Optional[str], Optional[str], int) -> Generator[str, None, None]
    """
    Create temporary directory to work with. It wil be removed
    automatically after the call is finished.

    :return: Path to newly created temporary directory
    """
    dir_path = mkdtemp(suffix=suffix, prefix=prefix)
    os.chmod(dir_path, permissions)
    yield dir_path
    shutil.rmtree(dir_path)


class GnuPG(object):

    def __init__(self, fingerprint=None, password=None, gnupg_home=None, key_paths_to_import=None):
        # type: (Optional[str], Optional[str], Optional[str], Optional[Union[str,List[str]]]) -> None
        """

        :param fingerprint: Fingerprint of key to be used. Default is to use first one found.
        :param password: Passphrase for key used. Default is None.
        :param gnupg_home: Path to GnuPG directory. Default is whatever gpg defaults to.
        """
        self._gpg = GPG(gnupghome=gnupg_home)
        self.import_keys(key_paths_to_import)

        fingerprints = self._gpg.list_keys().fingerprints
        if not fingerprint and fingerprints:
            fingerprint = fingerprints[0]
        elif fingerprint and fingerprint not in fingerprints:
            raise RuntimeError('No %r fingerprint found' % fingerprint)

        self._fingerprint = fingerprint
        self._password = password

    def import_keys(self, key_paths):
        # type: (Optional[Union[str, List[str]]]) -> None
        """
        Import public or private GPG key from (list of) given path(s).

        :param key_paths: (list of) path(s) with public keys to import
        """
        key_paths = key_paths or []
        if not isinstance(key_paths, list):
            key_paths = [key_paths]

        for key_path in key_paths:
            with open(key_path, 'rb') as fd:
                res = self._gpg.import_keys(fd.read())
                self._check_import_result(res)

    @classmethod
    def _check_import_result(cls, result):
        # type: (ImportResult) -> None
        """
        Check result of GPG import operation. If not successful,
        raise an exception.

        :param result: Result to be checked
        """
        if not isinstance(result, ImportResult):
            raise RuntimeError('ImportResult is expected as parameter')
        elif result.count != 1:
            raise RuntimeError('Key not imported: %r' % result.stderr)

    def verify_detached_signature(self, signature_file, path=None):
        # type: (str, Optional[str]) -> bool
        """

        :param signature_file: Path to signature file
        :param path: Path to data file
        :return:
        """
        if not path:
            path = signature_file[:-len('.sig')] if signature_file.endswith('.sig') else path

        with open(signature_file, 'rb') as fd:
            res = self._gpg.verify_file(fd, path)
            return cast(bool, res.valid)


class Main(object):

    def __init__(self, args):  # type: (argparse.Namespace) -> None
        self.args = args

    def execute(self):  # type: () -> int
        self.download_script()
        return self.run_script()

    def download_script(self):  # type: () -> None
        # download files within temporary directory
        with temp_dir() as tmp_dir:  # type: str
            script_path = os.path.join(tmp_dir, 'usi.py')
            etag_path = '/var/lib/univention-support-info/.etag'

            res = self._download_file(url=USI, local_path=script_path, etag_path=etag_path)
            if 200 == res:
                print('Collected new Univention Support Info', file=sys.stderr)
            elif 304 == res:
                print('Using cached Univention Support Info', file=sys.stderr)
                return
            else:
                print('Failed to download USI script. Trying to use cached one!', file=sys.stderr)
                return

            # download signature
            signature_url = '%s.gpg' % USI
            signature_path = os.path.join(tmp_dir, 'usi.py.sig')
            res = self._download_file(url=signature_url, local_path=signature_path)
            print('Download signature %s = %s' % (signature_url, res,))

            # check signature
            gpg = GnuPG(key_paths_to_import=glob.glob('/etc/apt/trusted.gpg.d/univention-archive-key-ucs-*.gpg'))
            valid = gpg.verify_detached_signature(path=script_path, signature_file=signature_path)
            if not valid:
                print('Signature not valid', file=sys.stderr)
            else:
                # we have a new version downloaded, copy it to proper place
                shutil.copyfile(script_path, USI_SCRIPT, follow_symlinks=False)
                print('New version of the script installed', file=sys.stderr)

    def run_script(self):  # type: () -> int
        keep = bool(self.args.directory)
        usi_file = mkstemp(prefix='univention-support-info-', suffix='.tar.bz2')[1]
        cmd = [USI_SCRIPT]
        if self.args.encrypt:
            cmd.append('--encrypt')
        if self.args.debug:
            cmd.append('--debug')
        if self.args.flat:
            cmd.append('--flat')
        if self.args.full_logs:
            cmd.append('--full-logs')
        if not self.args.quiet:
            cmd.append('--verbose')
        if self.args.directory:
            usi_file = os.path.join(self.args.directory, os.path.basename(usi_file))
        cmd.extend(['--output', usi_file])
        print('Starting Univention Support Info...', file=sys.stderr)
        sys.stderr.flush()

        returncode = subprocess.call(cmd)
        archives = [usi_file]
        if self.args.encrypt:
            archives.insert(0, '%s.gpg' % (usi_file,))
        if not all(os.path.exists(archive) for archive in archives):
            print('No archive could be created!', file=sys.stderr)
            return 1

        try:
            if self.args.upload_to_univention:
                print('Uploading archive to Univention...', file=sys.stderr)
                try:
                    archive = archives[0]  # will be the encrypted one
                    archive_id = self.upload_archive(archive)
                except Exception as exc:
                    print('Could not upload archive: %s' % (exc,), file=sys.stderr)
                    keep = True
                    return 1

                print('Archive has been uploaded with ID %s' % (archive_id,), file=sys.stderr)

                if self.args.sender and '@' in self.args.sender:
                    print('Sending mail to Univention...', file=sys.stderr)
                    try:
                        self.send_mail(archive_id)
                    except Exception as exc:
                        print('The mail could not be send: %s' % (exc,), file=sys.stderr)
                        returncode = 1
        finally:
            if self.args.cleanup and not keep:
                for archive in archives:
                    if os.path.isfile(archive):
                        os.remove(archive)
                        print('Cleaned up file %s' % (archive,))

        return returncode

    def upload_archive(self, archive):  # type: (str) -> str
        response = requests.get(UPLOAD)
        html = io.StringIO(response.text)
        tree = lxml.html.parse(html)
        form = tree.getroot().xpath('//form[@enctype="multipart/form-data"][@method="post"]')[0]
        upload_uri = urllib.parse.urljoin(UPLOAD, form.action)
        data = [(x.name, x.value) for x in form.xpath('//input') if x.type not in ('file', 'submit')]
        name = form.xpath('//input[@type="file"]')[0].name
        with open(archive, 'rb') as fd:
            response = requests.post(upload_uri, files={name: fd}, data=data)
            return cast(str, lxml.html.parse(io.StringIO(response.text)).getroot().xpath('//div[@id="page-body"]/b')[0].text)

    def send_mail(self, archive_id):  # type: (str) -> None
        subject = 'Univention Support Info Upload'
        if self.args.ticket:
            subject = '[Ticket#%s] %s' % (self.args.ticket.strip('# abcdefghijklmnoprstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXZ'), subject)
        msg = MIMEText('''
A new Univention support info archive has been uploaded.

Archive ID: %s''' % (archive_id,))
        msg['Subject'] = subject
        msg['From'] = self.args.sender
        msg['To'] = self.args.recipient
        with smtplib.SMTP() as s:
            s.sendmail(self.args.sender, [self.args.recipient], msg.as_string())

    @classmethod
    def _download_file(cls, url, local_path, etag_path=None):  # type: (str, str, Optional[str]) -> int
        try:
            etag = None
            if etag_path:
                with open(etag_path) as fd:
                    etag = fd.read().strip()
        except FileNotFoundError:
            pass

        headers = {}
        if etag:
            headers['If-None-Match'] = etag

        response = requests.Response()
        try:
            response = requests.get(url, headers=headers)
            if response.status_code == 200:
                with open(local_path, 'wb') as fd:
                    fd.write(response.content)
                os.chmod(local_path, 0o755)
                etag = response.headers.get('ETag')
                if etag and etag_path:
                    with open(etag_path, 'w') as fd:
                        fd.write(etag)
        except requests.exceptions.RequestException:
            response.status_code = 500
            if etag_path and os.path.isfile(etag_path):
                os.remove(etag_path)

        return response.status_code


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Wrapper for USI script. Downloads the latest USI version, executes it, uploads resulting archive to Univention and attach it to a support ticket.')
    parser.add_argument('--upload-to-univention', action='store_true', help='Uploads the resulting USI archive to the Univention upload service.')
    parser.add_argument('--cleanup', action='store_true', help='Remove the locally stored archive file.')
    parser.add_argument('--sender', help='The mail address of the sender. If given, the upload ID is send to the Univention Support.')
    parser.add_argument('--recipient', help='Recipient of the mail (default: %(default)s).', default='support@univention.de')
    parser.add_argument('--add-to-ticket', metavar='TICKET', dest='ticket', help='Adds the file to the ticket number instead of creating a new one.')
    parser.add_argument('--output', metavar='DIRECTORY', dest='directory', help='Keep the resulting USI archive in the specified directory.')

    usiparser = parser.add_argument_group('univention-support-info arguments')
    usiparser.add_argument('--encrypt', action='store_true', help='Encrypt the archive and send only the encrypted version to Univention')
    usiparser.add_argument('--full-logs', action='store_true', help='collect also rotated logfiles')
    usiparser.add_argument('--flat', action='store_true', help='flatten the directory structure')
    usiparser.add_argument('--quiet', action='store_true', help='Almost no output', default=False)
    usiparser.add_argument('--debug', action='store_true', help='enable debug', default=False)

    args = parser.parse_args()
    if os.getuid() != 0:
        parser.error('Must be executed as root!')
    if args.directory and not os.path.isdir(args.directory):
        parser.error('Folder does not exists or is not a directory.')
    sys.exit(Main(args).execute())
