#!/usr/bin/python3
#
# Univention Samba
#  helper script: kerberize a user account
#
# SPDX-FileCopyrightText: 2001-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only


import codecs
import getopt
import hashlib
import os
import sys
from datetime import datetime

import ldap
from ldap.filter import filter_format

import univention.config_registry
import univention.debug as ud
import univention.lib.policy_result
import univention.uldap


ud.init('/dev/null', ud.FLUSH, ud.FUNCTION)

configRegistry = univention.config_registry.ConfigRegistry()
configRegistry.load()

lo = univention.uldap.getAdminConnection()
krbbase = 'ou=krb5,' + configRegistry['ldap/base']
realm = configRegistry['kerberos/realm']

UNIXDAY = 3600 * 24


def _get_samba_password_history(newpassword, smbpwhistory, smbpwhlen):
    # calculate the password hash & salt
    # in binary for calculating the md5:
    salt = os.urandom(16)
    # we have to have that in hex:
    hexsalt = codecs.encode(salt, 'hex').upper().decode('ASCII')
    # we need the ntpwd binary data to
    pwd = codecs.decode(newpassword, 'hex')
    # calculating hash. stored as a 32byte hex in sambaPasswordHistory,
    # syntax like that: [Salt][MD5(Salt+Hash)]
    #    First 16bytes  ^     ^ last 16bytes.
    pwdhash = hashlib.md5(salt + pwd).hexdigest().upper()
    smbpwhash = hexsalt + pwdhash

    # split the history
    pwlist = smbpwhistory.strip().split(' ')
    # append new hash
    pwlist.append(smbpwhash)
    # strip old hashes
    pwlist = pwlist[-smbpwhlen:]
    # build history
    smbpwhistory = ''.join(pwlist)
    return smbpwhistory


def unixdayToKrb5Date(unixday):
    date = datetime.fromtimestamp(float(unixday.decode('ASCII'))).strftime("%Y%m%d%H%M%S") + 'Z'
    return date.encode('ASCII')


def nt_password_to_arcfour_hmac_md5(nt_password):
    # all arcfour-hmac-md5 keys begin this way
    key = b'0\x1d\xa1\x1b0\x19\xa0\x03\x02\x01\x17\xa1\x12\x04\x10'

    for i in range(16):
        o = nt_password[2 * i:2 * i + 2]
        key += chr(int(o, 16)).encode('ISO8859-1')
    return key


def main():
    username = None
    optlist, _mail_user = getopt.getopt(sys.argv[1:], 'u:')
    for option, value in optlist:
        if option == '-u':
            username = value

    if not username:
        sys.exit(0)

    ldap_attr = ['uid', 'krb5Key', 'sambaNTPassword', 'sambaLMPassword', 'sambaAcctFlags', 'objectClass', 'userPassword', 'krb5PasswordEnd', 'sambaPwdCanChange', 'sambaPwdMustChange', 'sambaPwdLastSet', 'krb5KDCFlags', 'shadowLastChange', 'sambaPasswordHistory', 'shadowMax']
    for dn, attrs in lo.search(filter=filter_format('(&(objectClass=sambaSamAccount)(sambaNTPassword=*)(uid=%s)(!(objectClass=univentionWindows)))', [username]), attr=ldap_attr):
        if b'univentionHost' in attrs['objectClass']:
            continue
        if attrs['sambaNTPassword'][0] != b"NO PASSWORDXXXXXXXXXXXXXXXXXXXXX":

            if attrs['uid'][0] == b'root':
                print('Skipping user root ')
                continue

            # check if the user was disabled
            sambaAcctFlags = attrs.get('sambaAcctFlags', [b''])[0]
            disabled = b'D' in sambaAcctFlags

            ocs = []
            ml = []
            if b'krb5Principal' not in attrs['objectClass']:
                ocs.append('krb5Principal')
                principal = b'%s@%s' % (attrs['uid'][0], realm.encode('UTF-8'))
                ml.append(('krb5PrincipalName', None, principal))

            flag = b'256' if disabled else b'126'

            if not sambaAcctFlags:
                ml.append(('sambaAcctFlags', None, b'[U          ]'))

            policy_result = univention.lib.policy_result.policy_result(dn)[0]
            if not attrs.get('sambaPasswordHistory', False):
                smbpwhlen = int(policy_result.get('univentionPWHistoryLen', [0])['0'])
                ml.append(('sambaPasswordHistory', None, _get_samba_password_history(dn, attrs.get('sambaNTPassword', [b''])[0], '', smbpwhlen)))

            if attrs.get('sambaPwdLastSet', False):
                usersPWExpireInterval = int(policy_result.get('univentionPWExpiryInterval', [0])['0'])

                # if not 'krb5PasswordEnd' in attrs['objectClass']:
                #    ocs.append('krb5PasswordEnd')
                oldkrb5PasswordEndValue = attrs.get('krb5PasswordEnd', [None])[0]
                # if not 'sambaPwdCanChange' in attrs['objectClass']:
                #    ocs.append('sambaPwdCanChange')
                oldsambaPwdCanChangeValue = attrs.get('sambaPwdCanChange', [None])[0]
                # if not 'sambaPwdMustChange' in attrs['objectClass']:
                #    ocs.append('sambaPwdMustChange')
                oldsambaPwdMustChangeValue = attrs.get('sambaPwdMustChange', [None])[0]
                oldshadowMaxValue = attrs.get('shadowMax', [None])[0]
                oldshadowLastChangeValue = attrs.get('shadowLastChange', [None])[0]
                sambaPwdLastSetValue = int(attrs[('sambaPwdLastSet')][0].decode('ASCII'))
                # Debug # print 'SambaPwdLastSet "%d", "%d", "%d"' %(sambaPwdLastSetTimestamp, pwdlifetime, unixday)
                shadowLastChangeValue = str(sambaPwdLastSetValue // UNIXDAY).encode('ASCII')
                sambaPwdCanChangeValue = str(sambaPwdLastSetValue + UNIXDAY).encode('ASCII')
                if usersPWExpireInterval:
                    # print 'DEBUG: PWExpireInterval policy valid, calculating and setting expiring dates'
                    sambaPwdMustChangeValue = str(sambaPwdLastSetValue + int(usersPWExpireInterval * UNIXDAY)).encode('ASCII')
                    krb5PasswordEndValue = str(unixdayToKrb5Date(sambaPwdMustChangeValue)).encode('ASCII')
                    shadowMaxValue = str(usersPWExpireInterval).encode('ASCII')
                else:
                    # print 'DEBUG: PWExpireInterval policy not set, removing expire intervals and dates'
                    sambaPwdMustChangeValue = None
                    krb5PasswordEndValue = None
                    shadowMaxValue = None
                ml.extend([
                    ('krb5PasswordEnd', oldkrb5PasswordEndValue, krb5PasswordEndValue),
                    ('sambaPwdCanChange', oldsambaPwdCanChangeValue, sambaPwdCanChangeValue),
                    ('sambaPwdMustChange', oldsambaPwdMustChangeValue, sambaPwdMustChangeValue),
                    ('shadowMax', oldshadowMaxValue, shadowMaxValue),
                    ('shadowLastChange', oldshadowLastChangeValue, shadowLastChangeValue),
                ])
            else:
                print('Could not find attribute "sambaPwdLastSet". Skipping generating of "krb5PasswordEnd".')

            if b'krb5KDCEntry' not in attrs['objectClass']:
                ocs.append('krb5KDCEntry')
                ml.extend([
                    ('krb5MaxLife', None, b'86400'),
                    ('krb5MaxRenew', None, b'604800'),
                    ('krb5KeyVersionNumber', None, b'1'),
                ])

            old_flag = attrs.get('krb5KDCFlags', [])
            old_keys = attrs.get('krb5Key', [])

            ml.extend([
                ('krb5Key', old_keys, nt_password_to_arcfour_hmac_md5(attrs['sambaNTPassword'][0])),
                ('krb5KDCFlags', old_flag, flag),
            ])

            if attrs.get('sambaLMPassword', [None])[0] not in [b"NO PASSWORDXXXXXXXXXXXXXXXXXXXXX", None]:
                old_password = attrs.get('userPassword', [])
                if disabled:
                    ml.extend([
                        ('userPassword', old_password, [b'{LANMAN}!%s' % attrs['sambaLMPassword'][0]]),
                    ])
                else:
                    ml.extend([
                        ('userPassword', old_password, [b'{LANMAN}%s' % attrs['sambaLMPassword'][0]]),
                    ])

            if ocs:
                print('Adding Kerberos key for %r...' % (dn,), end=' ')
                ml.insert(0, ('objectClass', None, [x.encode('UTF-8') for x in ocs]))

            try:
                lo.modify(dn, ml)
            except ldap.ALREADY_EXISTS:
                print('already exists')
            else:
                print('done')

        else:
            print('Can not add Kerberos key for %s...' % repr(dn), end=' ')
            print('no password set')


if __name__ == '__main__':
    main()
