summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2023-04-21 12:35:01 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2023-04-21 12:35:01 +0000
commit98b77c36ed90894a4f7d4b9b43b8903675f36717 (patch)
tree975140993cacd542fc99ece27b46e06021bd706c /lib/sqlalchemy/dialects/postgresql
parentbf3a395e8f3c8814cab25e980a8762ef484bbb56 (diff)
parent609f432563954167b8f0148e43c70c08380e8ba4 (diff)
downloadsqlalchemy-98b77c36ed90894a4f7d4b9b43b8903675f36717.tar.gz
Merge "Add intersection method to Range class" into main
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ranges.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py
index 3cf2ceb44..cefd280ea 100644
--- a/lib/sqlalchemy/dialects/postgresql/ranges.py
+++ b/lib/sqlalchemy/dialects/postgresql/ranges.py
@@ -641,6 +641,43 @@ class Range(Generic[_T]):
def __sub__(self, other: Range[_T]) -> Range[_T]:
return self.difference(other)
+ def intersection(self, other: Range[_T]) -> Range[_T]:
+ """Compute the intersection of this range with the `other`."""
+ if self.empty or other.empty or not self.overlaps(other):
+ return Range(None, None, empty=True)
+
+ slower = self.lower
+ slower_b = self.bounds[0]
+ supper = self.upper
+ supper_b = self.bounds[1]
+ olower = other.lower
+ olower_b = other.bounds[0]
+ oupper = other.upper
+ oupper_b = other.bounds[1]
+
+ if self._compare_edges(slower, slower_b, olower, olower_b) < 0:
+ rlower = olower
+ rlower_b = olower_b
+ else:
+ rlower = slower
+ rlower_b = slower_b
+
+ if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0:
+ rupper = oupper
+ rupper_b = oupper_b
+ else:
+ rupper = supper
+ rupper_b = supper_b
+
+ return Range(
+ rlower,
+ rupper,
+ bounds=cast(_BoundsType, rlower_b + rupper_b),
+ )
+
+ def __mul__(self, other: Range[_T]) -> Range[_T]:
+ return self.intersection(other)
+
def __str__(self) -> str:
return self._stringify()
@@ -809,6 +846,15 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
__sub__ = difference
+ def intersection(self, other: Any) -> ColumnElement[Range[_T]]:
+ """Range expression. Returns the intersection of the two ranges.
+ Will raise an exception if the resulting range is not
+ contiguous.
+ """
+ return self.expr.op("*")(other) # type: ignore
+
+ __mul__ = intersection
+
class AbstractRangeImpl(AbstractRange[Range[_T]]):
"""Marker for AbstractRange that will apply a subclass-specific