summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorGary Lockyer <gary@catalyst.net.nz>2018-10-09 11:09:20 +1300
committerAndrew Bartlett <abartlet@samba.org>2018-11-20 22:14:16 +0100
commit9a1277c1ecd76ab50ef0c509f6834ae9311c8daf (patch)
tree7c3c551ab2bcae1ec728e1868477b32cdd8406eb /python
parentad8bb6fcd08be28c40f2522d640333e9e69b7852 (diff)
downloadsamba-9a1277c1ecd76ab50ef0c509f6834ae9311c8daf.tar.gz
tests samr: Extra tests for samr_QueryDisplayInfo
Add extra tests to test the content returned by samr_QueryDisplayInfo, which is not tested for the ADDC. Also adds tests for the result caching added in the following commit. Signed-off-by: Gary Lockyer <gary@catalyst.net.nz> Reviewed-by: Andrew Bartlett <abartlet@samba.org>
Diffstat (limited to 'python')
-rw-r--r--python/samba/tests/dcerpc/sam.py390
1 files changed, 389 insertions, 1 deletions
diff --git a/python/samba/tests/dcerpc/sam.py b/python/samba/tests/dcerpc/sam.py
index 9c432e3a37e..d02827e8acc 100644
--- a/python/samba/tests/dcerpc/sam.py
+++ b/python/samba/tests/dcerpc/sam.py
@@ -19,9 +19,20 @@
"""Tests for samba.dcerpc.sam."""
-from samba.dcerpc import samr, security
+from samba.dcerpc import samr, security, lsa
from samba.tests import RpcInterfaceTestCase
+from samba.tests import env_loadparm, delete_force
+from samba.credentials import Credentials
+from samba.auth import system_session
+from samba.samdb import SamDB
+from samba.dsdb import (
+ ATYPE_NORMAL_ACCOUNT,
+ ATYPE_WORKSTATION_TRUST,
+ GTYPE_SECURITY_UNIVERSAL_GROUP,
+ GTYPE_SECURITY_GLOBAL_GROUP)
+from samba import generate_random_password
+import os
# FIXME: Pidl should be doing this for us
def toArray(handle, array, num_entries):
@@ -33,6 +44,32 @@ class SamrTests(RpcInterfaceTestCase):
def setUp(self):
super(SamrTests, self).setUp()
self.conn = samr.samr("ncalrpc:", self.get_loadparm())
+ self.open_samdb()
+ self.open_domain_handle()
+
+ #
+ # Open the samba database
+ #
+ def open_samdb(self):
+ self.lp = env_loadparm()
+ self.domain = os.environ["DOMAIN"]
+ self.creds = Credentials()
+ self.creds.guess(self.lp)
+ self.session = system_session()
+ self.samdb = SamDB(
+ session_info=self.session, credentials=self.creds, lp=self.lp)
+
+ #
+ # Open a SAMR Domain handle
+ def open_domain_handle(self):
+ self.handle = self.conn.Connect2(
+ None, security.SEC_FLAG_MAXIMUM_ALLOWED)
+
+ self.domain_sid = self.conn.LookupDomain(
+ self.handle, lsa.String(self.domain))
+
+ self.domain_handle = self.conn.OpenDomain(
+ self.handle, security.SEC_FLAG_MAXIMUM_ALLOWED, self.domain_sid)
def test_connect5(self):
(level, info, handle) = self.conn.Connect5(None, 0, 1, samr.ConnectInfo1())
@@ -45,3 +82,354 @@ class SamrTests(RpcInterfaceTestCase):
handle = self.conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
domains = toArray(*self.conn.EnumDomains(handle, 0, 4294967295))
self.conn.Close(handle)
+
+ # Create groups based on the id list supplied, the id is used to
+ # form a unique name and description.
+ #
+ # returns a list of the created dn's, which can be passed to delete_dns
+ # to clean up after the test has run.
+ def create_groups(self, ids):
+ dns = []
+ for i in ids:
+ name = "SAMR_GRP%d" % i
+ dn = "cn=%s,cn=Users,%s" % (name, self.samdb.domain_dn())
+ delete_force(self.samdb, dn)
+
+ self.samdb.newgroup(name)
+ dns.append(dn)
+ return dns
+
+ # Create user accounts based on the id list supplied, the id is used to
+ # form a unique name and description.
+ #
+ # returns a list of the created dn's, which can be passed to delete_dns
+ # to clean up after the test has run.
+ def create_users(self, ids):
+ dns = []
+ for i in ids:
+ name = "SAMR_USER%d" % i
+ dn = "cn=%s,CN=USERS,%s" % (name, self.samdb.domain_dn())
+ delete_force(self.samdb, dn)
+
+ # We only need the user to exist, we don't need a password
+ self.samdb.newuser(
+ name,
+ password=None,
+ setpassword=False,
+ description="Description for " + name,
+ givenname="given%dname" % i,
+ surname="surname%d" % i)
+ dns.append(dn)
+ return dns
+
+ # Create computer accounts based on the id list supplied, the id is used to
+ # form a unique name and description.
+ #
+ # returns a list of the created dn's, which can be passed to delete_dns
+ # to clean up after the test has run.
+ def create_computers(self, ids):
+ dns = []
+ for i in ids:
+ name = "SAMR_CMP%d" % i
+ dn = "cn=%s,cn=COMPUTERS,%s" % (name, self.samdb.domain_dn())
+ delete_force(self.samdb, dn)
+
+ self.samdb.newcomputer(name, description="Description of " + name)
+ dns.append(dn)
+ return dns
+
+ # Delete the specified dn's.
+ #
+ # Used to clean up entries created by individual tests.
+ #
+ def delete_dns(self, dns):
+ for dn in dns:
+ delete_force(self.samdb, dn)
+
+ # Common tests for QueryDisplayInfo
+ #
+ def _test_QueryDisplayInfo(
+ self, level, check_results, select, attributes, add_elements):
+ #
+ # Get the expected results by querying the samdb database directly.
+ # We do this rather than use a list of expected results as this runs
+ # with other tests so we do not have a known fixed list of elements
+ expected = self.samdb.search(expression=select, attrs=attributes)
+ self.assertTrue(len(expected) > 0)
+
+ #
+ # Perform QueryDisplayInfo with max results greater than the expected
+ # number of results.
+ (ts, rs, actual) = self.conn.QueryDisplayInfo(
+ self.domain_handle, level, 0, 1024, 4294967295)
+
+ self.assertEquals(len(expected), ts)
+ self.assertEquals(len(expected), rs)
+ check_results(expected, actual.entries)
+
+ #
+ # Perform QueryDisplayInfo with max results set to the number of
+ # results returned from the first query, should return the same results
+ (ts1, rs1, actual1) = self.conn.QueryDisplayInfo(
+ self.domain_handle, level, 0, rs, 4294967295)
+ self.assertEquals(ts, ts1)
+ self.assertEquals(rs, rs1)
+ check_results(expected, actual1.entries)
+
+ #
+ # Perform QueryDisplayInfo and get the last two results.
+ # Note: We are assuming there are at least three entries
+ self.assertTrue(ts > 2)
+ (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
+ self.domain_handle, level, (ts - 2), 2, 4294967295)
+ self.assertEquals(ts, ts2)
+ self.assertEquals(2, rs2)
+ check_results(list(expected)[-2:], actual2.entries)
+
+ #
+ # Perform QueryDisplayInfo and get the first two results.
+ # Note: We are assuming there are at least three entries
+ self.assertTrue(ts > 2)
+ (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
+ self.domain_handle, level, 0, 2, 4294967295)
+ self.assertEquals(ts, ts2)
+ self.assertEquals(2, rs2)
+ check_results(list(expected)[:2], actual2.entries)
+
+ #
+ # Perform QueryDisplayInfo and get two results in the middle of the
+ # list i.e. not the first or the last entry.
+ # Note: We are assuming there are at least four entries
+ self.assertTrue(ts > 3)
+ (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
+ self.domain_handle, level, 1, 2, 4294967295)
+ self.assertEquals(ts, ts2)
+ self.assertEquals(2, rs2)
+ check_results(list(expected)[1:2], actual2.entries)
+
+ #
+ # To check that cached values are being returned rather than the
+ # results being re-read from disk we add elements, and request all
+ # but the first result.
+ #
+ dns = add_elements([1000, 1002, 1003, 1004])
+
+ #
+ # Perform QueryDisplayInfo and get all but the first result.
+ # We should be using the cached results so the entries we just added
+ # should not be present
+ (ts3, rs3, actual3) = self.conn.QueryDisplayInfo(
+ self.domain_handle, level, 1, 1024, 4294967295)
+ self.assertEquals(ts, ts3)
+ self.assertEquals(len(expected) - 1, rs3)
+ check_results(list(expected)[1:], actual3.entries)
+
+ #
+ # Perform QueryDisplayInfo and get all the results.
+ # As the start index is zero we should reread the data from disk and
+ # the added entries should be there
+ new = self.samdb.search(expression=select, attrs=attributes)
+ (ts4, rs4, actual4) = self.conn.QueryDisplayInfo(
+ self.domain_handle, level, 0, 1024, 4294967295)
+ self.assertEquals(len(expected) + len(dns), ts4)
+ self.assertEquals(len(expected) + len(dns), rs4)
+ check_results(new, actual4.entries)
+
+ # Delete the added DN's and query all but the first entry.
+ # This should ensure the cached results are used and that the
+ # missing entry code is triggered.
+ self.delete_dns(dns)
+ (ts5, rs5, actual5) = self.conn.QueryDisplayInfo(
+ self.domain_handle, level, 1, 1024, 4294967295)
+ self.assertEquals(len(expected) + len(dns), ts5)
+ # The deleted results will be filtered from the result set so should
+ # be missing from the returned results.
+ # Note: depending on the GUID order, the first result in the cache may
+ # be a deleted entry, in which case the results will contain all
+ # the expected elements, otherwise the first expected result will
+ # be missing.
+ if rs5 == len(expected):
+ check_results(expected, actual5.entries)
+ elif rs5 == (len(expected) - 1):
+ check_results(list(expected)[1:], actual5.entries)
+ else:
+ self.fail("Incorrect number of entries {0}".format(rs5))
+
+ #
+ # Perform QueryDisplayInfo specifying an index past the end of the
+ # available data.
+ # Should return no data.
+ (ts6, rs6, actual6) = self.conn.QueryDisplayInfo(
+ self.domain_handle, level, ts5, 1, 4294967295)
+ self.assertEquals(ts5, ts6)
+ self.assertEquals(0, rs6)
+
+ self.conn.Close(self.handle)
+
+ # Test for QueryDisplayInfo, Level 1
+ # Returns the sAMAccountName, displayName and description for all
+ # the user accounts.
+ #
+ def test_QueryDisplayInfo_level_1(self):
+ def check_results(expected, actual):
+ # Assume the QueryDisplayInfo and ldb.search return their results
+ # in the same order
+ for (e, a) in zip(expected, actual):
+ self.assertTrue(isinstance(a, samr.DispEntryGeneral))
+ self.assertEquals(str(e["sAMAccountName"]),
+ str(a.account_name))
+
+ # The displayName and description are optional.
+ # In the expected results they will be missing, in
+ # samr.DispEntryGeneral the corresponding attribute will have a
+ # length of zero.
+ #
+ if a.full_name.length == 0:
+ self.assertFalse("displayName" in e)
+ else:
+ self.assertEquals(str(e["displayName"]), str(a.full_name))
+
+ if a.description.length == 0:
+ self.assertFalse("description" in e)
+ else:
+ self.assertEquals(str(e["description"]),
+ str(a.description))
+ # Create four user accounts
+ # to ensure that we have the minimum needed for the tests.
+ dns = self.create_users([1, 2, 3, 4])
+
+ select = "(&(objectclass=user)(sAMAccountType={0}))".format(
+ ATYPE_NORMAL_ACCOUNT)
+ attributes = ["sAMAccountName", "displayName", "description"]
+ self._test_QueryDisplayInfo(
+ 1, check_results, select, attributes, self.create_users)
+
+ self.delete_dns(dns)
+
+ # Test for QueryDisplayInfo, Level 2
+ # Returns the sAMAccountName and description for all
+ # the computer accounts.
+ #
+ def test_QueryDisplayInfo_level_2(self):
+ def check_results(expected, actual):
+ # Assume the QueryDisplayInfo and ldb.search return their results
+ # in the same order
+ for (e, a) in zip(expected, actual):
+ self.assertTrue(isinstance(a, samr.DispEntryFull))
+ self.assertEquals(str(e["sAMAccountName"]),
+ str(a.account_name))
+
+ # The description is optional.
+ # In the expected results they will be missing, in
+ # samr.DispEntryGeneral the corresponding attribute will have a
+ # length of zero.
+ #
+ if a.description.length == 0:
+ self.assertFalse("description" in e)
+ else:
+ self.assertEquals(str(e["description"]),
+ str(a.description))
+
+ # Create four computer accounts
+ # to ensure that we have the minimum needed for the tests.
+ dns = self.create_computers([1, 2, 3, 4])
+
+ select = "(&(objectclass=user)(sAMAccountType={0}))".format(
+ ATYPE_WORKSTATION_TRUST)
+ attributes = ["sAMAccountName", "description"]
+ self._test_QueryDisplayInfo(
+ 2, check_results, select, attributes, self.create_computers)
+
+ self.delete_dns(dns)
+
+ # Test for QueryDisplayInfo, Level 3
+ # Returns the sAMAccountName and description for all
+ # the groups.
+ #
+ def test_QueryDisplayInfo_level_3(self):
+ def check_results(expected, actual):
+ # Assume the QueryDisplayInfo and ldb.search return their results
+ # in the same order
+ for (e, a) in zip(expected, actual):
+ self.assertTrue(isinstance(a, samr.DispEntryFullGroup))
+ self.assertEquals(str(e["sAMAccountName"]),
+ str(a.account_name))
+
+ # The description is optional.
+ # In the expected results they will be missing, in
+ # samr.DispEntryGeneral the corresponding attribute will have a
+ # length of zero.
+ #
+ if a.description.length == 0:
+ self.assertFalse("description" in e)
+ else:
+ self.assertEquals(str(e["description"]),
+ str(a.description))
+
+ # Create four groups
+ # to ensure that we have the minimum needed for the tests.
+ dns = self.create_groups([1, 2, 3, 4])
+
+ select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
+ GTYPE_SECURITY_UNIVERSAL_GROUP,
+ GTYPE_SECURITY_GLOBAL_GROUP)
+ attributes = ["sAMAccountName", "description"]
+ self._test_QueryDisplayInfo(
+ 3, check_results, select, attributes, self.create_groups)
+
+ self.delete_dns(dns)
+
+ # Test for QueryDisplayInfo, Level 4
+ # Returns the sAMAccountName (as an ASCII string)
+ # for all the user accounts.
+ #
+ def test_QueryDisplayInfo_level_4(self):
+ def check_results(expected, actual):
+ # Assume the QueryDisplayInfo and ldb.search return their results
+ # in the same order
+ for (e, a) in zip(expected, actual):
+ self.assertTrue(isinstance(a, samr.DispEntryAscii))
+ self.assertTrue(
+ isinstance(a.account_name, lsa.AsciiStringLarge))
+ self.assertEquals(
+ str(e["sAMAccountName"]), str(a.account_name.string))
+
+ # Create four user accounts
+ # to ensure that we have the minimum needed for the tests.
+ dns = self.create_users([1, 2, 3, 4])
+
+ select = "(&(objectclass=user)(sAMAccountType={0}))".format(
+ ATYPE_NORMAL_ACCOUNT)
+ attributes = ["sAMAccountName", "displayName", "description"]
+ self._test_QueryDisplayInfo(
+ 4, check_results, select, attributes, self.create_users)
+
+ self.delete_dns(dns)
+
+ # Test for QueryDisplayInfo, Level 5
+ # Returns the sAMAccountName (as an ASCII string)
+ # for all the groups.
+ #
+ def test_QueryDisplayInfo_level_5(self):
+ def check_results(expected, actual):
+ # Assume the QueryDisplayInfo and ldb.search return their results
+ # in the same order
+ for (e, a) in zip(expected, actual):
+ self.assertTrue(isinstance(a, samr.DispEntryAscii))
+ self.assertTrue(
+ isinstance(a.account_name, lsa.AsciiStringLarge))
+ self.assertEquals(
+ str(e["sAMAccountName"]), str(a.account_name.string))
+
+ # Create four groups
+ # to ensure that we have the minimum needed for the tests.
+ dns = self.create_groups([1, 2, 3, 4])
+
+ select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
+ GTYPE_SECURITY_UNIVERSAL_GROUP,
+ GTYPE_SECURITY_GLOBAL_GROUP)
+ attributes = ["sAMAccountName", "description"]
+ self._test_QueryDisplayInfo(
+ 5, check_results, select, attributes, self.create_groups)
+
+ self.delete_dns(dns)