# -*- coding: utf-8 -*- # Copyright Andrew Bartlett 2018 # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # from __future__ import print_function import optparse import samba import samba.getopt as options import sys import os import time from samba.auth import system_session from samba.tests import TestCase import ldb ERRCODE_ENTRY_EXISTS = 68 ERRCODE_OPERATIONS_ERROR = 1 ERRCODE_INVALID_VALUE = 21 ERRCODE_CLASS_VIOLATION = 65 parser = optparse.OptionParser("{0} ".format(sys.argv[0])) sambaopts = options.SambaOptions(parser) # use command line creds if available credopts = options.CredentialsOptions(parser) parser.add_option_group(credopts) parser.add_option("-v", action="store_true", dest="verbose", help="print successful expression outputs") opts, args = parser.parse_args() if len(args) < 1: parser.print_usage() sys.exit(1) lp = sambaopts.get_loadparm() creds = credopts.get_credentials(lp) # Set properly at end of file. host = None global ou_count ou_count = 0 class ComplexExpressionTests(TestCase): # Using setUpClass instead of setup because we're not modifying any # records in the tests @classmethod def setUpClass(cls): super(ComplexExpressionTests, cls).setUpClass() cls.samdb = samba.samdb.SamDB(host, lp=lp, session_info=system_session(), credentials=creds) ou_name = "ComplexExprTest" cls.base_dn = "OU={0},{1}".format(ou_name, cls.samdb.domain_dn()) try: cls.samdb.delete(cls.base_dn, ["tree_delete:1"]) except: pass try: cls.samdb.create_ou(cls.base_dn) except ldb.LdbError as e: if e.args[0] == ERRCODE_ENTRY_EXISTS: print(('test ou {ou} already exists. Delete with ' '"samba-tool group delete OU={ou} ' '--force-subtree-delete"').format(ou=ou_name)) raise e cls.name_template = "testuser{0}" cls.default_n = 10 # These fields are carefully hand-picked from the schema. They have # syntax and handling appropriate for our test structure. cls.largeint_f = "accountExpires" cls.str_f = "accountNameHistory" cls.int_f = "flags" cls.enum_f = "preferredDeliveryMethod" cls.time_f = "msTSExpireDate" cls.ranged_int_f = "countryCode" @classmethod def tearDownClass(cls): cls.samdb.delete(cls.base_dn, ["tree_delete:1"]) # Make test OU containing users with field=val for each val def make_test_objects(self, field, vals): global ou_count ou_count += 1 ou_dn = "OU=testou{0},{1}".format(ou_count, self.base_dn) self.samdb.create_ou(ou_dn) ldap_objects = [{"dn": "CN=testuser{0},{1}".format(n, ou_dn), "name": self.name_template.format(n), "objectClass": "user", field: n} for n in vals] for ldap_object in ldap_objects: # It's useful to keep appropriate python types in the ldap_object # dict but smdb's 'add' function expects strings. stringed_ldap_object = {k: str(v) for (k, v) in ldap_object.items()} try: self.samdb.add(stringed_ldap_object) except ldb.LdbError as e: print("failed to add %s" % (stringed_ldap_object)) raise e return ou_dn, ldap_objects # Run search expr and print out time. This function should be used for # almost all searching. def time_ldap_search(self, expr, dn): time_taken = 0 try: start_time = time.time() res = self.samdb.search(base=dn, scope=ldb.SCOPE_SUBTREE, expression=expr) time_taken = time.time() - start_time except Exception as e: print("failed expr " + expr) raise e print("{0} took {1}s".format(expr, time_taken)) return res, time_taken # Take an ldap expression and an equivalent python expression. # Run and time the ldap expression and compare the result to the python # expression run over the a list of ldap_object dicts. def assertLDAPQuery(self, ldap_expr, ou_dn, py_expr, ldap_objects): # run (and time) the LDAP search expression over the DB res, time_taken = self.time_ldap_search(ldap_expr, ou_dn) results = {str(row.get('name')[0]) for row in res} # build the set of expected results by evaluating the python-equivalent # of the search expression over the same set of objects expected_results = set() for ldap_object in ldap_objects: try: final_expr = py_expr.format(**ldap_object) except KeyError: # If the format on the py_expr hits a key error, then # ldap_object doesn't have the field, so it shouldn't match. continue if eval(final_expr): expected_results.add(str(ldap_object['name'])) self.assertEqual(results, expected_results) if opts.verbose: ldap_object_names = {l['name'] for l in ldap_objects} excluded = ldap_object_names - results excluded = "\n ".join(excluded) or "[NOTHING]" returned = "\n ".join(expected_results) or "[NOTHING]" print("PASS: Expression {0} took {1}s and returned:" "\n {2}\n" "Excluded:\n {3}\n".format(ldap_expr, time_taken, returned, excluded)) # Basic integer range test def test_int_range(self, field=None): n = self.default_n field = field or self.int_f ou_dn, ldap_objects = self.make_test_objects(field, range(n)) expr = "(&(%s>=%s)(%s<=%s))" % (field, n-1, field, n+1) py_expr = "%d <= {%s} <= %d" % (n-1, field, n+1) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) half_n = int(n/2) expr = "(%s<=%s)" % (field, half_n) py_expr = "{%s} <= %d" % (field, half_n) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) expr = "(%s>=%s)" % (field, half_n) py_expr = "{%s} >= %d" % (field, half_n) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) # Same test again for largeint and enum def test_largeint_range(self): self.test_int_range(self.largeint_f) def test_enum_range(self): self.test_int_range(self.enum_f) # Special range test for integer field with upper and lower bounds defined. # The bounds are checked on insertion, not search, so we should be able # to compare to a constant that is outside bounds. def test_ranged_int_range(self): field = self.ranged_int_f ubound = 2**16 width = 8 vals = list(range(ubound-width, ubound)) ou_dn, ldap_objects = self.make_test_objects(field, vals) # Check <= value above overflow returns all vals expr = "(%s<=%d)" % (field, ubound+5) py_expr = "{%s} <= %d" % (field, ubound+5) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) # Test range also works for time fields def test_time_range(self): n = self.default_n field = self.time_f n = self.default_n width = int(n/2) base_time = 20050116175514 time_range = [base_time + t for t in range(-width, width)] time_range = [str(t) + ".0Z" for t in time_range] ou_dn, ldap_objects = self.make_test_objects(field, time_range) expr = "(%s<=%s)" % (field, str(base_time) + ".0Z") py_expr = 'int("{%s}"[:-3]) <= %d' % (field, base_time) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) expr = "(&(%s>=%s)(%s<=%s))" % (field, str(base_time-1) + ".0Z", field, str(base_time+1) + ".0Z") py_expr = '%d <= int("{%s}"[:-3]) <= %d' % (base_time-1, field, base_time+1) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) # Run each comparison op on a simple test set. Time taken will be printed. def test_int_single_cmp_op_speeds(self, field=None): n = self.default_n field = field or self.int_f ou_dn, ldap_objects = self.make_test_objects(field, range(n)) comp_ops = ['=', '<=', '>='] py_comp_ops = ['==', '<=', '>='] exprs = ["(%s%s%d)" % (field, c, n) for c in comp_ops] py_exprs = ["{%s}%s%d" % (field, c, n) for c in py_comp_ops] for expr, py_expr in zip(exprs, py_exprs): self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) def test_largeint_single_cmp_op_speeds(self): self.test_int_single_cmp_op_speeds(self.largeint_f) def test_enum_single_cmp_op_speeds(self): self.test_int_single_cmp_op_speeds(self.enum_f) # Check strings are ordered using a naive ordering. def test_str_ordering(self): field = self.str_f a_ord = ord('A') n = 10 str_range = ['abc{0}d'.format(chr(c)) for c in range(a_ord, a_ord+n)] ou_dn, ldap_objects = self.make_test_objects(field, str_range) half_n = int(a_ord + n/2) # Basic <= and >= statements expr = "(%s>=abc%s)" % (field, chr(half_n)) py_expr = "'{%s}' >= 'abc%s'" % (field, chr(half_n)) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) expr = "(%s<=abc%s)" % (field, chr(half_n)) py_expr = "'{%s}' <= 'abc%s'" % (field, chr(half_n)) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) # String range expr = "(&(%s>=abc%s)(%s<=abc%s))" % (field, chr(half_n-2), field, chr(half_n+2)) py_expr = "'abc%s' <= '{%s}' <= 'abc%s'" % (chr(half_n-2), field, chr(half_n+2)) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) # Integers treated as string expr = "(%s>=1)" % (field) py_expr = "'{%s}' >= '1'" % (field) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) # Windows returns nothing for invalid expressions. Expected fail on samba. def test_invalid_expressions(self, field=None): field = field or self.int_f n = self.default_n ou_dn, ldap_objects = self.make_test_objects(field, list(range(n))) int_expressions = ["(%s>=abc)", "(%s<=abc)", "(%s=abc)"] for expr in int_expressions: expr = expr % (field) self.assertLDAPQuery(expr, ou_dn, "False", ldap_objects) def test_largeint_invalid_expressions(self): self.test_invalid_expressions(self.largeint_f) def test_enum_invalid_expressions(self): self.test_invalid_expressions(self.enum_f) def test_case_insensitive(self): str_range = ["äbc"+str(n) for n in range(10)] ou_dn, ldap_objects = self.make_test_objects(self.str_f, str_range) expr = "(%s=äbc1)" % (self.str_f) pyexpr = '"{%s}"=="äbc1"' % (self.str_f) self.assertLDAPQuery(expr, ou_dn, pyexpr, ldap_objects) expr = "(%s=ÄbC1)" % (self.str_f) self.assertLDAPQuery(expr, ou_dn, pyexpr, ldap_objects) # Check negative numbers can be entered and compared def test_negative_cmp(self, field=None): field = field or self.int_f width = 6 around_zero = list(range(-width, width)) ou_dn, ldap_objects = self.make_test_objects(field, around_zero) expr = "(%s>=-3)" % (field) py_expr = "{%s} >= -3" % (field) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) def test_negative_cmp_largeint(self): self.test_negative_cmp(self.largeint_f) def test_negative_cmp_enum(self): self.test_negative_cmp(self.enum_f) # Check behaviour on insertion and comparison of zero-prefixed numbers. # Samba errors on insertion, Windows strips the leading zeroes. def test_zero_prefix(self, field=None): field = field or self.int_f # Test comparison with 0-prefixed constants. n = self.default_n ou_dn, ldap_objects = self.make_test_objects(field, list(range(n))) expr = "(%s>=00%d)" % (field, n/2) py_expr = "{%s} >= %d" % (field, n/2) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) # Delete the test OU so we don't mix it up with the next one. self.samdb.delete(ou_dn, ["tree_delete:1"]) # Try inserting 0-prefixed numbers, check it fails. zero_pref_nums = ['00'+str(num) for num in range(n)] try: ou_dn, ldap_objects = self.make_test_objects(field, zero_pref_nums) except ldb.LdbError as e: if e.args[0] != ERRCODE_INVALID_VALUE: raise e return # Samba doesn't get this far - the exception is raised. Windows allows # the insertion and removes the leading 0s as tested below. # Either behaviour is fine. print("LDAP allowed insertion of 0-prefixed nums for field " + field) res = self.samdb.search(base=ou_dn, scope=ldb.SCOPE_SUBTREE, expression="(objectClass=user)") returned_nums = [str(r.get(field)[0]) for r in res] expect = [str(n) for n in range(n)] self.assertEqual(set(returned_nums), set(expect)) expr = "(%s>=%d)" % (field, n/2) py_expr = "{%s} >= %d" % (field, n/2) for ldap_object in ldap_objects: ldap_object[field] = int(ldap_object[field]) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) def test_zero_prefix_largeint(self): self.test_zero_prefix(self.largeint_f) def test_zero_prefix_enum(self): self.test_zero_prefix(self.enum_f) # Check integer overflow is handled as best it can be. def test_int_overflow(self, field=None, of=None): field = field or self.int_f of = of or 2**31-1 width = 8 vals = list(range(of-width, of+width)) ou_dn, ldap_objects = self.make_test_objects(field, vals) # Check ">=overflow" doesn't return vals past overflow expr = "(%s>=%d)" % (field, of-3) py_expr = "%d <= {%s} <= %d" % (of-3, field, of) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) # "<=overflow" returns everything expr = "(%s<=%d)" % (field, of) py_expr = "True" self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) # Values past overflow should be negative expr = "(&(%s<=%d)(%s>=0))" % (field, of, field) py_expr = "{%s} <= %d" % (field, of) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) expr = "(%s<=0)" % (field) py_expr = "{%s} >= %d" % (field, of+1) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) # Get the values back out and check vals past overflow are negative. res = self.samdb.search(base=ou_dn, scope=ldb.SCOPE_SUBTREE, expression="(objectClass=user)") returned_nums = [str(r.get(field)[0]) for r in res] # Note: range(a,b) == [a..b-1] (confusing) up_to_overflow = list(range(of-width, of+1)) negatives = list(range(-of-1, -of+width-2)) expect = [str(n) for n in up_to_overflow + negatives] self.assertEqual(set(returned_nums), set(expect)) def test_enum_overflow(self): self.test_int_overflow(self.enum_f, 2**31-1) # Check cmp works on uSNChanged. We can't insert uSNChanged vals, they get # added automatically so we'll just insert some objects and go with what # we get. def test_usnchanged(self): field = "uSNChanged" n = 10 # Note we can't actually set uSNChanged via LDAP (LDB ignores it), # so the input val range doesn't matter here ou_dn, _ = self.make_test_objects(field, list(range(n))) # Get the assigned uSNChanged values res = self.samdb.search(base=ou_dn, scope=ldb.SCOPE_SUBTREE, expression="(objectClass=user)") # Our vals got ignored so make ldap_objects from search result ldap_objects = [{'name': str(r['name'][0]), field: int(r[field][0])} for r in res] # Get the median val and use as the number in the test search expr. nums = [l[field] for l in ldap_objects] nums = list(sorted(nums)) search_num = nums[int(len(nums)/2)] expr = "(&(%s<=%d)(objectClass=user))" % (field, search_num) py_expr = "{%s} <= %d" % (field, search_num) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) expr = "(&(%s>=%d)(objectClass=user))" % (field, search_num) py_expr = "{%s} >= %d" % (field, search_num) self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects) # If we're called independently then import subunit, get host from first # arg and run. Otherwise, subunit ran us so just set host from env. # We always try to run over LDAP rather than direct file, so that # search timings are not impacted by opening and closing the tdb file. if __name__ == "__main__": from samba.tests.subunitrun import TestProgram host = args[0] if "://" not in host: if os.path.isfile(host): host = "tdb://%s" % host else: host = "ldap://%s" % host TestProgram(module=__name__) else: host = "ldap://" + os.getenv("SERVER")