diff options
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/__init__.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 59 | ||||
-rw-r--r-- | test/dialect/postgresql/test_types.py | 14 |
3 files changed, 73 insertions, 4 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 98fe6f085..a5ddbd575 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -13,7 +13,7 @@ from .base import \ INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \ INET, CIDR, UUID, BIT, MACADDR, OID, DOUBLE_PRECISION, TIMESTAMP, TIME, \ DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array, Any, All, \ - TSVECTOR, DropEnumType + TSVECTOR, DropEnumType, LTREE, LQUERY, LTXTQUERY from .constraints import ExcludeConstraint from .hstore import HSTORE, hstore from .json import JSON, JSONElement, JSONB @@ -27,5 +27,5 @@ __all__ = ( 'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'Any', 'All', 'array', 'HSTORE', 'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE', 'TSRANGE', 'TSTZRANGE', 'json', 'JSON', 'JSONB', 'JSONElement', - 'DropEnumType' + 'DropEnumType', 'LTREE', 'LQUERY', 'LTXTQUERY' ) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 22c66dbbb..0793ca3ab 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -598,6 +598,51 @@ class OID(sqltypes.TypeEngine): __visit_name__ = "OID" +class LTREE(sqltypes.TypeEngine): + + """Postgresql LTREE type. + """ + + class Comparator(sqltypes.Concatenable.Comparator): + + def ancestor_of(self, other): + if isinstance(other, list): + return self.expr.op('@>')(expression.cast(other, ARRAY(LTREE))) + else: + return self.expr.op('@>')(other) + + def descendant_of(self, other): + if isinstance(other, list): + return self.expr.op('<@')(expression.cast(other, ARRAY(LTREE))) + else: + return self.expr.op('<@')(other) + + def lquery(self, other): + if isinstance(other, list): + return self.expr.op('?')(expression.cast(other, ARRAY(LQUERY))) + else: + return self.expr.op('~')(other) + + def ltxtquery(self, other): + return self.expr.op('@')(other) + + comparator_factory = Comparator + + __visit_name__ = 'LTREE' + +PGLTree = LTREE + + +class LQUERY(sqltypes.TypeEngine): + __visit_name__ = 'LQUERY' +PGLQuery = LQUERY + + +class LTXTQUERY(sqltypes.TypeEngine): + __visit_name__ = 'LTXTQUERY' +PGLTxtQuery = LTXTQUERY + + class TIMESTAMP(sqltypes.TIMESTAMP): def __init__(self, timezone=False, precision=None): @@ -1360,7 +1405,10 @@ ischema_names = { 'interval': INTERVAL, 'interval year to month': INTERVAL, 'interval day to second': INTERVAL, - 'tsvector': TSVECTOR + 'tsvector': TSVECTOR, + 'ltree': LTREE, + 'lquery': LQUERY, + 'ltxtquery': LTXTQUERY } @@ -1782,6 +1830,15 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): if type_.dimensions is not None else 1)) + def visit_LTREE(self, type_, **kw): + return 'LTREE' + + def visit_LQUERY(self, type_, **kw): + return 'LQUERY' + + def visit_LTXTQUERY(self, type_, **kw): + return 'LTXTQUERY' + class PGIdentifierPreparer(compiler.IdentifierPreparer): diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index fac0f2df8..456efbf92 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -14,7 +14,7 @@ import sqlalchemy as sa from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, \ INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE, \ - JSON, JSONB + JSON, JSONB, LTREE, LQUERY, LTXTQUERY import decimal from sqlalchemy import util from sqlalchemy.testing.util import round_decimal @@ -1052,6 +1052,18 @@ class TimestampTest(fixtures.TestBase, AssertsExecutionResults): eq_(result[0], datetime.datetime(2007, 12, 25, 0, 0)) +class LTreeTest(fixtures.TestBase, AssertsExecutionResults): + __only_on__ = 'postgresql >= 9.1.0' + __backend__ = True + + def test_ltree(self): + engine = testing.db + connection = engine.connect() + + # The Postgresql server must have the LTREE extension enabled. + # `CREATE EXTENSION ltree;` + + class SpecialTypesTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): """test DDL and reflection of PG-specific types """ |