summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2010-12-15 12:44:37 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2010-12-15 12:44:37 -0500
commit5a832a49e37ca9259fbad286335367927d0ec60e (patch)
tree2654979a0c3fdb048588850276e827d93df450e5 /lib/sqlalchemy
parentbff2f6f3fbb0450cb9d0d09a25845a437c3df85e (diff)
downloadsqlalchemy-5a832a49e37ca9259fbad286335367927d0ec60e.tar.gz
- an approach I like better, remove most adapt() methods and use a generic
copier - mssql reflection fix, but this will come in again from the tip merge
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/access/base.py5
-rw-r--r--lib/sqlalchemy/dialects/maxdb/base.py10
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py56
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py28
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py4
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py31
-rw-r--r--lib/sqlalchemy/types.py39
-rw-r--r--lib/sqlalchemy/util/__init__.py3
-rw-r--r--lib/sqlalchemy/util/langhelpers.py13
9 files changed, 60 insertions, 129 deletions
diff --git a/lib/sqlalchemy/dialects/access/base.py b/lib/sqlalchemy/dialects/access/base.py
index 75ea91287..cf35b3e0a 100644
--- a/lib/sqlalchemy/dialects/access/base.py
+++ b/lib/sqlalchemy/dialects/access/base.py
@@ -50,15 +50,10 @@ class AcSmallInteger(types.SmallInteger):
return "SMALLINT"
class AcDateTime(types.DateTime):
- def __init__(self, *a, **kw):
- super(AcDateTime, self).__init__(False)
-
def get_col_spec(self):
return "DATETIME"
class AcDate(types.Date):
- def __init__(self, *a, **kw):
- super(AcDate, self).__init__(False)
def get_col_spec(self):
return "DATETIME"
diff --git a/lib/sqlalchemy/dialects/maxdb/base.py b/lib/sqlalchemy/dialects/maxdb/base.py
index 9a1e10f51..3d45bb670 100644
--- a/lib/sqlalchemy/dialects/maxdb/base.py
+++ b/lib/sqlalchemy/dialects/maxdb/base.py
@@ -116,15 +116,13 @@ class _StringType(sqltypes.String):
class MaxString(_StringType):
_type = 'VARCHAR'
- def __init__(self, *a, **kw):
- super(MaxString, self).__init__(*a, **kw)
-
class MaxUnicode(_StringType):
_type = 'VARCHAR'
def __init__(self, length=None, **kw):
- super(MaxUnicode, self).__init__(length=length, encoding='unicode')
+ kw['encoding'] = 'unicode'
+ super(MaxUnicode, self).__init__(length=length, **kw)
class MaxChar(_StringType):
@@ -134,8 +132,8 @@ class MaxChar(_StringType):
class MaxText(_StringType):
_type = 'LONG'
- def __init__(self, *a, **kw):
- super(MaxText, self).__init__(*a, **kw)
+ def __init__(self, length=None, **kw):
+ super(MaxText, self).__init__(length, **kw)
def get_col_spec(self):
spec = 'LONG'
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 4c0a00890..c5f891fb7 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -280,16 +280,15 @@ class _StringType(object):
class TEXT(_StringType, sqltypes.TEXT):
"""MSSQL TEXT type, for variable-length text up to 2^31 characters."""
- def __init__(self, *args, **kw):
+ def __init__(self, length=None, collation=None, **kw):
"""Construct a TEXT.
:param collation: Optional, a column-level collation for this string
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- collation = kw.pop('collation', None)
_StringType.__init__(self, collation)
- sqltypes.Text.__init__(self, *args, **kw)
+ sqltypes.Text.__init__(self, length, **kw)
class NTEXT(_StringType, sqltypes.UnicodeText):
"""MSSQL NTEXT type, for variable-length unicode text up to 2^30
@@ -297,24 +296,22 @@ class NTEXT(_StringType, sqltypes.UnicodeText):
__visit_name__ = 'NTEXT'
- def __init__(self, *args, **kwargs):
+ def __init__(self, length=None, collation=None, **kw):
"""Construct a NTEXT.
:param collation: Optional, a column-level collation for this string
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- collation = kwargs.pop('collation', None)
_StringType.__init__(self, collation)
- length = kwargs.pop('length', None)
- sqltypes.UnicodeText.__init__(self, length, **kwargs)
+ sqltypes.UnicodeText.__init__(self, length, **kw)
class VARCHAR(_StringType, sqltypes.VARCHAR):
"""MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum
of 8,000 characters."""
- def __init__(self, *args, **kw):
+ def __init__(self, length=None, collation=None, **kw):
"""Construct a VARCHAR.
:param length: Optinal, maximum data length, in characters.
@@ -335,16 +332,15 @@ class VARCHAR(_StringType, sqltypes.VARCHAR):
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- collation = kw.pop('collation', None)
_StringType.__init__(self, collation)
- sqltypes.VARCHAR.__init__(self, *args, **kw)
+ sqltypes.VARCHAR.__init__(self, length, **kw)
class NVARCHAR(_StringType, sqltypes.NVARCHAR):
"""MSSQL NVARCHAR type.
For variable-length unicode character data up to 4,000 characters."""
- def __init__(self, *args, **kw):
+ def __init__(self, length=None, collation=None, **kw):
"""Construct a NVARCHAR.
:param length: Optional, Maximum data length, in characters.
@@ -353,15 +349,14 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR):
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- collation = kw.pop('collation', None)
_StringType.__init__(self, collation)
- sqltypes.NVARCHAR.__init__(self, *args, **kw)
+ sqltypes.NVARCHAR.__init__(self, length, **kw)
class CHAR(_StringType, sqltypes.CHAR):
"""MSSQL CHAR type, for fixed-length non-Unicode data with a maximum
of 8,000 characters."""
- def __init__(self, *args, **kw):
+ def __init__(self, length=None, collation=None, **kw):
"""Construct a CHAR.
:param length: Optinal, maximum data length, in characters.
@@ -382,16 +377,15 @@ class CHAR(_StringType, sqltypes.CHAR):
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- collation = kw.pop('collation', None)
_StringType.__init__(self, collation)
- sqltypes.CHAR.__init__(self, *args, **kw)
+ sqltypes.CHAR.__init__(self, length, **kw)
class NCHAR(_StringType, sqltypes.NCHAR):
"""MSSQL NCHAR type.
For fixed-length unicode character data up to 4,000 characters."""
- def __init__(self, *args, **kw):
+ def __init__(self, length=None, collation=None, **kw):
"""Construct an NCHAR.
:param length: Optional, Maximum data length, in characters.
@@ -400,9 +394,8 @@ class NCHAR(_StringType, sqltypes.NCHAR):
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- collation = kw.pop('collation', None)
_StringType.__init__(self, collation)
- sqltypes.NCHAR.__init__(self, *args, **kw)
+ sqltypes.NCHAR.__init__(self, length, **kw)
class IMAGE(sqltypes.LargeBinary):
__visit_name__ = 'IMAGE'
@@ -1150,8 +1143,8 @@ class MSDialect(default.DefaultDialect):
"and sch.name=:schname "
"and ind.is_primary_key=0",
bindparams=[
- sql.bindparam('tabname', tablename, sqltypes.Unicode),
- sql.bindparam('schname', current_schema, sqltypes.Unicode)
+ sql.bindparam('tabname', tablename, sqltypes.String(convert_unicode=True)),
+ sql.bindparam('schname', current_schema, sqltypes.String(convert_unicode=True))
]
)
)
@@ -1163,16 +1156,19 @@ class MSDialect(default.DefaultDialect):
'column_names':[]
}
rp = connection.execute(
- sql.text("select ind_col.index_id, col.name from sys.columns as col "
- "join sys.index_columns as ind_col on "
- "ind_col.column_id=col.column_id "
- "join sys.tables as tab on tab.object_id=col.object_id "
- "join sys.schemas as sch on sch.schema_id=tab.schema_id "
- "where tab.name=:tabname "
- "and sch.name=:schname",
+ sql.text(
+ "select ind_col.index_id, ind_col.object_id, col.name "
+ "from sys.columns as col "
+ "join sys.tables as tab on tab.object_id=col.object_id "
+ "join sys.index_columns as ind_col on "
+ "(ind_col.column_id=col.column_id and "
+ "ind_col.object_id=tab.object_id) "
+ "join sys.schemas as sch on sch.schema_id=tab.schema_id "
+ "where tab.name=:tabname "
+ "and sch.name=:schname",
bindparams=[
- sql.bindparam('tabname', tablename, sqltypes.Unicode),
- sql.bindparam('schname', current_schema, sqltypes.Unicode)
+ sql.bindparam('tabname', tablename, sqltypes.String(convert_unicode=True)),
+ sql.bindparam('schname', current_schema, sqltypes.String(convert_unicode=True))
]),
)
for row in rp:
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 5c3289bfb..528e94965 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -233,17 +233,11 @@ SET_RE = re.compile(
class _NumericType(object):
"""Base for MySQL numeric types."""
- def __init__(self, **kw):
- self.unsigned = kw.pop('unsigned', False)
- self.zerofill = kw.pop('zerofill', False)
+ def __init__(self, unsigned=False, zerofill=False, **kw):
+ self.unsigned = unsigned
+ self.zerofill = zerofill
super(_NumericType, self).__init__(**kw)
- def adapt(self, typeimpl, **kw):
- return super(_NumericType, self).adapt(
- typeimpl,
- unsigned=self.unsigned,
- zerofill=self.zerofill)
-
class _FloatType(_NumericType, sqltypes.Float):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
if isinstance(self, (REAL, DOUBLE)) and \
@@ -263,11 +257,6 @@ class _IntegerType(_NumericType, sqltypes.Integer):
self.display_width = display_width
super(_IntegerType, self).__init__(**kw)
- def adapt(self, typeimpl, **kw):
- return super(_IntegerType, self).adapt(
- typeimpl,
- display_width=self.display_width)
-
class _StringType(sqltypes.String):
"""Base for MySQL string types."""
@@ -288,17 +277,6 @@ class _StringType(sqltypes.String):
self.national = national
super(_StringType, self).__init__(**kw)
- def adapt(self, typeimpl, **kw):
- return super(_StringType, self).adapt(
- typeimpl,
- charset=self.charset,
- collation=self.collation,
- ascii=self.ascii,
- binary=self.binary,
- national=self.national,
- **kw
- )
-
def __repr__(self):
attributes = inspect.getargspec(self.__init__)[0][1:]
attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 256972696..3d97b504e 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -215,10 +215,6 @@ class INTERVAL(sqltypes.TypeEngine):
return INTERVAL(day_precision=interval.day_precision,
second_precision=interval.second_precision)
- def adapt(self, impltype):
- return impltype(day_precision=self.day_precision,
- second_precision=self.second_precision)
-
@property
def _type_affinity(self):
return sqltypes.Interval
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index c9920c930..72b58a71c 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -133,23 +133,12 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
super(TIMESTAMP, self).__init__(timezone=timezone)
self.precision = precision
- def adapt(self, impltype, **kw):
- return impltype(
- precision=self.precision,
- timezone=self.timezone,
- **kw)
class TIME(sqltypes.TIME):
def __init__(self, timezone=False, precision=None):
super(TIME, self).__init__(timezone=timezone)
self.precision = precision
- def adapt(self, impltype, **kw):
- return impltype(
- precision=self.precision,
- timezone=self.timezone,
- **kw)
-
class INTERVAL(sqltypes.TypeEngine):
"""Postgresql INTERVAL type.
@@ -161,9 +150,6 @@ class INTERVAL(sqltypes.TypeEngine):
def __init__(self, precision=None):
self.precision = precision
- def adapt(self, impltype):
- return impltype(self.precision)
-
@classmethod
def _adapt_from_generic_interval(cls, interval):
return INTERVAL(precision=interval.second_precision)
@@ -176,6 +162,9 @@ PGInterval = INTERVAL
class BIT(sqltypes.TypeEngine):
__visit_name__ = 'BIT'
+ def __init__(self, length=1):
+ self.length= length
+
PGBit = BIT
class UUID(sqltypes.TypeEngine):
@@ -226,9 +215,6 @@ class UUID(sqltypes.TypeEngine):
else:
return None
- def adapt(self, impltype, **kw):
- return impltype(as_uuid=self.as_uuid, **kw)
-
PGUuid = UUID
class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
@@ -300,13 +286,6 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
def is_mutable(self):
return self.mutable
- def adapt(self, impltype):
- return impltype(
- self.item_type,
- mutable=self.mutable,
- as_tuple=self.as_tuple
- )
-
def bind_processor(self, dialect):
item_proc = self.item_type.dialect_impl(dialect).bind_processor(dialect)
if item_proc:
@@ -647,7 +626,7 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
return "INTERVAL"
def visit_BIT(self, type_):
- return "BIT"
+ return "BIT(%d)" % type_.length
def visit_UUID(self, type_):
return "UUID"
@@ -1102,7 +1081,7 @@ class PGDialect(default.DefaultDialect):
elif attype == 'double precision':
args = (53, )
elif attype == 'integer':
- args = (32, 0)
+ args = ()
elif attype in ('timestamp with time zone',
'time with time zone'):
kwargs['timezone'] = True
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index 447938461..f5df02367 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -182,7 +182,7 @@ class TypeEngine(AbstractType):
return dialect.type_descriptor(self)
def adapt(self, cls, **kw):
- return cls(**kw)
+ return util.constructor_copy(self, cls, **kw)
def _coerce_compared_value(self, op, value):
_coerced_type = _type_map.get(type(value), NULLTYPE)
@@ -221,7 +221,7 @@ class TypeEngine(AbstractType):
encode('ascii', 'backslashreplace')
# end Py2K
- def __init__(self, *args, **kwargs):
+ def __init__(self):
# supports getargspec of the __init__ method
# used by generic __repr__
pass
@@ -642,6 +642,9 @@ def adapt_type(typeobj, colspecs):
return typeobj
return typeobj.adapt(impltype)
+
+
+
class NullType(TypeEngine):
"""An unknown type.
@@ -788,15 +791,6 @@ class String(Concatenable, TypeEngine):
self.unicode_error = unicode_error
self._warn_on_bytestring = _warn_on_bytestring
- def adapt(self, impltype, **kw):
- return impltype(
- length=self.length,
- convert_unicode=self.convert_unicode,
- unicode_error=self.unicode_error,
- _warn_on_bytestring=self._warn_on_bytestring,
- **kw
- )
-
def bind_processor(self, dialect):
if self.convert_unicode or dialect.convert_unicode:
if dialect.supports_unicode_binds and \
@@ -816,10 +810,11 @@ class String(Concatenable, TypeEngine):
return None
else:
encoder = codecs.getencoder(dialect.encoding)
+ warn_on_bytestring = self._warn_on_bytestring
def process(value):
if isinstance(value, unicode):
return encoder(value, self.unicode_error)[0]
- elif value is not None:
+ elif warn_on_bytestring and value is not None:
util.warn("Unicode type received non-unicode bind "
"param value")
return value
@@ -1092,13 +1087,6 @@ class Numeric(_DateAffinity, TypeEngine):
self.scale = scale
self.asdecimal = asdecimal
- def adapt(self, impltype, **kw):
- return impltype(
- precision=self.precision,
- scale=self.scale,
- asdecimal=self.asdecimal,
- **kw)
-
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
@@ -1190,10 +1178,6 @@ class Float(Numeric):
self.precision = precision
self.asdecimal = asdecimal
- def adapt(self, impltype, **kw):
- return impltype(precision=self.precision,
- asdecimal=self.asdecimal, **kw)
-
def result_processor(self, dialect, coltype):
if self.asdecimal:
return processors.to_decimal_processor_factory(decimal.Decimal)
@@ -1240,9 +1224,6 @@ class DateTime(_DateAffinity, TypeEngine):
def __init__(self, timezone=False):
self.timezone = timezone
- def adapt(self, impltype, **kw):
- return impltype(timezone=self.timezone, **kw)
-
def get_dbapi_type(self, dbapi):
return dbapi.DATETIME
@@ -1300,9 +1281,6 @@ class Time(_DateAffinity,TypeEngine):
def __init__(self, timezone=False):
self.timezone = timezone
- def adapt(self, impltype, **kw):
- return impltype(timezone=self.timezone, **kw)
-
def get_dbapi_type(self, dbapi):
return dbapi.DATETIME
@@ -1362,9 +1340,6 @@ class _Binary(TypeEngine):
else:
return super(_Binary, self)._coerce_compared_value(op, value)
- def adapt(self, impltype, **kw):
- return impltype(length=self.length, **kw)
-
def get_dbapi_type(self, dbapi):
return dbapi.BINARY
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index 9119e35b7..ae1eb3ac5 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -24,7 +24,8 @@ from langhelpers import iterate_attributes, class_hierarchy, \
reset_memoized, group_expirable_memoized_property, importlater, \
monkeypatch_proxied_specials, asbool, bool_or_str, coerce_kw_type,\
duck_type_collection, assert_arg_type, symbol, dictlike_iteritems,\
- classproperty, set_creation_order, warn_exception, warn, NoneType
+ classproperty, set_creation_order, warn_exception, warn, NoneType,\
+ constructor_copy
from deprecations import warn_deprecated, warn_pending_deprecation, \
deprecated, pending_deprecation
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index d85793ee0..945e2a6bd 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -516,6 +516,19 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True):
else:
kw[key] = type_(kw[key])
+
+def constructor_copy(obj, cls, **kw):
+ """Instantiate cls using the __dict__ of obj as constructor arguments.
+
+ Uses inspect to match the named arguments of ``cls``.
+
+ """
+
+ names = get_cls_kwargs(cls)
+ kw.update((k, obj.__dict__[k]) for k in names if k in obj.__dict__)
+ return cls(**kw)
+
+
def duck_type_collection(specimen, default=None):
"""Given an instance or class, guess if it is or is acting as one of
the basic collection types: list, set and dict. If the __emulates__