#!/usr/bin/env python3 # Unix SMB/CIFS implementation. # Copyright (C) Stefan Metzmacher 2020 # Copyright (C) 2020 Catalyst.Net Ltd # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # import sys import os sys.path.insert(0, "bin/python") os.environ["PYTHONUNBUFFERED"] = "1" from samba.tests import DynamicTestCase import ldb from samba.tests.krb5.kdc_base_test import KDCBaseTest from samba.tests.krb5.raw_testcase import KerberosCredentials from samba.tests.krb5.rfc4120_constants import ( AES256_CTS_HMAC_SHA1_96, ARCFOUR_HMAC_MD5, KDC_ERR_S_PRINCIPAL_UNKNOWN, NT_PRINCIPAL, ) global_asn1_print = False global_hexdump = False @DynamicTestCase class SpnTests(KDCBaseTest): test_account_types = { 'computer': KDCBaseTest.AccountType.COMPUTER, 'server': KDCBaseTest.AccountType.SERVER, 'rodc': KDCBaseTest.AccountType.RODC } test_spns = { '2_part': 'ldap/{{account}}', '3_part_our_domain': 'ldap/{{account}}/{netbios_domain_name}', '3_part_our_realm': 'ldap/{{account}}/{dns_domain_name}', '3_part_not_our_realm': 'ldap/{{account}}/test', '3_part_instance': 'ldap/{{account}}:test/{dns_domain_name}' } @classmethod def setUpClass(cls): super().setUpClass() cls._mock_rodc_creds = None @classmethod def setUpDynamicTestCases(cls): for account_type_name, account_type in cls.test_account_types.items(): for spn_name, spn in cls.test_spns.items(): tname = f'{spn_name}_spn_{account_type_name}' targs = (account_type, spn) cls.generate_dynamic_test('test_spn', tname, *targs) def _test_spn_with_args(self, account_type, spn): target_creds = self._get_creds(account_type) spn = self._format_spn(spn, target_creds) sname = self.PrincipalName_create(name_type=NT_PRINCIPAL, names=spn.split('/')) client_creds = self.get_client_creds() tgt = self.get_tgt(client_creds) samdb = self.get_samdb() netbios_domain_name = samdb.domain_netbios_name() dns_domain_name = samdb.domain_dns_name() subkey = self.RandomKey(tgt.session_key.etype) etypes = (AES256_CTS_HMAC_SHA1_96, ARCFOUR_HMAC_MD5,) if account_type is self.AccountType.SERVER: ticket_etype = AES256_CTS_HMAC_SHA1_96 else: ticket_etype = None decryption_key = self.TicketDecryptionKey_from_creds( target_creds, etype=ticket_etype) if (spn.count('/') > 1 and (spn.endswith(netbios_domain_name) or spn.endswith(dns_domain_name)) and account_type is not self.AccountType.SERVER and account_type is not self.AccountType.RODC): expected_error_mode = KDC_ERR_S_PRINCIPAL_UNKNOWN check_error_fn = self.generic_check_kdc_error check_rep_fn = None else: expected_error_mode = 0 check_error_fn = None check_rep_fn = self.generic_check_kdc_rep kdc_exchange_dict = self.tgs_exchange_dict( expected_crealm=tgt.crealm, expected_cname=tgt.cname, expected_srealm=tgt.srealm, expected_sname=sname, ticket_decryption_key=decryption_key, check_rep_fn=check_rep_fn, check_error_fn=check_error_fn, check_kdc_private_fn=self.generic_check_kdc_private, expected_error_mode=expected_error_mode, tgt=tgt, authenticator_subkey=subkey, kdc_options='0', expect_edata=False) self._generic_kdc_exchange(kdc_exchange_dict, cname=None, realm=tgt.srealm, sname=sname, etypes=etypes) def setUp(self): super().setUp() self.do_asn1_print = global_asn1_print self.do_hexdump = global_hexdump def _format_spns(self, spns, creds=None): return map(lambda spn: self._format_spn(spn, creds), spns) def _format_spn(self, spn, creds=None): samdb = self.get_samdb() spn = spn.format(netbios_domain_name=samdb.domain_netbios_name(), dns_domain_name=samdb.domain_dns_name()) if creds is not None: account_name = creds.get_username() spn = spn.format(account=account_name) return spn def _get_creds(self, account_type): spns = self._format_spns(self.test_spns.values()) if account_type is self.AccountType.RODC: creds = self._mock_rodc_creds if creds is None: creds = self._get_mock_rodc_creds(spns) type(self)._mock_rodc_creds = creds else: creds = self.get_cached_creds( account_type=account_type, opts={ 'spn': spns }) return creds def _get_mock_rodc_creds(self, spns): rodc_ctx = self.get_mock_rodc_ctx() for spn in spns: spn = spn.format(account=rodc_ctx.myname) if spn not in rodc_ctx.SPNs: rodc_ctx.SPNs.append(spn) samdb = self.get_samdb() rodc_dn = ldb.Dn(samdb, rodc_ctx.acct_dn) msg = ldb.Message(rodc_dn) msg['servicePrincipalName'] = ldb.MessageElement( rodc_ctx.SPNs, ldb.FLAG_MOD_REPLACE, 'servicePrincipalName') samdb.modify(msg) creds = KerberosCredentials() creds.guess(self.get_lp()) creds.set_realm(rodc_ctx.realm.upper()) creds.set_domain(rodc_ctx.domain_name) creds.set_password(rodc_ctx.acct_pass) creds.set_username(rodc_ctx.myname) creds.set_workstation(rodc_ctx.samname) creds.set_dn(rodc_dn) creds.set_spn(rodc_ctx.SPNs) res = samdb.search(base=rodc_dn, scope=ldb.SCOPE_BASE, attrs=['msDS-KeyVersionNumber']) kvno = int(res[0].get('msDS-KeyVersionNumber', idx=0)) creds.set_kvno(kvno) keys = self.get_keys(creds) self.creds_set_keys(creds, keys) return creds if __name__ == "__main__": global_asn1_print = False global_hexdump = False import unittest unittest.main()