summaryrefslogtreecommitdiff
path: root/Lib
diff options
context:
space:
mode:
authorNick Coghlan <ncoghlan@gmail.com>2012-07-08 00:45:33 +1000
committerNick Coghlan <ncoghlan@gmail.com>2012-07-08 00:45:33 +1000
commit3008ec070f40974790f6a85f7ceabf0876950abf (patch)
treebc7c86ec1c468fc57fdee384b963f1c18856ba69 /Lib
parent9a9c28ce7a051b37a91e4fc7aef70bcdcda25047 (diff)
downloadcpython-git-3008ec070f40974790f6a85f7ceabf0876950abf.tar.gz
Issue 14814: Ensure ordering semantics across all 3 entity types in ipaddress are consistent and well-defined
Diffstat (limited to 'Lib')
-rw-r--r--Lib/ipaddress.py136
-rw-r--r--Lib/test/test_ipaddress.py158
2 files changed, 166 insertions, 128 deletions
diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py
index b1e07fc992..2019009551 100644
--- a/Lib/ipaddress.py
+++ b/Lib/ipaddress.py
@@ -12,7 +12,7 @@ __version__ = '1.0'
import struct
-
+import functools
IPV4LENGTH = 32
IPV6LENGTH = 128
@@ -405,7 +405,38 @@ def get_mixed_type_key(obj):
return NotImplemented
-class _IPAddressBase:
+class _TotalOrderingMixin:
+ # Helper that derives the other comparison operations from
+ # __lt__ and __eq__
+ def __eq__(self, other):
+ raise NotImplementedError
+ def __ne__(self, other):
+ equal = self.__eq__(other)
+ if equal is NotImplemented:
+ return NotImplemented
+ return not equal
+ def __lt__(self, other):
+ raise NotImplementedError
+ def __le__(self, other):
+ less = self.__lt__(other)
+ if less is NotImplemented or not less:
+ return self.__eq__(other)
+ return less
+ def __gt__(self, other):
+ less = self.__lt__(other)
+ if less is NotImplemented:
+ return NotImplemented
+ equal = self.__eq__(other)
+ if equal is NotImplemented:
+ return NotImplemented
+ return not (less or equal)
+ def __ge__(self, other):
+ less = self.__lt__(other)
+ if less is NotImplemented:
+ return NotImplemented
+ return not less
+
+class _IPAddressBase(_TotalOrderingMixin):
"""The mother class."""
@@ -465,7 +496,6 @@ class _IPAddressBase:
prefixlen = self._prefixlen
return self._string_from_ip_int(self._ip_int_from_prefix(prefixlen))
-
class _BaseAddress(_IPAddressBase):
"""A generic IP object.
@@ -493,24 +523,6 @@ class _BaseAddress(_IPAddressBase):
except AttributeError:
return NotImplemented
- def __ne__(self, other):
- eq = self.__eq__(other)
- if eq is NotImplemented:
- return NotImplemented
- return not eq
-
- def __le__(self, other):
- gt = self.__gt__(other)
- if gt is NotImplemented:
- return NotImplemented
- return not gt
-
- def __ge__(self, other):
- lt = self.__lt__(other)
- if lt is NotImplemented:
- return NotImplemented
- return not lt
-
def __lt__(self, other):
if self._version != other._version:
raise TypeError('%s and %s are not of the same version' % (
@@ -522,17 +534,6 @@ class _BaseAddress(_IPAddressBase):
return self._ip < other._ip
return False
- def __gt__(self, other):
- if self._version != other._version:
- raise TypeError('%s and %s are not of the same version' % (
- self, other))
- if not isinstance(other, _BaseAddress):
- raise TypeError('%s and %s are not of the same type' % (
- self, other))
- if self._ip != other._ip:
- return self._ip > other._ip
- return False
-
# Shorthand for Integer addition and subtraction. This is not
# meant to ever support addition/subtraction of addresses.
def __add__(self, other):
@@ -625,31 +626,6 @@ class _BaseNetwork(_IPAddressBase):
return self.netmask < other.netmask
return False
- def __gt__(self, other):
- if self._version != other._version:
- raise TypeError('%s and %s are not of the same version' % (
- self, other))
- if not isinstance(other, _BaseNetwork):
- raise TypeError('%s and %s are not of the same type' % (
- self, other))
- if self.network_address != other.network_address:
- return self.network_address > other.network_address
- if self.netmask != other.netmask:
- return self.netmask > other.netmask
- return False
-
- def __le__(self, other):
- gt = self.__gt__(other)
- if gt is NotImplemented:
- return NotImplemented
- return not gt
-
- def __ge__(self, other):
- lt = self.__lt__(other)
- if lt is NotImplemented:
- return NotImplemented
- return not lt
-
def __eq__(self, other):
try:
return (self._version == other._version and
@@ -658,12 +634,6 @@ class _BaseNetwork(_IPAddressBase):
except AttributeError:
return NotImplemented
- def __ne__(self, other):
- eq = self.__eq__(other)
- if eq is NotImplemented:
- return NotImplemented
- return not eq
-
def __hash__(self):
return hash(int(self.network_address) ^ int(self.netmask))
@@ -1292,11 +1262,27 @@ class IPv4Interface(IPv4Address):
self.network.prefixlen)
def __eq__(self, other):
+ address_equal = IPv4Address.__eq__(self, other)
+ if not address_equal or address_equal is NotImplemented:
+ return address_equal
try:
- return (IPv4Address.__eq__(self, other) and
- self.network == other.network)
+ return self.network == other.network
except AttributeError:
+ # An interface with an associated network is NOT the
+ # same as an unassociated address. That's why the hash
+ # takes the extra info into account.
+ return False
+
+ def __lt__(self, other):
+ address_less = IPv4Address.__lt__(self, other)
+ if address_less is NotImplemented:
return NotImplemented
+ try:
+ return self.network < other.network
+ except AttributeError:
+ # We *do* allow addresses and interfaces to be sorted. The
+ # unassociated address is considered less than all interfaces.
+ return False
def __hash__(self):
return self._ip ^ self._prefixlen ^ int(self.network.network_address)
@@ -1928,11 +1914,27 @@ class IPv6Interface(IPv6Address):
self.network.prefixlen)
def __eq__(self, other):
+ address_equal = IPv6Address.__eq__(self, other)
+ if not address_equal or address_equal is NotImplemented:
+ return address_equal
try:
- return (IPv6Address.__eq__(self, other) and
- self.network == other.network)
+ return self.network == other.network
except AttributeError:
+ # An interface with an associated network is NOT the
+ # same as an unassociated address. That's why the hash
+ # takes the extra info into account.
+ return False
+
+ def __lt__(self, other):
+ address_less = IPv6Address.__lt__(self, other)
+ if address_less is NotImplemented:
return NotImplemented
+ try:
+ return self.network < other.network
+ except AttributeError:
+ # We *do* allow addresses and interfaces to be sorted. The
+ # unassociated address is considered less than all interfaces.
+ return False
def __hash__(self):
return self._ip ^ self._prefixlen ^ int(self.network.network_address)
diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py
index 417c98677f..5aaf736740 100644
--- a/Lib/test/test_ipaddress.py
+++ b/Lib/test/test_ipaddress.py
@@ -415,6 +415,93 @@ class FactoryFunctionErrors(ErrorReporting):
self.assertFactoryError(ipaddress.ip_network, "network")
+class ComparisonTests(unittest.TestCase):
+
+ v4addr = ipaddress.IPv4Address(1)
+ v4net = ipaddress.IPv4Network(1)
+ v4intf = ipaddress.IPv4Interface(1)
+ v6addr = ipaddress.IPv6Address(1)
+ v6net = ipaddress.IPv6Network(1)
+ v6intf = ipaddress.IPv6Interface(1)
+
+ v4_addresses = [v4addr, v4intf]
+ v4_objects = v4_addresses + [v4net]
+ v6_addresses = [v6addr, v6intf]
+ v6_objects = v6_addresses + [v6net]
+ objects = v4_objects + v6_objects
+
+ def test_foreign_type_equality(self):
+ # __eq__ should never raise TypeError directly
+ other = object()
+ for obj in self.objects:
+ self.assertNotEqual(obj, other)
+ self.assertFalse(obj == other)
+ self.assertEqual(obj.__eq__(other), NotImplemented)
+ self.assertEqual(obj.__ne__(other), NotImplemented)
+
+ def test_mixed_type_equality(self):
+ # Ensure none of the internal objects accidentally
+ # expose the right set of attributes to become "equal"
+ for lhs in self.objects:
+ for rhs in self.objects:
+ if lhs is rhs:
+ continue
+ self.assertNotEqual(lhs, rhs)
+
+ def test_containment(self):
+ for obj in self.v4_addresses:
+ self.assertIn(obj, self.v4net)
+ for obj in self.v6_addresses:
+ self.assertIn(obj, self.v6net)
+ for obj in self.v4_objects + [self.v6net]:
+ self.assertNotIn(obj, self.v6net)
+ for obj in self.v6_objects + [self.v4net]:
+ self.assertNotIn(obj, self.v4net)
+
+ def test_mixed_type_ordering(self):
+ for lhs in self.objects:
+ for rhs in self.objects:
+ if isinstance(lhs, type(rhs)) or isinstance(rhs, type(lhs)):
+ continue
+ self.assertRaises(TypeError, lambda: lhs < rhs)
+ self.assertRaises(TypeError, lambda: lhs > rhs)
+ self.assertRaises(TypeError, lambda: lhs <= rhs)
+ self.assertRaises(TypeError, lambda: lhs >= rhs)
+
+ def test_mixed_type_key(self):
+ # with get_mixed_type_key, you can sort addresses and network.
+ v4_ordered = [self.v4addr, self.v4net, self.v4intf]
+ v6_ordered = [self.v6addr, self.v6net, self.v6intf]
+ self.assertEqual(v4_ordered,
+ sorted(self.v4_objects,
+ key=ipaddress.get_mixed_type_key))
+ self.assertEqual(v6_ordered,
+ sorted(self.v6_objects,
+ key=ipaddress.get_mixed_type_key))
+ self.assertEqual(v4_ordered + v6_ordered,
+ sorted(self.objects,
+ key=ipaddress.get_mixed_type_key))
+ self.assertEqual(NotImplemented, ipaddress.get_mixed_type_key(object))
+
+ def test_incompatible_versions(self):
+ # These should always raise TypeError
+ v4addr = ipaddress.ip_address('1.1.1.1')
+ v4net = ipaddress.ip_network('1.1.1.1')
+ v6addr = ipaddress.ip_address('::1')
+ v6net = ipaddress.ip_address('::1')
+
+ self.assertRaises(TypeError, v4addr.__lt__, v6addr)
+ self.assertRaises(TypeError, v4addr.__gt__, v6addr)
+ self.assertRaises(TypeError, v4net.__lt__, v6net)
+ self.assertRaises(TypeError, v4net.__gt__, v6net)
+
+ self.assertRaises(TypeError, v6addr.__lt__, v4addr)
+ self.assertRaises(TypeError, v6addr.__gt__, v4addr)
+ self.assertRaises(TypeError, v6net.__lt__, v4net)
+ self.assertRaises(TypeError, v6net.__gt__, v4net)
+
+
+
class IpaddrUnitTest(unittest.TestCase):
def setUp(self):
@@ -495,67 +582,6 @@ class IpaddrUnitTest(unittest.TestCase):
self.assertEqual(str(self.ipv6_network.hostmask),
'::ffff:ffff:ffff:ffff')
- def testEqualityChecks(self):
- # __eq__ should never raise TypeError directly
- other = object()
- def assertEqualityNotImplemented(instance):
- self.assertEqual(instance.__eq__(other), NotImplemented)
- self.assertEqual(instance.__ne__(other), NotImplemented)
- self.assertFalse(instance == other)
- self.assertTrue(instance != other)
-
- assertEqualityNotImplemented(self.ipv4_address)
- assertEqualityNotImplemented(self.ipv4_network)
- assertEqualityNotImplemented(self.ipv4_interface)
- assertEqualityNotImplemented(self.ipv6_address)
- assertEqualityNotImplemented(self.ipv6_network)
- assertEqualityNotImplemented(self.ipv6_interface)
-
- def testBadVersionComparison(self):
- # These should always raise TypeError
- v4addr = ipaddress.ip_address('1.1.1.1')
- v4net = ipaddress.ip_network('1.1.1.1')
- v6addr = ipaddress.ip_address('::1')
- v6net = ipaddress.ip_address('::1')
-
- self.assertRaises(TypeError, v4addr.__lt__, v6addr)
- self.assertRaises(TypeError, v4addr.__gt__, v6addr)
- self.assertRaises(TypeError, v4net.__lt__, v6net)
- self.assertRaises(TypeError, v4net.__gt__, v6net)
-
- self.assertRaises(TypeError, v6addr.__lt__, v4addr)
- self.assertRaises(TypeError, v6addr.__gt__, v4addr)
- self.assertRaises(TypeError, v6net.__lt__, v4net)
- self.assertRaises(TypeError, v6net.__gt__, v4net)
-
- def testMixedTypeComparison(self):
- v4addr = ipaddress.ip_address('1.1.1.1')
- v4net = ipaddress.ip_network('1.1.1.1/32')
- v6addr = ipaddress.ip_address('::1')
- v6net = ipaddress.ip_network('::1/128')
-
- self.assertFalse(v4net.__contains__(v6net))
- self.assertFalse(v6net.__contains__(v4net))
-
- self.assertRaises(TypeError, lambda: v4addr < v4net)
- self.assertRaises(TypeError, lambda: v4addr > v4net)
- self.assertRaises(TypeError, lambda: v4net < v4addr)
- self.assertRaises(TypeError, lambda: v4net > v4addr)
-
- self.assertRaises(TypeError, lambda: v6addr < v6net)
- self.assertRaises(TypeError, lambda: v6addr > v6net)
- self.assertRaises(TypeError, lambda: v6net < v6addr)
- self.assertRaises(TypeError, lambda: v6net > v6addr)
-
- # with get_mixed_type_key, you can sort addresses and network.
- self.assertEqual([v4addr, v4net],
- sorted([v4net, v4addr],
- key=ipaddress.get_mixed_type_key))
- self.assertEqual([v6addr, v6net],
- sorted([v6net, v6addr],
- key=ipaddress.get_mixed_type_key))
- self.assertEqual(NotImplemented, ipaddress.get_mixed_type_key(object))
-
def testIpFromInt(self):
self.assertEqual(self.ipv4_interface._ip,
ipaddress.IPv4Interface(16909060)._ip)
@@ -1049,6 +1075,16 @@ class IpaddrUnitTest(unittest.TestCase):
self.assertTrue(ipaddress.ip_address('::1') <=
ipaddress.ip_address('::2'))
+ def testInterfaceComparison(self):
+ self.assertTrue(ipaddress.ip_interface('1.1.1.1') <=
+ ipaddress.ip_interface('1.1.1.1'))
+ self.assertTrue(ipaddress.ip_interface('1.1.1.1') <=
+ ipaddress.ip_interface('1.1.1.2'))
+ self.assertTrue(ipaddress.ip_interface('::1') <=
+ ipaddress.ip_interface('::1'))
+ self.assertTrue(ipaddress.ip_interface('::1') <=
+ ipaddress.ip_interface('::2'))
+
def testNetworkComparison(self):
# ip1 and ip2 have the same network address
ip1 = ipaddress.IPv4Network('1.1.1.0/24')