summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/mssql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-01-16 20:03:33 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2015-01-16 20:03:33 -0500
commitf3a892a3ef666e299107a990bf4eae7ed9a953ae (patch)
tree01c0bbb71be7b397fd2f91b406c3ae7889b2306d /lib/sqlalchemy/dialects/mssql
parent79fa69f1f37fdbc0dfec6bdea1e07f52bfe18f7b (diff)
downloadsqlalchemy-f3a892a3ef666e299107a990bf4eae7ed9a953ae.tar.gz
- Custom dialects that implement :class:`.GenericTypeCompiler` can
now be constructed such that the visit methods receive an indication of the owning expression object, if any. Any visit method that accepts keyword arguments (e.g. ``**kw``) will in most cases receive a keyword argument ``type_expression``, referring to the expression object that the type is contained within. For columns in DDL, the dialect's compiler class may need to alter its ``get_column_specification()`` method to support this as well. The ``UserDefinedType.get_col_spec()`` method will also receive ``type_expression`` if it provides ``**kw`` in its argument signature. fixes #3074
Diffstat (limited to 'lib/sqlalchemy/dialects/mssql')
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py82
1 files changed, 42 insertions, 40 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 5d84975c0..92d7e4ab3 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -694,7 +694,6 @@ ischema_names = {
class MSTypeCompiler(compiler.GenericTypeCompiler):
-
def _extend(self, spec, type_, length=None):
"""Extend a string-type declaration with standard SQL
COLLATE annotations.
@@ -715,115 +714,115 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
return ' '.join([c for c in (spec, collation)
if c is not None])
- def visit_FLOAT(self, type_):
+ def visit_FLOAT(self, type_, **kw):
precision = getattr(type_, 'precision', None)
if precision is None:
return "FLOAT"
else:
return "FLOAT(%(precision)s)" % {'precision': precision}
- def visit_TINYINT(self, type_):
+ def visit_TINYINT(self, type_, **kw):
return "TINYINT"
- def visit_DATETIMEOFFSET(self, type_):
+ def visit_DATETIMEOFFSET(self, type_, **kw):
if type_.precision:
return "DATETIMEOFFSET(%s)" % type_.precision
else:
return "DATETIMEOFFSET"
- def visit_TIME(self, type_):
+ def visit_TIME(self, type_, **kw):
precision = getattr(type_, 'precision', None)
if precision:
return "TIME(%s)" % precision
else:
return "TIME"
- def visit_DATETIME2(self, type_):
+ def visit_DATETIME2(self, type_, **kw):
precision = getattr(type_, 'precision', None)
if precision:
return "DATETIME2(%s)" % precision
else:
return "DATETIME2"
- def visit_SMALLDATETIME(self, type_):
+ def visit_SMALLDATETIME(self, type_, **kw):
return "SMALLDATETIME"
- def visit_unicode(self, type_):
- return self.visit_NVARCHAR(type_)
+ def visit_unicode(self, type_, **kw):
+ return self.visit_NVARCHAR(type_, **kw)
- def visit_text(self, type_):
+ def visit_text(self, type_, **kw):
if self.dialect.deprecate_large_types:
- return self.visit_VARCHAR(type_)
+ return self.visit_VARCHAR(type_, **kw)
else:
- return self.visit_TEXT(type_)
+ return self.visit_TEXT(type_, **kw)
- def visit_unicode_text(self, type_):
+ def visit_unicode_text(self, type_, **kw):
if self.dialect.deprecate_large_types:
- return self.visit_NVARCHAR(type_)
+ return self.visit_NVARCHAR(type_, **kw)
else:
- return self.visit_NTEXT(type_)
+ return self.visit_NTEXT(type_, **kw)
- def visit_NTEXT(self, type_):
+ def visit_NTEXT(self, type_, **kw):
return self._extend("NTEXT", type_)
- def visit_TEXT(self, type_):
+ def visit_TEXT(self, type_, **kw):
return self._extend("TEXT", type_)
- def visit_VARCHAR(self, type_):
+ def visit_VARCHAR(self, type_, **kw):
return self._extend("VARCHAR", type_, length=type_.length or 'max')
- def visit_CHAR(self, type_):
+ def visit_CHAR(self, type_, **kw):
return self._extend("CHAR", type_)
- def visit_NCHAR(self, type_):
+ def visit_NCHAR(self, type_, **kw):
return self._extend("NCHAR", type_)
- def visit_NVARCHAR(self, type_):
+ def visit_NVARCHAR(self, type_, **kw):
return self._extend("NVARCHAR", type_, length=type_.length or 'max')
- def visit_date(self, type_):
+ def visit_date(self, type_, **kw):
if self.dialect.server_version_info < MS_2008_VERSION:
- return self.visit_DATETIME(type_)
+ return self.visit_DATETIME(type_, **kw)
else:
- return self.visit_DATE(type_)
+ return self.visit_DATE(type_, **kw)
- def visit_time(self, type_):
+ def visit_time(self, type_, **kw):
if self.dialect.server_version_info < MS_2008_VERSION:
- return self.visit_DATETIME(type_)
+ return self.visit_DATETIME(type_, **kw)
else:
- return self.visit_TIME(type_)
+ return self.visit_TIME(type_, **kw)
- def visit_large_binary(self, type_):
+ def visit_large_binary(self, type_, **kw):
if self.dialect.deprecate_large_types:
- return self.visit_VARBINARY(type_)
+ return self.visit_VARBINARY(type_, **kw)
else:
- return self.visit_IMAGE(type_)
+ return self.visit_IMAGE(type_, **kw)
- def visit_IMAGE(self, type_):
+ def visit_IMAGE(self, type_, **kw):
return "IMAGE"
- def visit_VARBINARY(self, type_):
+ def visit_VARBINARY(self, type_, **kw):
return self._extend(
"VARBINARY",
type_,
length=type_.length or 'max')
- def visit_boolean(self, type_):
+ def visit_boolean(self, type_, **kw):
return self.visit_BIT(type_)
- def visit_BIT(self, type_):
+ def visit_BIT(self, type_, **kw):
return "BIT"
- def visit_MONEY(self, type_):
+ def visit_MONEY(self, type_, **kw):
return "MONEY"
- def visit_SMALLMONEY(self, type_):
+ def visit_SMALLMONEY(self, type_, **kw):
return 'SMALLMONEY'
- def visit_UNIQUEIDENTIFIER(self, type_):
+ def visit_UNIQUEIDENTIFIER(self, type_, **kw):
return "UNIQUEIDENTIFIER"
- def visit_SQL_VARIANT(self, type_):
+ def visit_SQL_VARIANT(self, type_, **kw):
return 'SQL_VARIANT'
@@ -1240,8 +1239,11 @@ class MSSQLStrictCompiler(MSSQLCompiler):
class MSDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs):
- colspec = (self.preparer.format_column(column) + " "
- + self.dialect.type_compiler.process(column.type))
+ colspec = (
+ self.preparer.format_column(column) + " "
+ + self.dialect.type_compiler.process(
+ column.type, type_expression=column)
+ )
if column.nullable is not None:
if not column.nullable or column.primary_key or \