summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql/pygresql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/pygresql.py')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pygresql.py138
1 files changed, 118 insertions, 20 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/pygresql.py b/lib/sqlalchemy/dialects/postgresql/pygresql.py
index 79ad3f188..d30206613 100644
--- a/lib/sqlalchemy/dialects/postgresql/pygresql.py
+++ b/lib/sqlalchemy/dialects/postgresql/pygresql.py
@@ -19,7 +19,9 @@ import re
from ... import exc, processors, util
from ...types import Numeric, JSON as Json
-from .base import PGDialect, _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID
+from ...sql.elements import Null
+from .base import PGDialect, PGCompiler, PGIdentifierPreparer, \
+ _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID
from .hstore import HSTORE
from .json import JSON, JSONB
@@ -30,23 +32,24 @@ class _PGNumeric(Numeric):
return None
def result_processor(self, dialect, coltype):
- oid = coltype.oid
+ if not isinstance(coltype, int):
+ coltype = coltype.oid
if self.asdecimal:
- if oid in _FLOAT_TYPES:
+ if coltype in _FLOAT_TYPES:
return processors.to_decimal_processor_factory(
decimal.Decimal,
self._effective_decimal_return_scale)
- elif oid in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
# PyGreSQL returns Decimal natively for 1700 (numeric)
return None
else:
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype)
else:
- if oid in _FLOAT_TYPES:
+ if coltype in _FLOAT_TYPES:
# PyGreSQL returns float natively for 701 (float8)
return None
- elif oid in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
return processors.to_float
else:
raise exc.InvalidRequestError(
@@ -56,51 +59,121 @@ class _PGNumeric(Numeric):
class _PGHStore(HSTORE):
def bind_processor(self, dialect):
- if not dialect.use_native_hstore:
+ if not dialect.has_native_hstore:
return super(_PGHStore, self).bind_processor(dialect)
+ hstore = dialect.dbapi.Hstore
+ def process(value):
+ if isinstance(value, dict):
+ return hstore(value)
+ return value
+ return process
def result_processor(self, dialect, coltype):
- if not dialect.use_native_hstore:
+ if not dialect.has_native_hstore:
return super(_PGHStore, self).result_processor(dialect, coltype)
class _PGJSON(JSON):
def bind_processor(self, dialect):
- if not dialect.use_native_json:
+ if not dialect.has_native_json:
return super(_PGJSON, self).bind_processor(dialect)
+ json = dialect.dbapi.Json
+
+ def process(value):
+ if value is self.NULL:
+ value = None
+ elif isinstance(value, Null) or (
+ value is None and self.none_as_null):
+ return None
+ if value is None or isinstance(value, (dict, list)):
+ return json(value)
+ return value
+
+ return process
def result_processor(self, dialect, coltype):
- if not dialect.use_native_json:
+ if not dialect.has_native_json:
return super(_PGJSON, self).result_processor(dialect, coltype)
class _PGJSONB(JSONB):
def bind_processor(self, dialect):
- if not dialect.use_native_json:
+ if not dialect.has_native_json:
return super(_PGJSONB, self).bind_processor(dialect)
+ json = dialect.dbapi.Json
+
+ def process(value):
+ if value is self.NULL:
+ value = None
+ elif isinstance(value, Null) or (
+ value is None and self.none_as_null):
+ return None
+ if value is None or isinstance(value, (dict, list)):
+ return json(value)
+ return value
+
+ return process
def result_processor(self, dialect, coltype):
- if not dialect.use_native_json:
+ if not dialect.has_native_json:
return super(_PGJSONB, self).result_processor(dialect, coltype)
class _PGUUID(UUID):
def bind_processor(self, dialect):
- if not dialect.use_native_uuid:
+ if not dialect.has_native_uuid:
return super(_PGUUID, self).bind_processor(dialect)
+ uuid = dialect.dbapi.Uuid
+
+ def process(value):
+ if value is None:
+ return None
+ if isinstance(value, (str, bytes)):
+ if len(value) == 16:
+ return uuid(bytes=value)
+ return uuid(value)
+ if isinstance(value, int):
+ return uuid(int=value)
+ return value
+
+ return process
def result_processor(self, dialect, coltype):
- if not dialect.use_native_uuid:
+ if not dialect.has_native_uuid:
return super(_PGUUID, self).result_processor(dialect, coltype)
+ if not self.as_uuid:
+ def process(value):
+ if value is not None:
+ return str(value)
+ return process
+
+
+class _PGCompiler(PGCompiler):
+
+ def visit_mod_binary(self, binary, operator, **kw):
+ return self.process(binary.left, **kw) + " %% " + \
+ self.process(binary.right, **kw)
+
+ def post_process_text(self, text):
+ return text.replace('%', '%%')
+
+
+class _PGIdentifierPreparer(PGIdentifierPreparer):
+
+ def _escape_identifier(self, value):
+ value = value.replace(self.escape_quote, self.escape_to_quote)
+ return value.replace('%', '%%')
class PGDialect_pygresql(PGDialect):
+
driver = 'pygresql'
- _has_native_hstore = False
+ statement_compiler = _PGCompiler
+ preparer = _PGIdentifierPreparer
@classmethod
def dbapi(cls):
@@ -119,8 +192,7 @@ class PGDialect_pygresql(PGDialect):
}
)
- def __init__(self, use_native_hstore=False,
- use_native_json=False, use_native_uuid=True, **kwargs):
+ def __init__(self, **kwargs):
super(PGDialect_pygresql, self).__init__(**kwargs)
try:
version = self.dbapi.version
@@ -129,9 +201,18 @@ class PGDialect_pygresql(PGDialect):
except (AttributeError, ValueError, TypeError):
version = (0, 0)
self.dbapi_version = version
- self.use_native_hstore = use_native_hstore and version >= (5, 0)
- self.use_native_json = use_native_json and version >= (5, 0)
- self.use_native_uuid = use_native_uuid and version >= (5, 0)
+ if version < (5, 0):
+ has_native_hstore = has_native_json = has_native_uuid = False
+ if version != (0, 0):
+ util.warn("PyGreSQL is only fully supported by SQLAlchemy"
+ " since version 5.0.")
+ else:
+ self.supports_unicode_statements = True
+ self.supports_unicode_binds = True
+ has_native_hstore = has_native_json = has_native_uuid = True
+ self.has_native_hstore = has_native_hstore
+ self.has_native_json = has_native_json
+ self.has_native_uuid = has_native_uuid
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
@@ -141,5 +222,22 @@ class PGDialect_pygresql(PGDialect):
opts.update(url.query)
return [], opts
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.Error):
+ if not connection:
+ return False
+ try:
+ connection = connection.connection
+ except AttributeError:
+ pass
+ else:
+ if not connection:
+ return False
+ try:
+ return connection.closed
+ except AttributeError: # PyGreSQL < 5.0
+ return connection._cnx is None
+ return False
+
dialect = PGDialect_pygresql