diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/pygresql.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/pygresql.py | 138 |
1 files changed, 118 insertions, 20 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/pygresql.py b/lib/sqlalchemy/dialects/postgresql/pygresql.py index 79ad3f188..d30206613 100644 --- a/lib/sqlalchemy/dialects/postgresql/pygresql.py +++ b/lib/sqlalchemy/dialects/postgresql/pygresql.py @@ -19,7 +19,9 @@ import re from ... import exc, processors, util from ...types import Numeric, JSON as Json -from .base import PGDialect, _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID +from ...sql.elements import Null +from .base import PGDialect, PGCompiler, PGIdentifierPreparer, \ + _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID from .hstore import HSTORE from .json import JSON, JSONB @@ -30,23 +32,24 @@ class _PGNumeric(Numeric): return None def result_processor(self, dialect, coltype): - oid = coltype.oid + if not isinstance(coltype, int): + coltype = coltype.oid if self.asdecimal: - if oid in _FLOAT_TYPES: + if coltype in _FLOAT_TYPES: return processors.to_decimal_processor_factory( decimal.Decimal, self._effective_decimal_return_scale) - elif oid in _DECIMAL_TYPES or coltype in _INT_TYPES: + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: # PyGreSQL returns Decimal natively for 1700 (numeric) return None else: raise exc.InvalidRequestError( "Unknown PG numeric type: %d" % coltype) else: - if oid in _FLOAT_TYPES: + if coltype in _FLOAT_TYPES: # PyGreSQL returns float natively for 701 (float8) return None - elif oid in _DECIMAL_TYPES or coltype in _INT_TYPES: + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: return processors.to_float else: raise exc.InvalidRequestError( @@ -56,51 +59,121 @@ class _PGNumeric(Numeric): class _PGHStore(HSTORE): def bind_processor(self, dialect): - if not dialect.use_native_hstore: + if not dialect.has_native_hstore: return super(_PGHStore, self).bind_processor(dialect) + hstore = dialect.dbapi.Hstore + def process(value): + if isinstance(value, dict): + return hstore(value) + return value + return process def result_processor(self, dialect, coltype): - if not dialect.use_native_hstore: + if not dialect.has_native_hstore: return super(_PGHStore, self).result_processor(dialect, coltype) class _PGJSON(JSON): def bind_processor(self, dialect): - if not dialect.use_native_json: + if not dialect.has_native_json: return super(_PGJSON, self).bind_processor(dialect) + json = dialect.dbapi.Json + + def process(value): + if value is self.NULL: + value = None + elif isinstance(value, Null) or ( + value is None and self.none_as_null): + return None + if value is None or isinstance(value, (dict, list)): + return json(value) + return value + + return process def result_processor(self, dialect, coltype): - if not dialect.use_native_json: + if not dialect.has_native_json: return super(_PGJSON, self).result_processor(dialect, coltype) class _PGJSONB(JSONB): def bind_processor(self, dialect): - if not dialect.use_native_json: + if not dialect.has_native_json: return super(_PGJSONB, self).bind_processor(dialect) + json = dialect.dbapi.Json + + def process(value): + if value is self.NULL: + value = None + elif isinstance(value, Null) or ( + value is None and self.none_as_null): + return None + if value is None or isinstance(value, (dict, list)): + return json(value) + return value + + return process def result_processor(self, dialect, coltype): - if not dialect.use_native_json: + if not dialect.has_native_json: return super(_PGJSONB, self).result_processor(dialect, coltype) class _PGUUID(UUID): def bind_processor(self, dialect): - if not dialect.use_native_uuid: + if not dialect.has_native_uuid: return super(_PGUUID, self).bind_processor(dialect) + uuid = dialect.dbapi.Uuid + + def process(value): + if value is None: + return None + if isinstance(value, (str, bytes)): + if len(value) == 16: + return uuid(bytes=value) + return uuid(value) + if isinstance(value, int): + return uuid(int=value) + return value + + return process def result_processor(self, dialect, coltype): - if not dialect.use_native_uuid: + if not dialect.has_native_uuid: return super(_PGUUID, self).result_processor(dialect, coltype) + if not self.as_uuid: + def process(value): + if value is not None: + return str(value) + return process + + +class _PGCompiler(PGCompiler): + + def visit_mod_binary(self, binary, operator, **kw): + return self.process(binary.left, **kw) + " %% " + \ + self.process(binary.right, **kw) + + def post_process_text(self, text): + return text.replace('%', '%%') + + +class _PGIdentifierPreparer(PGIdentifierPreparer): + + def _escape_identifier(self, value): + value = value.replace(self.escape_quote, self.escape_to_quote) + return value.replace('%', '%%') class PGDialect_pygresql(PGDialect): + driver = 'pygresql' - _has_native_hstore = False + statement_compiler = _PGCompiler + preparer = _PGIdentifierPreparer @classmethod def dbapi(cls): @@ -119,8 +192,7 @@ class PGDialect_pygresql(PGDialect): } ) - def __init__(self, use_native_hstore=False, - use_native_json=False, use_native_uuid=True, **kwargs): + def __init__(self, **kwargs): super(PGDialect_pygresql, self).__init__(**kwargs) try: version = self.dbapi.version @@ -129,9 +201,18 @@ class PGDialect_pygresql(PGDialect): except (AttributeError, ValueError, TypeError): version = (0, 0) self.dbapi_version = version - self.use_native_hstore = use_native_hstore and version >= (5, 0) - self.use_native_json = use_native_json and version >= (5, 0) - self.use_native_uuid = use_native_uuid and version >= (5, 0) + if version < (5, 0): + has_native_hstore = has_native_json = has_native_uuid = False + if version != (0, 0): + util.warn("PyGreSQL is only fully supported by SQLAlchemy" + " since version 5.0.") + else: + self.supports_unicode_statements = True + self.supports_unicode_binds = True + has_native_hstore = has_native_json = has_native_uuid = True + self.has_native_hstore = has_native_hstore + self.has_native_json = has_native_json + self.has_native_uuid = has_native_uuid def create_connect_args(self, url): opts = url.translate_connect_args(username='user') @@ -141,5 +222,22 @@ class PGDialect_pygresql(PGDialect): opts.update(url.query) return [], opts + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.Error): + if not connection: + return False + try: + connection = connection.connection + except AttributeError: + pass + else: + if not connection: + return False + try: + return connection.closed + except AttributeError: # PyGreSQL < 5.0 + return connection._cnx is None + return False + dialect = PGDialect_pygresql |