#!/usr/share/ucs-test/runner python3
## desc: "GPO Security Descriptor sync"
## exposure: dangerous
## packages:
##   - univention-config
##   - univention-directory-manager-tools
##   - univention-samba4
##   - univention-s4-connector
#
#  Bug #33768


from __future__ import print_function

import time
import subprocess
import re

import univention.uldap
from univention.config_registry import ConfigRegistry
from univention.s4connector import configdb
import univention.testing.utils as utils
from univention.testing.strings import random_username
import univention.testing.udm as udm_test
import s4connector

from ldap.filter import filter_format
import ldb
from samba.param import LoadParm
from samba.credentials import Credentials
from samba.samdb import SamDB
from samba.auth import system_session
from samba.sd_utils import SDUtils
from samba.ndr import ndr_unpack
from samba.dcerpc import security


def set_ucr(ucr_set, ucr_unset=None, ucr=None):
	if not ucr:
		ucr = ConfigRegistry()
		ucr.load()

	previous_ucr_set = []
	previous_ucr_unset = []

	if ucr_set:
		if isinstance(ucr_set, str):
			ucr_set = (ucr_set,)

		for setting in ucr_set:
			var = setting.split("=", 1)[0]
			new_val = setting.split("=", 1)[1]
			old_val = ucr.get(var)
			if new_val == old_val:
				continue

			if old_val is not None:
				previous_ucr_set.append(u'%s=%s' % (var, old_val))
			else:
				previous_ucr_unset.append(u'%s' % (var,))

		univention.config_registry.handler_set(ucr_set)

	if ucr_unset:
		if isinstance(ucr_unset, str):
			ucr_unset = (ucr_unset,)

		for var in ucr_unset:
			val = ucr.get(var)
			if val is not None:
				previous_ucr_set.append(u'%s=%s' % (var, val))

		univention.config_registry.handler_unset(ucr_unset)

	return (previous_ucr_set, previous_ucr_unset)


class Testclass_GPO_Security_Descriptor(object):

	def __init__(self, udm, ucr=None):
		self.SAM_LDAP_FILTER_GPO = "(&(objectclass=grouppolicycontainer)(cn=%s))"
		self.gpo_ldap_filter = None
		self.gponame = None

		self.udm = udm

		if ucr:
			self.ucr = ucr
		else:
			self.ucr = ConfigRegistry()
			self.ucr.load()

		self.adminaccount = utils.UCSTestDomainAdminCredentials()
		self.machine_ucs_ldap = univention.uldap.getMachineConnection()

		self.fqdn = ".".join((self.ucr["hostname"], self.ucr["domainname"]))

		self.lp = LoadParm()
		self.lp.load_default()

		self.samba_machine_creds = Credentials()
		self.samba_machine_creds.guess(self.lp)
		self.samba_machine_creds.set_machine_account(self.lp)
		self.machine_samdb = SamDB(url="/var/lib/samba/private/sam.ldb", session_info=system_session(), credentials=self.samba_machine_creds, lp=self.lp)
		self.domain_sid = security.dom_sid(self.machine_samdb.get_domain_sid())
		self.DA_SID = security.dom_sid("%s-%d" % (self.domain_sid, security.DOMAIN_RID_ADMINS))
		self.DU_SID = security.dom_sid("%s-%d" % (self.domain_sid, security.DOMAIN_RID_USERS))

		self.samba_admin_creds = Credentials()
		self.samba_admin_creds.guess(self.lp)
		self.samba_admin_creds.parse_string(self.adminaccount.username)
		self.samba_admin_creds.set_password(self.adminaccount.bindpw)
		self.admin_samdb = SamDB(url="/var/lib/samba/private/sam.ldb", session_info=system_session(), credentials=self.samba_admin_creds, lp=self.lp)
		self.admin_samdb_sdutil = SDUtils(self.admin_samdb)

	def restart_s4_connector(self):
		cmd = ("/etc/init.d/univention-s4-connector", "restart")
		p1 = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
		stdout, stderr = p1.communicate()
		if p1.returncode != 0:
			utils.fail("Error restarting S4 Connector: %s\nCommand was: %s" % (stdout.decode('UTF-8', 'replace'), cmd))

	def activate_ntsd_sync(self):
		ucr_set = ["connector/s4/mapping/gpo/ntsd=true", ]
		self.previous_ucr_set, self.previous_ucr_unset = set_ucr(ucr_set, ucr=self.ucr)
		if self.previous_ucr_unset or self.previous_ucr_set:
			self.restart_s4_connector()

	def __enter__(self):
		self.activate_ntsd_sync()
		return self

	def __exit__(self, exc_type, exc_value, traceback):
		if exc_type:
			print('GPO Cleanup after exception: %s %s' % (exc_type, exc_value))
		if self.previous_ucr_unset or self.previous_ucr_set:
			set_ucr(self.previous_ucr_set, self.previous_ucr_unset, ucr=self.ucr)
			self.restart_s4_connector()
		self.remove_gpo()

	def get_ldb_object(self, dn=None, ldap_filter=None, attrs=None):
		if not attrs:
			attrs = ["*"]
		if not ldap_filter:
			ldap_filter = "(objectClass=*)"

		if dn:
			res = self.machine_samdb.search(base=dn, scope=ldb.SCOPE_BASE, expression=ldap_filter, attrs=attrs)
		else:
			res = self.machine_samdb.search(expression=ldap_filter, attrs=attrs)

		for ldb_msg in res:
			return ldb_msg

	def get_ldb_gpo(self, gponame):
		ldap_filter = filter_format(self.SAM_LDAP_FILTER_GPO, (gponame,))
		attrs = ["nTSecurityDescriptor", "uSNChanged"]
		ldb_msg = self.get_ldb_object(ldap_filter=ldap_filter, attrs=attrs)
		return ldb_msg

	def get_ntsd(self, obj):
		if isinstance(obj, ldb.Message):
			ntsd_ndr = obj["nTSecurityDescriptor"][0]
			ntsd = ndr_unpack(security.descriptor, ntsd_ndr)
		elif isinstance(obj, tuple):
			ntsd_sddl = obj[1].get("msNTSecurityDescriptor", [None])[0]
			if not ntsd_sddl:
				raise ValueError("No msNTSecurityDescriptor synchronized")
			ntsd = security.descriptor.from_sddl(ntsd_sddl.decode('ASCII'), self.domain_sid)
		elif isinstance(obj, str):
			ntsd = security.descriptor.from_sddl(obj, self.domain_sid)
		elif isinstance(obj, bytes):
			ntsd = security.descriptor.from_sddl(obj.decode('ASCII'), self.domain_sid)
		else:
			raise ValueError("General ValueError")

		return ntsd

	def assert_owner(self, ntsd, expected_sid, logtag='assert_owner'):
		if ntsd.owner_sid != expected_sid:
			utils.fail("ERROR: %s: Unexpected owner SID! Expected: %s, Found: %s" % (logtag, expected_sid, ntsd.owner_sid))

	def get_ucs_ldap_object(self, ucs_dn):
		res = self.machine_ucs_ldap.search(base=ucs_dn, scope="base", attr=["*"])
		return res[0]

	def wait_for_s4connector_sync_to_ucs(self, ldb_msg, logtag="wait_for_s4connector_sync_to_ucs"):

		usn = int(ldb_msg["uSNChanged"][0])

		configdbfile = '/etc/univention/connector/s4internal.sqlite'
		s4c_internaldb = configdb(configdbfile)

		t0 = time.time()
		while int(s4c_internaldb.get("S4", "lastUSN")) < usn:
			if time.time() - t0 > 120:
				utils.fail("ERROR: %s: Replication takes too long, aborting" % logtag)
			time.sleep(1)
		time.sleep(15)

	def wait_for_object_usn_change(self, ldb_msg, logtag="wait_for_object_usn_change"):

		initial_usn = int(ldb_msg["uSNChanged"][0])
		usn = initial_usn

		t0 = time.time()
		while usn == initial_usn:
			time.sleep(1)
			if time.time() - t0 > 120:
				utils.fail("ERROR: %s: Replication takes too long, aborting" % logtag)
			ldb_msg = self.get_ldb_object(dn=str(ldb_msg.dn), attrs=["uSNChanged"])
			usn = int(ldb_msg["uSNChanged"][0])
		time.sleep(15)

	def remove_gpo(self, critical=True):
		if self.gponame:
			cmd = (
				"samba-tool", "gpo", "del", self.gponame,
				"-k", "no",
				"-H", "ldap://%s" % (self.fqdn,),
				"--username", self.adminaccount.username,
				"--password", self.adminaccount.bindpw)

			p1 = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
			stdout, stderr = p1.communicate()
			if p1.returncode != 0:
				if critical:
					utils.fail("Error removing GPO using samba-tool: %s\nCommand was: %s" % (stdout.decode('UTF-8', 'replace'), cmd))
			else:
				self.gponame = None

	def create_gpo(self, logtag="create_gpo"):
		display_name = 'ucs_test_gpo_' + random_username(8)

		cmd = (
			"samba-tool", "gpo", "create", display_name,
			"-k", "no",
			"-H", "ldap://%s" % (self.fqdn,),
			"--username", self.adminaccount.username,
			"--password", self.adminaccount.bindpw)

		p1 = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
		stdout, stderr = p1.communicate()
		if p1.returncode != 0:
			utils.fail("ERROR: %s: creating GPO using samba-tool: %s\nCommand was: %s" % (logtag, stdout.decode('UTF-8', 'replace'), cmd))

		stdout = stdout.decode('UTF-8', 'replace').rstrip()
		try:
			self.gponame = '{' + re.search('{(.+?)}', stdout).group(1) + '}'
			self.gpo_ldap_filter = filter_format(self.SAM_LDAP_FILTER_GPO, (self.gponame,))
		except AttributeError as ex:
			utils.fail("Could not find the GPO reference in the STDOUT '%s' of the 'samba-tool', error: '%s'" % (stdout, ex))

	def modify_udm_object(self, modulename, **kwargs):
		cmd = self.udm._build_udm_cmdline(modulename, 'modify', kwargs)
		child = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
		(stdout, stderr) = child.communicate()

		if child.returncode:
			raise udm_test.UCSTestUDM_ModifyUDMObjectFailed({'module': modulename, 'kwargs': kwargs, 'returncode': child.returncode, 'stdout': stdout.decode('UTF-8', 'replace'), 'stderr': stderr.decode('UTF-8', 'replace')})

	def modify_sd_on_ucs_ldap_gpo(self, ucs_dn, ucs_ntsd):
		self.modify_udm_object('container/msgpo', dn=ucs_dn, msNTSecurityDescriptor=ucs_ntsd.as_sddl())


class Testcase_GPO_Security_Descriptor_UDM_to_SAM(Testclass_GPO_Security_Descriptor):

	def run(self):
		sync_from = "UDM"
		sync_to = "SAM"
		print("GPO Security Descriptor sync from %s to %s" % (sync_from, sync_to))
		PHASE = "preparation"

		self.create_gpo(logtag=PHASE)
		print('GPO Name: %s' % self.gponame)
		ldb_msg = self.get_ldb_gpo(self.gponame)
		sam_ntsd = self.get_ntsd(ldb_msg)
		self.assert_owner(sam_ntsd, self.DA_SID, logtag=PHASE)
		self.wait_for_s4connector_sync_to_ucs(ldb_msg, logtag=PHASE)

		sam_ntsd = self.get_ntsd(ldb_msg)
		self.assert_owner(sam_ntsd, self.DA_SID, logtag=PHASE)

		# we need the exact case of the DN otherwise udm cli will fail
		temp_dn = str(ldb_msg.dn).lower().replace(self.ucr["samba4/ldap/base"].lower(), self.ucr["ldap/base"])
		ucs_dn = self.machine_ucs_ldap.searchDn(base=temp_dn, scope='base')[0]

		uldap_msg = self.get_ucs_ldap_object(ucs_dn)
		try:
			ucs_ntsd = self.get_ntsd(uldap_msg)
		except ValueError as ex:
			utils.fail("ERROR: %s: %s" % (PHASE, ex.args[0]))
		if ucs_ntsd.as_sddl() != sam_ntsd.as_sddl():
			utils.fail("ERROR: %s: NT Security descriptor differs between %s and %s" % (PHASE, sync_from, sync_to))

		PHASE = "test"

		ucs_ntsd.owner_sid = self.DU_SID
		self.modify_sd_on_ucs_ldap_gpo(ucs_dn, ucs_ntsd)

		uldap_msg = self.get_ucs_ldap_object(ucs_dn)
		try:
			ucs_ntsd = self.get_ntsd(uldap_msg)
		except ValueError as ex:
			utils.fail("ERROR: %s: %s" % (PHASE, ex.args[0]))
		self.assert_owner(ucs_ntsd, self.DU_SID, logtag=PHASE)

		self.wait_for_object_usn_change(ldb_msg, logtag=PHASE)

		ldb_msg = self.get_ldb_gpo(self.gponame)
		sam_ntsd = self.get_ntsd(ldb_msg)

		if ucs_ntsd.as_sddl() != sam_ntsd.as_sddl():
			utils.fail("ERROR: %s: NT Security descriptor not synchronized from %s to %s" % (PHASE, sync_from, sync_to))

		PHASE = "cleanup"

		ucs_ntsd.owner_sid = self.DA_SID
		self.modify_sd_on_ucs_ldap_gpo(ucs_dn, ucs_ntsd)

		uldap_msg = self.get_ucs_ldap_object(ucs_dn)
		try:
			ucs_ntsd = self.get_ntsd(uldap_msg)
		except ValueError as ex:
			utils.fail("ERROR: %s: %s" % (PHASE, ex.args[0]))
		self.assert_owner(ucs_ntsd, self.DA_SID, logtag=PHASE)

		self.wait_for_object_usn_change(ldb_msg, logtag=PHASE)
		ldb_msg = self.get_ldb_gpo(self.gponame)
		sam_ntsd = self.get_ntsd(ldb_msg)

		if ucs_ntsd.as_sddl() != sam_ntsd.as_sddl():
			utils.fail("ERROR: %s: NT Security descriptor not re-synchronized from %s to %s" % (PHASE, sync_from, sync_to))


class Testcase_GPO_Security_Descriptor_SAM_to_UDM(Testclass_GPO_Security_Descriptor):

	def run(self):
		sync_from = "SAM"
		sync_to = "UDM"
		print("GPO Security Descriptor sync from %s to %s" % (sync_from, sync_to))
		PHASE = "preparation"

		self.create_gpo(logtag=PHASE)
		print('GPO Name: %s' % self.gponame)
		ldb_msg = self.get_ldb_gpo(self.gponame)
		sam_ntsd = self.get_ntsd(ldb_msg)
		self.assert_owner(sam_ntsd, self.DA_SID, logtag=PHASE)
		self.wait_for_s4connector_sync_to_ucs(ldb_msg, logtag=PHASE)

		sam_ntsd = self.get_ntsd(ldb_msg)
		self.assert_owner(sam_ntsd, self.DA_SID, logtag=PHASE)

		ucs_dn = str(ldb_msg.dn).lower().replace(self.ucr["samba4/ldap/base"].lower(), self.ucr["ldap/base"].lower())
		uldap_msg = self.get_ucs_ldap_object(ucs_dn)
		try:
			ucs_ntsd = self.get_ntsd(uldap_msg)
		except ValueError as ex:
			utils.fail("ERROR: %s: %s" % (PHASE, ex.args[0]))
		if ucs_ntsd.as_sddl() != sam_ntsd.as_sddl():
			utils.fail("ERROR: %s: NT Security descriptor differs between %s and %s" % (PHASE, sync_from, sync_to))

		PHASE = "test"

		sam_ntsd.owner_sid = self.DU_SID
		self.admin_samdb_sdutil.modify_sd_on_dn(str(ldb_msg.dn), sam_ntsd)

		ldb_msg = self.get_ldb_gpo(self.gponame)
		sam_ntsd = self.get_ntsd(ldb_msg)
		self.assert_owner(sam_ntsd, self.DU_SID, logtag=PHASE)

		self.wait_for_s4connector_sync_to_ucs(ldb_msg, logtag=PHASE)

		uldap_msg = self.get_ucs_ldap_object(ucs_dn)
		try:
			ucs_ntsd = self.get_ntsd(uldap_msg)
		except ValueError as ex:
			utils.fail("ERROR: %s: %s" % (PHASE, ex.args[0]))

		if ucs_ntsd.as_sddl() != sam_ntsd.as_sddl():
			print('ucs_ntsd.as_sddl: %s' % ucs_ntsd.as_sddl())
			print('sam_ntsd.as_sddl: %s' % sam_ntsd.as_sddl())
			utils.fail("ERROR: %s: NT Security descriptor not synchronized from %s to %s" % (PHASE, sync_from, sync_to))

		PHASE = "cleanup"

		sam_ntsd.owner_sid = self.DA_SID
		self.admin_samdb_sdutil.modify_sd_on_dn(str(ldb_msg.dn), sam_ntsd)

		ldb_msg = self.get_ldb_gpo(self.gponame)
		sam_ntsd = self.get_ntsd(ldb_msg)
		self.assert_owner(sam_ntsd, self.DA_SID, logtag=PHASE)

		self.wait_for_s4connector_sync_to_ucs(ldb_msg, logtag=PHASE)
		uldap_msg = self.get_ucs_ldap_object(ucs_dn)
		try:
			ucs_ntsd = self.get_ntsd(uldap_msg)
		except ValueError as ex:
			utils.fail("ERROR: %s: %s" % (PHASE, ex.args[0]))

		if ucs_ntsd.as_sddl() != sam_ntsd.as_sddl():
			utils.fail("ERROR: %s: NT Security descriptor not re-synchronized from %s to %s" % (PHASE, sync_from, sync_to))


if __name__ == "__main__":
	s4connector.exit_if_connector_not_running()

	with udm_test.UCSTestUDM() as udm:
		with Testcase_GPO_Security_Descriptor_SAM_to_UDM(udm) as test:
			test.run()

		with Testcase_GPO_Security_Descriptor_UDM_to_SAM(udm) as test:
			test.run()
