summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/dialects/postgresql/__init__.py3
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py3
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pgjson.py128
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2.py15
-rw-r--r--test/dialect/postgresql/test_types.py191
5 files changed, 338 insertions, 2 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py
index 3c1d19504..728f1629f 100644
--- a/lib/sqlalchemy/dialects/postgresql/__init__.py
+++ b/lib/sqlalchemy/dialects/postgresql/__init__.py
@@ -15,6 +15,7 @@ from .base import \
TSVECTOR
from .constraints import ExcludeConstraint
from .hstore import HSTORE, hstore
+from .pgjson import JSON
from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \
TSTZRANGE
@@ -24,5 +25,5 @@ __all__ = (
'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', 'BOOLEAN',
'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'Any', 'All', 'array', 'HSTORE',
'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE',
- 'TSRANGE', 'TSTZRANGE', 'TSVECTOR'
+ 'TSRANGE', 'TSTZRANGE', 'json', 'JSON'
)
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index d3380afdd..a7f838009 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -1246,6 +1246,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
def visit_HSTORE(self, type_):
return "HSTORE"
+ def visit_JSON(self, type_):
+ return "JSON"
+
def visit_INT4RANGE(self, type_):
return "INT4RANGE"
diff --git a/lib/sqlalchemy/dialects/postgresql/pgjson.py b/lib/sqlalchemy/dialects/postgresql/pgjson.py
new file mode 100644
index 000000000..a29d0bbcc
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/pgjson.py
@@ -0,0 +1,128 @@
+# postgresql/json.py
+# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+import json
+
+from .base import ARRAY, ischema_names
+from ... import types as sqltypes
+from ...sql import functions as sqlfunc
+from ...sql.operators import custom_op
+from ... import util
+
+__all__ = ('JSON', 'json')
+
+
+class JSON(sqltypes.TypeEngine):
+ """Represent the Postgresql JSON type.
+
+ The :class:`.JSON` type stores arbitrary JSON format data, e.g.::
+
+ data_table = Table('data_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', JSON)
+ )
+
+ with engine.connect() as conn:
+ conn.execute(
+ data_table.insert(),
+ data = {"key1": "value1", "key2": "value2"}
+ )
+
+ :class:`.JSON` provides several operations:
+
+ * Index operations::
+
+ data_table.c.data['some key']
+
+ * Index operations returning text (required for text comparison or casting)::
+
+ data_table.c.data.get_item_as_text('some key') == 'some value'
+
+ * Path index operations::
+
+ data_table.c.data.get_path("{key_1, key_2, ..., key_n}")
+
+ * Path index operations returning text (required for text comparison or casting)::
+
+ data_table.c.data.get_path("{key_1, key_2, ..., key_n}") == 'some value'
+
+ Please be aware that when used with the SQLAlchemy ORM, you will need to
+ replace the JSON object present on an attribute with a new object in order
+ for any changes to be properly persisted.
+
+ .. versionadded:: 0.9
+ """
+
+ __visit_name__ = 'JSON'
+
+ def __init__(self, json_serializer=None, json_deserializer=None):
+ if json_serializer:
+ self.json_serializer = json_serializer
+ else:
+ self.json_serializer = json.dumps
+ if json_deserializer:
+ self.json_deserializer = json_deserializer
+ else:
+ self.json_deserializer = json.loads
+
+ class comparator_factory(sqltypes.Concatenable.Comparator):
+ """Define comparison operations for :class:`.JSON`."""
+
+ def __getitem__(self, other):
+ """Text expression. Get the value at a given key."""
+ # I'm choosing to return text here so the result can be cast,
+ # compared with strings, etc.
+ #
+ # The only downside to this is that you cannot dereference more
+ # than one level deep in json structures, though comparator
+ # support for multi-level dereference is lacking anyhow.
+ return self.expr.op('->', precedence=5)(other)
+
+ def get_item_as_text(self, other):
+ """Text expression. Get the value at the given key as text. Use
+ this when you need to cast the type of the returned value."""
+ return self.expr.op('->>', precedence=5)(other)
+
+ def get_path(self, other):
+ """Text expression. Get the value at a given path. Paths are of
+ the form {key_1, key_2, ..., key_n}."""
+ return self.expr.op('#>', precedence=5)(other)
+
+ def get_path_as_text(self, other):
+ """Text expression. Get the value at a given path, as text.
+ Paths are of the form {key_1, key_2, ..., key_n}. Use this when
+ you need to cast the type of the returned value."""
+ return self.expr.op('#>>', precedence=5)(other)
+
+ def _adapt_expression(self, op, other_comparator):
+ if isinstance(op, custom_op):
+ if op.opstring == '->':
+ return op, sqltypes.Text
+ return sqltypes.Concatenable.Comparator.\
+ _adapt_expression(self, op, other_comparator)
+
+ def bind_processor(self, dialect):
+ if util.py2k:
+ encoding = dialect.encoding
+ def process(value):
+ return self.json_serializer(value).encode(encoding)
+ else:
+ def process(value):
+ return self.json_serializer(value)
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if util.py2k:
+ encoding = dialect.encoding
+ def process(value):
+ return self.json_deserializer(value.decode(encoding))
+ else:
+ def process(value):
+ return self.json_deserializer(value)
+ return process
+
+
+ischema_names['json'] = JSON
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
index c3c749523..4a9248e5f 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
@@ -179,6 +179,7 @@ from .base import PGDialect, PGCompiler, \
ENUM, ARRAY, _DECIMAL_TYPES, _FLOAT_TYPES,\
_INT_TYPES
from .hstore import HSTORE
+from .pgjson import JSON
logger = logging.getLogger('sqlalchemy.dialects.postgresql')
@@ -233,6 +234,17 @@ class _PGHStore(HSTORE):
else:
return super(_PGHStore, self).result_processor(dialect, coltype)
+
+class _PGJSON(JSON):
+ # I've omitted the bind processor here because the method of serializing
+ # involves registering specific types to auto-serialize, and the adapter
+ # just a thin wrapper over json.dumps.
+ def result_processor(self, dialect, coltype):
+ if dialect._has_native_json:
+ return None
+ else:
+ return super(_PGJSON, self).result_processor(dialect, coltype)
+
# When we're handed literal SQL, ensure it's a SELECT-query. Since
# 8.3, combining cursors and "FOR UPDATE" has been fine.
SERVER_SIDE_CURSOR_RE = re.compile(
@@ -317,6 +329,7 @@ class PGDialect_psycopg2(PGDialect):
psycopg2_version = (0, 0)
_has_native_hstore = False
+ _has_native_json = False
colspecs = util.update_copy(
PGDialect.colspecs,
@@ -325,6 +338,7 @@ class PGDialect_psycopg2(PGDialect):
ENUM: _PGEnum, # needs force_unicode
sqltypes.Enum: _PGEnum, # needs force_unicode
HSTORE: _PGHStore,
+ JSON: _PGJSON
}
)
@@ -352,6 +366,7 @@ class PGDialect_psycopg2(PGDialect):
self._has_native_hstore = self.use_native_hstore and \
self._hstore_oids(connection.connection) \
is not None
+ self._has_native_json = self.psycopg2_version >= (2, 5)
@classmethod
def dbapi(cls):
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index 6e8609448..19df131fd 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -15,7 +15,8 @@ from sqlalchemy.orm import Session, mapper, aliased
from sqlalchemy import exc, schema, types
from sqlalchemy.dialects.postgresql import base as postgresql
from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, \
- INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE
+ INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE, \
+ JSON
import decimal
from sqlalchemy import util
from sqlalchemy.testing.util import round_decimal
@@ -1663,3 +1664,191 @@ class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest):
def _data_obj(self):
return self.extras.DateTimeTZRange(*self.tstzs())
+
+
+class JSONTest(fixtures.TestBase):
+ def _assert_sql(self, construct, expected):
+ dialect = postgresql.dialect()
+ compiled = str(construct.compile(dialect=dialect))
+ compiled = re.sub(r'\s+', ' ', compiled)
+ expected = re.sub(r'\s+', ' ', expected)
+ eq_(compiled, expected)
+
+ def setup(self):
+ metadata = MetaData()
+ self.test_table = Table('test_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('test_column', JSON)
+ )
+ self.jsoncol = self.test_table.c.test_column
+
+ def _test_where(self, whereclause, expected):
+ stmt = select([self.test_table]).where(whereclause)
+ self._assert_sql(
+ stmt,
+ "SELECT test_table.id, test_table.test_column FROM test_table "
+ "WHERE %s" % expected
+ )
+
+ def _test_cols(self, colclause, expected, from_=True):
+ stmt = select([colclause])
+ self._assert_sql(
+ stmt,
+ (
+ "SELECT %s" +
+ (" FROM test_table" if from_ else "")
+ ) % expected
+ )
+
+ def test_bind_serialize_default(self):
+ from sqlalchemy.engine import default
+
+ dialect = default.DefaultDialect()
+ proc = self.test_table.c.test_column.type._cached_bind_processor(dialect)
+ eq_(
+ proc({"A": [1, 2, 3, True, False]}),
+ '{"A": [1, 2, 3, true, false]}'
+ )
+
+ def test_result_deserialize_default(self):
+ from sqlalchemy.engine import default
+
+ dialect = default.DefaultDialect()
+ proc = self.test_table.c.test_column.type._cached_result_processor(
+ dialect, None)
+ eq_(
+ proc('{"A": [1, 2, 3, true, false]}'),
+ {"A": [1, 2, 3, True, False]}
+ )
+
+ # This test is a bit misleading -- in real life you will need to cast to do anything
+ def test_where_getitem(self):
+ self._test_where(
+ self.jsoncol['bar'] == None,
+ "(test_table.test_column -> %(test_column_1)s) IS NULL"
+ )
+
+ def test_where_path(self):
+ self._test_where(
+ self.jsoncol.get_path('{"foo", 1}') == None,
+ "(test_table.test_column #> %(test_column_1)s) IS NULL"
+ )
+
+ def test_where_getitem_as_text(self):
+ self._test_where(
+ self.jsoncol.get_item_as_text('bar') == None,
+ "(test_table.test_column ->> %(test_column_1)s) IS NULL"
+ )
+
+ def test_where_path_as_text(self):
+ self._test_where(
+ self.jsoncol.get_path_as_text('{"foo", 1}') == None,
+ "(test_table.test_column #>> %(test_column_1)s) IS NULL"
+ )
+
+ def test_cols_get(self):
+ self._test_cols(
+ self.jsoncol['foo'],
+ "test_table.test_column -> %(test_column_1)s AS anon_1",
+ True
+ )
+
+
+class JSONRoundTripTest(fixtures.TablesTest):
+ __only_on__ = 'postgresql'
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table('data_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(30), nullable=False),
+ Column('data', JSON)
+ )
+
+ def _fixture_data(self, engine):
+ data_table = self.tables.data_table
+ engine.execute(
+ data_table.insert(),
+ {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}},
+ {'name': 'r2', 'data': {"k1": "r2v1", "k2": "r2v2"}},
+ {'name': 'r3', 'data': {"k1": "r3v1", "k2": "r3v2"}},
+ {'name': 'r4', 'data': {"k1": "r4v1", "k2": "r4v2"}},
+ {'name': 'r5', 'data': {"k1": "r5v1", "k2": "r5v2"}},
+ )
+
+ def _assert_data(self, compare):
+ data = testing.db.execute(
+ select([self.tables.data_table.c.data]).
+ order_by(self.tables.data_table.c.name)
+ ).fetchall()
+ eq_([d for d, in data], compare)
+
+ def _test_insert(self, engine):
+ engine.execute(
+ self.tables.data_table.insert(),
+ {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}
+ )
+ self._assert_data([{"k1": "r1v1", "k2": "r1v2"}])
+
+ def _non_native_engine(self):
+ if testing.against("postgresql+psycopg2"):
+ engine = engines.testing_engine()
+ else:
+ engine = testing.db
+ engine.connect()
+ return engine
+
+ def test_reflect(self):
+ from sqlalchemy import inspect
+ insp = inspect(testing.db)
+ cols = insp.get_columns('data_table')
+ assert isinstance(cols[2]['type'], JSON)
+
+ @testing.only_on("postgresql+psycopg2")
+ def test_insert_native(self):
+ engine = testing.db
+ self._test_insert(engine)
+
+ def test_insert_python(self):
+ engine = self._non_native_engine()
+ self._test_insert(engine)
+
+ @testing.only_on("postgresql+psycopg2")
+ def test_criterion_native(self):
+ engine = testing.db
+ self._fixture_data(engine)
+ self._test_criterion(engine)
+
+ def test_criterion_python(self):
+ engine = self._non_native_engine()
+ self._fixture_data(engine)
+ self._test_criterion(engine)
+
+ def test_path_query(self):
+ engine = testing.db
+ self._fixture_data(engine)
+ data_table = self.tables.data_table
+ result = engine.execute(
+ select([data_table.c.data]).where(
+ data_table.c.data.get_path_as_text('{k1}') == 'r3v1'
+ )
+ ).first()
+ eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
+
+ def test_query_returned_as_text(self):
+ engine = testing.db
+ self._fixture_data(engine)
+ data_table = self.tables.data_table
+ result = engine.execute(
+ select([data_table.c.data.get_item_as_text('k1')])
+ ).first()
+ assert isinstance(result[0], basestring)
+
+ def _test_criterion(self, engine):
+ data_table = self.tables.data_table
+ result = engine.execute(
+ select([data_table.c.data]).where(
+ data_table.c.data.get_item_as_text('k1') == 'r3v1'
+ )
+ ).first()
+ eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))