summaryrefslogtreecommitdiff
path: root/python/samba/common.py
blob: 8876e4f4faa0bfae3cb2eb68c7c6ba4987090f1a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Samba common functions
#
# Copyright (C) Matthieu Patou <mat@matws.net>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#


import ldb
from samba import dsdb
from samba.ndr import ndr_pack
from samba.dcerpc import misc
import binascii

from samba.compat import PY3


if PY3:
    # cmp() exists only in Python 2
    def cmp(a, b):
        return (a > b) - (a < b)

    raw_input = input


def confirm(msg, forced=False, allow_all=False):
    """confirm an action with the user

    :param msg: A string to print to the user
    :param forced: Are the answer forced
    """
    if forced:
        print("%s [YES]" % msg)
        return True

    mapping = {
        'Y': True,
        'YES': True,
        '': False,
        'N': False,
        'NO': False,
    }

    prompt = '[y/N]'

    if allow_all:
        mapping['ALL'] = 'ALL'
        mapping['NONE'] = 'NONE'
        prompt = '[y/N/all/none]'

    while True:
        v = raw_input(msg + ' %s ' % prompt)
        v = v.upper()
        if v in mapping:
            return mapping[v]
        print("Unknown response '%s'" % v)


def normalise_int32(ivalue):
    '''normalise a ldap integer to signed 32 bit'''
    if int(ivalue) & 0x80000000 and int(ivalue) > 0:
        return str(int(ivalue) - 0x100000000)
    return str(ivalue)


class dsdb_Dn(object):
    '''a class for binary DN'''

    def __init__(self, samdb, dnstring, syntax_oid=None):
        '''create a dsdb_Dn'''
        if syntax_oid is None:
            # auto-detect based on string
            if dnstring.startswith("B:"):
                syntax_oid = dsdb.DSDB_SYNTAX_BINARY_DN
            elif dnstring.startswith("S:"):
                syntax_oid = dsdb.DSDB_SYNTAX_STRING_DN
            else:
                syntax_oid = dsdb.DSDB_SYNTAX_OR_NAME
        if syntax_oid in [dsdb.DSDB_SYNTAX_BINARY_DN, dsdb.DSDB_SYNTAX_STRING_DN]:
            # it is a binary DN
            colons = dnstring.split(':')
            if len(colons) < 4:
                raise RuntimeError("Invalid DN %s" % dnstring)
            prefix_len = 4 + len(colons[1]) + int(colons[1])
            self.prefix = dnstring[0:prefix_len]
            self.binary = self.prefix[3 + len(colons[1]):-1]
            self.dnstring = dnstring[prefix_len:]
        else:
            self.dnstring = dnstring
            self.prefix = ''
            self.binary = ''
        self.dn = ldb.Dn(samdb, self.dnstring)

    def __str__(self):
        return self.prefix + str(self.dn.extended_str(mode=1))

    def __cmp__(self, other):
        ''' compare dsdb_Dn values similar to parsed_dn_compare()'''
        dn1 = self
        dn2 = other
        guid1 = dn1.dn.get_extended_component("GUID")
        guid2 = dn2.dn.get_extended_component("GUID")

        v = cmp(guid1, guid2)
        if v != 0:
            return v
        v = cmp(dn1.binary, dn2.binary)
        return v

    # In Python3, __cmp__ is replaced by these 6 methods
    def __eq__(self, other):
        return self.__cmp__(other) == 0

    def __ne__(self, other):
        return self.__cmp__(other) != 0

    def __lt__(self, other):
        return self.__cmp__(other) < 0

    def __le__(self, other):
        return self.__cmp__(other) <= 0

    def __gt__(self, other):
        return self.__cmp__(other) > 0

    def __ge__(self, other):
        return self.__cmp__(other) >= 0

    def get_binary_integer(self):
        '''return binary part of a dsdb_Dn as an integer, or None'''
        if self.prefix == '':
            return None
        return int(self.binary, 16)

    def get_bytes(self):
        '''return binary as a byte string'''
        return binascii.unhexlify(self.binary)