summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristoph Zwerschke <cito@online.de>2016-02-09 15:20:51 +0100
committerChristoph Zwerschke <cito@online.de>2016-02-09 15:20:51 +0100
commit29ef0ad7c12d9ac8c27123d945bc4fbd74ba1a3c (patch)
treecfc1801574f81439fe27da7de3a7e7fffbb99c33
parentc65a805f577d659a2cb169a447b0134c081d500a (diff)
downloadsqlalchemy-pr/234.tar.gz
Make all tests pass with postgresql+pygresqlpr/234
Note that it requires the latest PyGreSQL 5.0 to pass all tests, a warning is printed otherwise. Minor change to test.dialect.postgresql.test_types: - JSON content should not be required to be unicode if the dialect doesn't return unicode (Python 2). Minor change to dialects/postgresql.base: - Return index names properly casted to unicode
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py4
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pygresql.py138
-rw-r--r--test/dialect/postgresql/test_types.py13
3 files changed, 129 insertions, 26 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index b16a82e04..25938d44b 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -2396,7 +2396,9 @@ class PGDialect(default.DefaultDialect):
i.relname
"""
- t = sql.text(IDX_SQL, typemap={'attname': sqltypes.Unicode})
+ t = sql.text(IDX_SQL, typemap={
+ 'relname': sqltypes.Unicode,
+ 'attname': sqltypes.Unicode})
c = connection.execute(t, table_oid=table_oid)
indexes = defaultdict(lambda: defaultdict(dict))
diff --git a/lib/sqlalchemy/dialects/postgresql/pygresql.py b/lib/sqlalchemy/dialects/postgresql/pygresql.py
index 79ad3f188..d30206613 100644
--- a/lib/sqlalchemy/dialects/postgresql/pygresql.py
+++ b/lib/sqlalchemy/dialects/postgresql/pygresql.py
@@ -19,7 +19,9 @@ import re
from ... import exc, processors, util
from ...types import Numeric, JSON as Json
-from .base import PGDialect, _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID
+from ...sql.elements import Null
+from .base import PGDialect, PGCompiler, PGIdentifierPreparer, \
+ _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID
from .hstore import HSTORE
from .json import JSON, JSONB
@@ -30,23 +32,24 @@ class _PGNumeric(Numeric):
return None
def result_processor(self, dialect, coltype):
- oid = coltype.oid
+ if not isinstance(coltype, int):
+ coltype = coltype.oid
if self.asdecimal:
- if oid in _FLOAT_TYPES:
+ if coltype in _FLOAT_TYPES:
return processors.to_decimal_processor_factory(
decimal.Decimal,
self._effective_decimal_return_scale)
- elif oid in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
# PyGreSQL returns Decimal natively for 1700 (numeric)
return None
else:
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype)
else:
- if oid in _FLOAT_TYPES:
+ if coltype in _FLOAT_TYPES:
# PyGreSQL returns float natively for 701 (float8)
return None
- elif oid in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
return processors.to_float
else:
raise exc.InvalidRequestError(
@@ -56,51 +59,121 @@ class _PGNumeric(Numeric):
class _PGHStore(HSTORE):
def bind_processor(self, dialect):
- if not dialect.use_native_hstore:
+ if not dialect.has_native_hstore:
return super(_PGHStore, self).bind_processor(dialect)
+ hstore = dialect.dbapi.Hstore
+ def process(value):
+ if isinstance(value, dict):
+ return hstore(value)
+ return value
+ return process
def result_processor(self, dialect, coltype):
- if not dialect.use_native_hstore:
+ if not dialect.has_native_hstore:
return super(_PGHStore, self).result_processor(dialect, coltype)
class _PGJSON(JSON):
def bind_processor(self, dialect):
- if not dialect.use_native_json:
+ if not dialect.has_native_json:
return super(_PGJSON, self).bind_processor(dialect)
+ json = dialect.dbapi.Json
+
+ def process(value):
+ if value is self.NULL:
+ value = None
+ elif isinstance(value, Null) or (
+ value is None and self.none_as_null):
+ return None
+ if value is None or isinstance(value, (dict, list)):
+ return json(value)
+ return value
+
+ return process
def result_processor(self, dialect, coltype):
- if not dialect.use_native_json:
+ if not dialect.has_native_json:
return super(_PGJSON, self).result_processor(dialect, coltype)
class _PGJSONB(JSONB):
def bind_processor(self, dialect):
- if not dialect.use_native_json:
+ if not dialect.has_native_json:
return super(_PGJSONB, self).bind_processor(dialect)
+ json = dialect.dbapi.Json
+
+ def process(value):
+ if value is self.NULL:
+ value = None
+ elif isinstance(value, Null) or (
+ value is None and self.none_as_null):
+ return None
+ if value is None or isinstance(value, (dict, list)):
+ return json(value)
+ return value
+
+ return process
def result_processor(self, dialect, coltype):
- if not dialect.use_native_json:
+ if not dialect.has_native_json:
return super(_PGJSONB, self).result_processor(dialect, coltype)
class _PGUUID(UUID):
def bind_processor(self, dialect):
- if not dialect.use_native_uuid:
+ if not dialect.has_native_uuid:
return super(_PGUUID, self).bind_processor(dialect)
+ uuid = dialect.dbapi.Uuid
+
+ def process(value):
+ if value is None:
+ return None
+ if isinstance(value, (str, bytes)):
+ if len(value) == 16:
+ return uuid(bytes=value)
+ return uuid(value)
+ if isinstance(value, int):
+ return uuid(int=value)
+ return value
+
+ return process
def result_processor(self, dialect, coltype):
- if not dialect.use_native_uuid:
+ if not dialect.has_native_uuid:
return super(_PGUUID, self).result_processor(dialect, coltype)
+ if not self.as_uuid:
+ def process(value):
+ if value is not None:
+ return str(value)
+ return process
+
+
+class _PGCompiler(PGCompiler):
+
+ def visit_mod_binary(self, binary, operator, **kw):
+ return self.process(binary.left, **kw) + " %% " + \
+ self.process(binary.right, **kw)
+
+ def post_process_text(self, text):
+ return text.replace('%', '%%')
+
+
+class _PGIdentifierPreparer(PGIdentifierPreparer):
+
+ def _escape_identifier(self, value):
+ value = value.replace(self.escape_quote, self.escape_to_quote)
+ return value.replace('%', '%%')
class PGDialect_pygresql(PGDialect):
+
driver = 'pygresql'
- _has_native_hstore = False
+ statement_compiler = _PGCompiler
+ preparer = _PGIdentifierPreparer
@classmethod
def dbapi(cls):
@@ -119,8 +192,7 @@ class PGDialect_pygresql(PGDialect):
}
)
- def __init__(self, use_native_hstore=False,
- use_native_json=False, use_native_uuid=True, **kwargs):
+ def __init__(self, **kwargs):
super(PGDialect_pygresql, self).__init__(**kwargs)
try:
version = self.dbapi.version
@@ -129,9 +201,18 @@ class PGDialect_pygresql(PGDialect):
except (AttributeError, ValueError, TypeError):
version = (0, 0)
self.dbapi_version = version
- self.use_native_hstore = use_native_hstore and version >= (5, 0)
- self.use_native_json = use_native_json and version >= (5, 0)
- self.use_native_uuid = use_native_uuid and version >= (5, 0)
+ if version < (5, 0):
+ has_native_hstore = has_native_json = has_native_uuid = False
+ if version != (0, 0):
+ util.warn("PyGreSQL is only fully supported by SQLAlchemy"
+ " since version 5.0.")
+ else:
+ self.supports_unicode_statements = True
+ self.supports_unicode_binds = True
+ has_native_hstore = has_native_json = has_native_uuid = True
+ self.has_native_hstore = has_native_hstore
+ self.has_native_json = has_native_json
+ self.has_native_uuid = has_native_uuid
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
@@ -141,5 +222,22 @@ class PGDialect_pygresql(PGDialect):
opts.update(url.query)
return [], opts
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.Error):
+ if not connection:
+ return False
+ try:
+ connection = connection.connection
+ except AttributeError:
+ pass
+ else:
+ if not connection:
+ return False
+ try:
+ return connection.closed
+ except AttributeError: # PyGreSQL < 5.0
+ return connection._cnx is None
+ return False
+
dialect = PGDialect_pygresql
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index 8818a9941..6bcc4cf9a 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -544,11 +544,11 @@ class NumericInterpretationTest(fixtures.TestBase):
__backend__ = True
def test_numeric_codes(self):
- from sqlalchemy.dialects.postgresql import psycopg2cffi, pg8000, \
- psycopg2, base
+ from sqlalchemy.dialects.postgresql import pg8000, pygresql, \
+ psycopg2, psycopg2cffi, base
- dialects = (pg8000.dialect(), psycopg2.dialect(),
- psycopg2cffi.dialect())
+ dialects = (pg8000.dialect(), pygresql.dialect(),
+ psycopg2.dialect(), psycopg2cffi.dialect())
for dialect in dialects:
typ = Numeric().dialect_impl(dialect)
for code in base._INT_TYPES + base._FLOAT_TYPES + \
@@ -2757,7 +2757,10 @@ class JSONRoundTripTest(fixtures.TablesTest):
result = engine.execute(
select([data_table.c.data['k1'].astext])
).first()
- assert isinstance(result[0], util.text_type)
+ if engine.dialect.returns_unicode_strings:
+ assert isinstance(result[0], util.text_type)
+ else:
+ assert isinstance(result[0], util.string_types)
def test_query_returned_as_int(self):
engine = testing.db