diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2023-04-21 12:35:01 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2023-04-21 12:35:01 +0000 |
| commit | 98b77c36ed90894a4f7d4b9b43b8903675f36717 (patch) | |
| tree | 975140993cacd542fc99ece27b46e06021bd706c | |
| parent | bf3a395e8f3c8814cab25e980a8762ef484bbb56 (diff) | |
| parent | 609f432563954167b8f0148e43c70c08380e8ba4 (diff) | |
| download | sqlalchemy-98b77c36ed90894a4f7d4b9b43b8903675f36717.tar.gz | |
Merge "Add intersection method to Range class" into main
| -rw-r--r-- | doc/build/changelog/unreleased_20/9509.rst | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/ranges.py | 46 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_types.py | 44 |
3 files changed, 96 insertions, 0 deletions
diff --git a/doc/build/changelog/unreleased_20/9509.rst b/doc/build/changelog/unreleased_20/9509.rst new file mode 100644 index 000000000..b50a4a028 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9509.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: usecase, postgresql + :tickets: 9509 + + Add missing :meth:`_postgresql.Range.intersection` method. + Pull request courtesy Yurii Karabas. diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 3cf2ceb44..cefd280ea 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -641,6 +641,43 @@ class Range(Generic[_T]): def __sub__(self, other: Range[_T]) -> Range[_T]: return self.difference(other) + def intersection(self, other: Range[_T]) -> Range[_T]: + """Compute the intersection of this range with the `other`.""" + if self.empty or other.empty or not self.overlaps(other): + return Range(None, None, empty=True) + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + if self._compare_edges(slower, slower_b, olower, olower_b) < 0: + rlower = olower + rlower_b = olower_b + else: + rlower = slower + rlower_b = slower_b + + if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0: + rupper = oupper + rupper_b = oupper_b + else: + rupper = supper + rupper_b = supper_b + + return Range( + rlower, + rupper, + bounds=cast(_BoundsType, rlower_b + rupper_b), + ) + + def __mul__(self, other: Range[_T]) -> Range[_T]: + return self.intersection(other) + def __str__(self) -> str: return self._stringify() @@ -809,6 +846,15 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]): __sub__ = difference + def intersection(self, other: Any) -> ColumnElement[Range[_T]]: + """Range expression. Returns the intersection of the two ranges. + Will raise an exception if the resulting range is not + contiguous. + """ + return self.expr.op("*")(other) # type: ignore + + __mul__ = intersection + class AbstractRangeImpl(AbstractRange[Range[_T]]): """Marker for AbstractRange that will apply a subclass-specific diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 5f5be3c57..f322bf354 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -4499,6 +4499,50 @@ class _RangeComparisonFixtures(_RangeTests): @testing.combinations( *_common_ranges_to_test, lambda r, e: Range(r.lower, r.lower, bounds="[]"), + lambda r, e: Range(r.lower - e, r.upper - e, bounds="[]"), + lambda r, e: Range(r.lower - e, r.upper + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.upper + e, bounds="[]"), + argnames="r1t", + ) + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.lower, r.lower, bounds="[]"), + lambda r, e: Range(r.lower, r.upper - e, bounds="(]"), + lambda r, e: Range(r.lower, r.lower + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.lower, bounds="(]"), + lambda r, e: Range(r.lower - e, r.lower + e, bounds="()"), + lambda r, e: Range(r.lower, r.upper, bounds="[]"), + lambda r, e: Range(r.lower, r.upper, bounds="()"), + argnames="r2t", + ) + def test_intersection(self, connection, r1t, r2t): + r1 = r1t(self._data_obj(), self._epsilon) + r2 = r2t(self._data_obj(), self._epsilon) + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + cast(r1, RANGE).intersection(r2), + ) + validate_q = select( + literal_column(f"'{r1}'::{range_typ}*'{r2}'::{range_typ}", RANGE), + ) + + pg_res = connection.execute(q).scalar() + + validate_intersection = connection.execute(validate_q).scalar() + eq_(pg_res, validate_intersection) + py_res = r1.intersection(r2) + eq_( + py_res, + pg_res, + f"{r1}.intersection({r2}): got {py_res}, expected {pg_res}", + ) + + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.lower, r.lower, bounds="[]"), argnames="r1t", ) @testing.combinations( |
