diff options
author | Nick Coghlan <ncoghlan@gmail.com> | 2012-07-08 00:45:33 +1000 |
---|---|---|
committer | Nick Coghlan <ncoghlan@gmail.com> | 2012-07-08 00:45:33 +1000 |
commit | 3008ec070f40974790f6a85f7ceabf0876950abf (patch) | |
tree | bc7c86ec1c468fc57fdee384b963f1c18856ba69 /Lib | |
parent | 9a9c28ce7a051b37a91e4fc7aef70bcdcda25047 (diff) | |
download | cpython-git-3008ec070f40974790f6a85f7ceabf0876950abf.tar.gz |
Issue 14814: Ensure ordering semantics across all 3 entity types in ipaddress are consistent and well-defined
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/ipaddress.py | 136 | ||||
-rw-r--r-- | Lib/test/test_ipaddress.py | 158 |
2 files changed, 166 insertions, 128 deletions
diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py index b1e07fc992..2019009551 100644 --- a/Lib/ipaddress.py +++ b/Lib/ipaddress.py @@ -12,7 +12,7 @@ __version__ = '1.0' import struct - +import functools IPV4LENGTH = 32 IPV6LENGTH = 128 @@ -405,7 +405,38 @@ def get_mixed_type_key(obj): return NotImplemented -class _IPAddressBase: +class _TotalOrderingMixin: + # Helper that derives the other comparison operations from + # __lt__ and __eq__ + def __eq__(self, other): + raise NotImplementedError + def __ne__(self, other): + equal = self.__eq__(other) + if equal is NotImplemented: + return NotImplemented + return not equal + def __lt__(self, other): + raise NotImplementedError + def __le__(self, other): + less = self.__lt__(other) + if less is NotImplemented or not less: + return self.__eq__(other) + return less + def __gt__(self, other): + less = self.__lt__(other) + if less is NotImplemented: + return NotImplemented + equal = self.__eq__(other) + if equal is NotImplemented: + return NotImplemented + return not (less or equal) + def __ge__(self, other): + less = self.__lt__(other) + if less is NotImplemented: + return NotImplemented + return not less + +class _IPAddressBase(_TotalOrderingMixin): """The mother class.""" @@ -465,7 +496,6 @@ class _IPAddressBase: prefixlen = self._prefixlen return self._string_from_ip_int(self._ip_int_from_prefix(prefixlen)) - class _BaseAddress(_IPAddressBase): """A generic IP object. @@ -493,24 +523,6 @@ class _BaseAddress(_IPAddressBase): except AttributeError: return NotImplemented - def __ne__(self, other): - eq = self.__eq__(other) - if eq is NotImplemented: - return NotImplemented - return not eq - - def __le__(self, other): - gt = self.__gt__(other) - if gt is NotImplemented: - return NotImplemented - return not gt - - def __ge__(self, other): - lt = self.__lt__(other) - if lt is NotImplemented: - return NotImplemented - return not lt - def __lt__(self, other): if self._version != other._version: raise TypeError('%s and %s are not of the same version' % ( @@ -522,17 +534,6 @@ class _BaseAddress(_IPAddressBase): return self._ip < other._ip return False - def __gt__(self, other): - if self._version != other._version: - raise TypeError('%s and %s are not of the same version' % ( - self, other)) - if not isinstance(other, _BaseAddress): - raise TypeError('%s and %s are not of the same type' % ( - self, other)) - if self._ip != other._ip: - return self._ip > other._ip - return False - # Shorthand for Integer addition and subtraction. This is not # meant to ever support addition/subtraction of addresses. def __add__(self, other): @@ -625,31 +626,6 @@ class _BaseNetwork(_IPAddressBase): return self.netmask < other.netmask return False - def __gt__(self, other): - if self._version != other._version: - raise TypeError('%s and %s are not of the same version' % ( - self, other)) - if not isinstance(other, _BaseNetwork): - raise TypeError('%s and %s are not of the same type' % ( - self, other)) - if self.network_address != other.network_address: - return self.network_address > other.network_address - if self.netmask != other.netmask: - return self.netmask > other.netmask - return False - - def __le__(self, other): - gt = self.__gt__(other) - if gt is NotImplemented: - return NotImplemented - return not gt - - def __ge__(self, other): - lt = self.__lt__(other) - if lt is NotImplemented: - return NotImplemented - return not lt - def __eq__(self, other): try: return (self._version == other._version and @@ -658,12 +634,6 @@ class _BaseNetwork(_IPAddressBase): except AttributeError: return NotImplemented - def __ne__(self, other): - eq = self.__eq__(other) - if eq is NotImplemented: - return NotImplemented - return not eq - def __hash__(self): return hash(int(self.network_address) ^ int(self.netmask)) @@ -1292,11 +1262,27 @@ class IPv4Interface(IPv4Address): self.network.prefixlen) def __eq__(self, other): + address_equal = IPv4Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal try: - return (IPv4Address.__eq__(self, other) and - self.network == other.network) + return self.network == other.network except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv4Address.__lt__(self, other) + if address_less is NotImplemented: return NotImplemented + try: + return self.network < other.network + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False def __hash__(self): return self._ip ^ self._prefixlen ^ int(self.network.network_address) @@ -1928,11 +1914,27 @@ class IPv6Interface(IPv6Address): self.network.prefixlen) def __eq__(self, other): + address_equal = IPv6Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal try: - return (IPv6Address.__eq__(self, other) and - self.network == other.network) + return self.network == other.network except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv6Address.__lt__(self, other) + if address_less is NotImplemented: return NotImplemented + try: + return self.network < other.network + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False def __hash__(self): return self._ip ^ self._prefixlen ^ int(self.network.network_address) diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py index 417c98677f..5aaf736740 100644 --- a/Lib/test/test_ipaddress.py +++ b/Lib/test/test_ipaddress.py @@ -415,6 +415,93 @@ class FactoryFunctionErrors(ErrorReporting): self.assertFactoryError(ipaddress.ip_network, "network") +class ComparisonTests(unittest.TestCase): + + v4addr = ipaddress.IPv4Address(1) + v4net = ipaddress.IPv4Network(1) + v4intf = ipaddress.IPv4Interface(1) + v6addr = ipaddress.IPv6Address(1) + v6net = ipaddress.IPv6Network(1) + v6intf = ipaddress.IPv6Interface(1) + + v4_addresses = [v4addr, v4intf] + v4_objects = v4_addresses + [v4net] + v6_addresses = [v6addr, v6intf] + v6_objects = v6_addresses + [v6net] + objects = v4_objects + v6_objects + + def test_foreign_type_equality(self): + # __eq__ should never raise TypeError directly + other = object() + for obj in self.objects: + self.assertNotEqual(obj, other) + self.assertFalse(obj == other) + self.assertEqual(obj.__eq__(other), NotImplemented) + self.assertEqual(obj.__ne__(other), NotImplemented) + + def test_mixed_type_equality(self): + # Ensure none of the internal objects accidentally + # expose the right set of attributes to become "equal" + for lhs in self.objects: + for rhs in self.objects: + if lhs is rhs: + continue + self.assertNotEqual(lhs, rhs) + + def test_containment(self): + for obj in self.v4_addresses: + self.assertIn(obj, self.v4net) + for obj in self.v6_addresses: + self.assertIn(obj, self.v6net) + for obj in self.v4_objects + [self.v6net]: + self.assertNotIn(obj, self.v6net) + for obj in self.v6_objects + [self.v4net]: + self.assertNotIn(obj, self.v4net) + + def test_mixed_type_ordering(self): + for lhs in self.objects: + for rhs in self.objects: + if isinstance(lhs, type(rhs)) or isinstance(rhs, type(lhs)): + continue + self.assertRaises(TypeError, lambda: lhs < rhs) + self.assertRaises(TypeError, lambda: lhs > rhs) + self.assertRaises(TypeError, lambda: lhs <= rhs) + self.assertRaises(TypeError, lambda: lhs >= rhs) + + def test_mixed_type_key(self): + # with get_mixed_type_key, you can sort addresses and network. + v4_ordered = [self.v4addr, self.v4net, self.v4intf] + v6_ordered = [self.v6addr, self.v6net, self.v6intf] + self.assertEqual(v4_ordered, + sorted(self.v4_objects, + key=ipaddress.get_mixed_type_key)) + self.assertEqual(v6_ordered, + sorted(self.v6_objects, + key=ipaddress.get_mixed_type_key)) + self.assertEqual(v4_ordered + v6_ordered, + sorted(self.objects, + key=ipaddress.get_mixed_type_key)) + self.assertEqual(NotImplemented, ipaddress.get_mixed_type_key(object)) + + def test_incompatible_versions(self): + # These should always raise TypeError + v4addr = ipaddress.ip_address('1.1.1.1') + v4net = ipaddress.ip_network('1.1.1.1') + v6addr = ipaddress.ip_address('::1') + v6net = ipaddress.ip_address('::1') + + self.assertRaises(TypeError, v4addr.__lt__, v6addr) + self.assertRaises(TypeError, v4addr.__gt__, v6addr) + self.assertRaises(TypeError, v4net.__lt__, v6net) + self.assertRaises(TypeError, v4net.__gt__, v6net) + + self.assertRaises(TypeError, v6addr.__lt__, v4addr) + self.assertRaises(TypeError, v6addr.__gt__, v4addr) + self.assertRaises(TypeError, v6net.__lt__, v4net) + self.assertRaises(TypeError, v6net.__gt__, v4net) + + + class IpaddrUnitTest(unittest.TestCase): def setUp(self): @@ -495,67 +582,6 @@ class IpaddrUnitTest(unittest.TestCase): self.assertEqual(str(self.ipv6_network.hostmask), '::ffff:ffff:ffff:ffff') - def testEqualityChecks(self): - # __eq__ should never raise TypeError directly - other = object() - def assertEqualityNotImplemented(instance): - self.assertEqual(instance.__eq__(other), NotImplemented) - self.assertEqual(instance.__ne__(other), NotImplemented) - self.assertFalse(instance == other) - self.assertTrue(instance != other) - - assertEqualityNotImplemented(self.ipv4_address) - assertEqualityNotImplemented(self.ipv4_network) - assertEqualityNotImplemented(self.ipv4_interface) - assertEqualityNotImplemented(self.ipv6_address) - assertEqualityNotImplemented(self.ipv6_network) - assertEqualityNotImplemented(self.ipv6_interface) - - def testBadVersionComparison(self): - # These should always raise TypeError - v4addr = ipaddress.ip_address('1.1.1.1') - v4net = ipaddress.ip_network('1.1.1.1') - v6addr = ipaddress.ip_address('::1') - v6net = ipaddress.ip_address('::1') - - self.assertRaises(TypeError, v4addr.__lt__, v6addr) - self.assertRaises(TypeError, v4addr.__gt__, v6addr) - self.assertRaises(TypeError, v4net.__lt__, v6net) - self.assertRaises(TypeError, v4net.__gt__, v6net) - - self.assertRaises(TypeError, v6addr.__lt__, v4addr) - self.assertRaises(TypeError, v6addr.__gt__, v4addr) - self.assertRaises(TypeError, v6net.__lt__, v4net) - self.assertRaises(TypeError, v6net.__gt__, v4net) - - def testMixedTypeComparison(self): - v4addr = ipaddress.ip_address('1.1.1.1') - v4net = ipaddress.ip_network('1.1.1.1/32') - v6addr = ipaddress.ip_address('::1') - v6net = ipaddress.ip_network('::1/128') - - self.assertFalse(v4net.__contains__(v6net)) - self.assertFalse(v6net.__contains__(v4net)) - - self.assertRaises(TypeError, lambda: v4addr < v4net) - self.assertRaises(TypeError, lambda: v4addr > v4net) - self.assertRaises(TypeError, lambda: v4net < v4addr) - self.assertRaises(TypeError, lambda: v4net > v4addr) - - self.assertRaises(TypeError, lambda: v6addr < v6net) - self.assertRaises(TypeError, lambda: v6addr > v6net) - self.assertRaises(TypeError, lambda: v6net < v6addr) - self.assertRaises(TypeError, lambda: v6net > v6addr) - - # with get_mixed_type_key, you can sort addresses and network. - self.assertEqual([v4addr, v4net], - sorted([v4net, v4addr], - key=ipaddress.get_mixed_type_key)) - self.assertEqual([v6addr, v6net], - sorted([v6net, v6addr], - key=ipaddress.get_mixed_type_key)) - self.assertEqual(NotImplemented, ipaddress.get_mixed_type_key(object)) - def testIpFromInt(self): self.assertEqual(self.ipv4_interface._ip, ipaddress.IPv4Interface(16909060)._ip) @@ -1049,6 +1075,16 @@ class IpaddrUnitTest(unittest.TestCase): self.assertTrue(ipaddress.ip_address('::1') <= ipaddress.ip_address('::2')) + def testInterfaceComparison(self): + self.assertTrue(ipaddress.ip_interface('1.1.1.1') <= + ipaddress.ip_interface('1.1.1.1')) + self.assertTrue(ipaddress.ip_interface('1.1.1.1') <= + ipaddress.ip_interface('1.1.1.2')) + self.assertTrue(ipaddress.ip_interface('::1') <= + ipaddress.ip_interface('::1')) + self.assertTrue(ipaddress.ip_interface('::1') <= + ipaddress.ip_interface('::2')) + def testNetworkComparison(self): # ip1 and ip2 have the same network address ip1 = ipaddress.IPv4Network('1.1.1.0/24') |