diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 190 |
1 files changed, 115 insertions, 75 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b102f0240..da62b1434 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -82,6 +82,7 @@ OPERATORS = { operators.eq: ' = ', operators.concat_op: ' || ', operators.match_op: ' MATCH ', + operators.notmatch_op: ' NOT MATCH ', operators.in_op: ' IN ', operators.notin_op: ' NOT IN ', operators.comma_op: ', ', @@ -247,15 +248,16 @@ class Compiled(object): return self.execute(*multiparams, **params).scalar() -class TypeCompiler(object): - +class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): """Produces DDL specification for TypeEngine objects.""" + ensure_kwarg = 'visit_\w+' + def __init__(self, dialect): self.dialect = dialect - def process(self, type_): - return type_._compiler_dispatch(self) + def process(self, type_, **kw): + return type_._compiler_dispatch(self, **kw) class _CompileLabel(visitors.Visitable): @@ -637,8 +639,9 @@ class SQLCompiler(Compiled): def visit_index(self, index, **kwargs): return index.name - def visit_typeclause(self, typeclause, **kwargs): - return self.dialect.type_compiler.process(typeclause.type) + def visit_typeclause(self, typeclause, **kw): + kw['type_expression'] = typeclause + return self.dialect.type_compiler.process(typeclause.type, **kw) def post_process_text(self, text): return text @@ -862,14 +865,18 @@ class SQLCompiler(Compiled): else: return "%s = 0" % self.process(element.element, **kw) - def visit_binary(self, binary, **kw): + def visit_notmatch_op_binary(self, binary, operator, **kw): + return "NOT %s" % self.visit_binary( + binary, override_operator=operators.match_op) + + def visit_binary(self, binary, override_operator=None, **kw): # don't allow "? = ?" to render if self.ansi_bind_rules and \ isinstance(binary.left, elements.BindParameter) and \ isinstance(binary.right, elements.BindParameter): kw['literal_binds'] = True - operator_ = binary.operator + operator_ = override_operator or binary.operator disp = getattr(self, "visit_%s_binary" % operator_.__name__, None) if disp: return disp(binary, operator_, **kw) @@ -1188,12 +1195,16 @@ class SQLCompiler(Compiled): self, asfrom=True, **kwargs ) + if cte._suffixes: + text += " " + self._generate_prefixes( + cte, cte._suffixes, **kwargs) + self.ctes[cte] = text if asfrom: if cte_alias_name: text = self.preparer.format_alias(cte, cte_alias_name) - text += " AS " + cte_name + text += self.get_render_as_alias_suffix(cte_name) else: return self.preparer.format_alias(cte, cte_name) return text @@ -1212,8 +1223,8 @@ class SQLCompiler(Compiled): elif asfrom: ret = alias.original._compiler_dispatch(self, asfrom=True, **kwargs) + \ - " AS " + \ - self.preparer.format_alias(alias, alias_name) + self.get_render_as_alias_suffix( + self.preparer.format_alias(alias, alias_name)) if fromhints and alias in fromhints: ret = self.format_from_hint_text(ret, alias, @@ -1223,6 +1234,9 @@ class SQLCompiler(Compiled): else: return alias.original._compiler_dispatch(self, **kwargs) + def get_render_as_alias_suffix(self, alias_name_text): + return " AS " + alias_name_text + def _add_to_result_map(self, keyname, name, objects, type_): if not self.dialect.case_sensitive: keyname = keyname.lower() @@ -1549,6 +1563,10 @@ class SQLCompiler(Compiled): compound_index == 0 and toplevel: text = self._render_cte_clause() + text + if select._suffixes: + text += " " + self._generate_prefixes( + select, select._suffixes, **kwargs) + self.stack.pop(-1) if asfrom and parens: @@ -2086,7 +2104,9 @@ class DDLCompiler(Compiled): (table.description, column.name, ce.args[0]) )) - const = self.create_table_constraints(table) + const = self.create_table_constraints( + table, _include_foreign_key_constraints= + create.include_foreign_key_constraints) if const: text += ", \n\t" + const @@ -2110,7 +2130,9 @@ class DDLCompiler(Compiled): return text - def create_table_constraints(self, table): + def create_table_constraints( + self, table, + _include_foreign_key_constraints=None): # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) @@ -2118,8 +2140,15 @@ class DDLCompiler(Compiled): if table.primary_key: constraints.append(table.primary_key) + all_fkcs = table.foreign_key_constraints + if _include_foreign_key_constraints is not None: + omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints) + else: + omit_fkcs = set() + constraints.extend([c for c in table._sorted_constraints - if c is not table.primary_key]) + if c is not table.primary_key and + c not in omit_fkcs]) return ", \n\t".join( p for p in @@ -2214,15 +2243,26 @@ class DDLCompiler(Compiled): self.preparer.format_sequence(drop.element) def visit_drop_constraint(self, drop): + constraint = drop.element + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + else: + formatted_name = None + + if formatted_name is None: + raise exc.CompileError( + "Can't emit DROP CONSTRAINT for constraint %r; " + "it has no name" % drop.element) return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( self.preparer.format_table(drop.element.table), - self.preparer.format_constraint(drop.element), + formatted_name, drop.cascade and " CASCADE" or "" ) def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process(column.type) + self.dialect.type_compiler.process( + column.type, type_expression=column) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -2346,13 +2386,13 @@ class DDLCompiler(Compiled): class GenericTypeCompiler(TypeCompiler): - def visit_FLOAT(self, type_): + def visit_FLOAT(self, type_, **kw): return "FLOAT" - def visit_REAL(self, type_): + def visit_REAL(self, type_, **kw): return "REAL" - def visit_NUMERIC(self, type_): + def visit_NUMERIC(self, type_, **kw): if type_.precision is None: return "NUMERIC" elif type_.scale is None: @@ -2363,7 +2403,7 @@ class GenericTypeCompiler(TypeCompiler): {'precision': type_.precision, 'scale': type_.scale} - def visit_DECIMAL(self, type_): + def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return "DECIMAL" elif type_.scale is None: @@ -2374,31 +2414,31 @@ class GenericTypeCompiler(TypeCompiler): {'precision': type_.precision, 'scale': type_.scale} - def visit_INTEGER(self, type_): + def visit_INTEGER(self, type_, **kw): return "INTEGER" - def visit_SMALLINT(self, type_): + def visit_SMALLINT(self, type_, **kw): return "SMALLINT" - def visit_BIGINT(self, type_): + def visit_BIGINT(self, type_, **kw): return "BIGINT" - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): return 'TIMESTAMP' - def visit_DATETIME(self, type_): + def visit_DATETIME(self, type_, **kw): return "DATETIME" - def visit_DATE(self, type_): + def visit_DATE(self, type_, **kw): return "DATE" - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): return "TIME" - def visit_CLOB(self, type_): + def visit_CLOB(self, type_, **kw): return "CLOB" - def visit_NCLOB(self, type_): + def visit_NCLOB(self, type_, **kw): return "NCLOB" def _render_string_type(self, type_, name): @@ -2410,91 +2450,91 @@ class GenericTypeCompiler(TypeCompiler): text += ' COLLATE "%s"' % type_.collation return text - def visit_CHAR(self, type_): + def visit_CHAR(self, type_, **kw): return self._render_string_type(type_, "CHAR") - def visit_NCHAR(self, type_): + def visit_NCHAR(self, type_, **kw): return self._render_string_type(type_, "NCHAR") - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): return self._render_string_type(type_, "VARCHAR") - def visit_NVARCHAR(self, type_): + def visit_NVARCHAR(self, type_, **kw): return self._render_string_type(type_, "NVARCHAR") - def visit_TEXT(self, type_): + def visit_TEXT(self, type_, **kw): return self._render_string_type(type_, "TEXT") - def visit_BLOB(self, type_): + def visit_BLOB(self, type_, **kw): return "BLOB" - def visit_BINARY(self, type_): + def visit_BINARY(self, type_, **kw): return "BINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_VARBINARY(self, type_): + def visit_VARBINARY(self, type_, **kw): return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_BOOLEAN(self, type_): + def visit_BOOLEAN(self, type_, **kw): return "BOOLEAN" - def visit_large_binary(self, type_): - return self.visit_BLOB(type_) + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_, **kw) - def visit_boolean(self, type_): - return self.visit_BOOLEAN(type_) + def visit_boolean(self, type_, **kw): + return self.visit_BOOLEAN(type_, **kw) - def visit_time(self, type_): - return self.visit_TIME(type_) + def visit_time(self, type_, **kw): + return self.visit_TIME(type_, **kw) - def visit_datetime(self, type_): - return self.visit_DATETIME(type_) + def visit_datetime(self, type_, **kw): + return self.visit_DATETIME(type_, **kw) - def visit_date(self, type_): - return self.visit_DATE(type_) + def visit_date(self, type_, **kw): + return self.visit_DATE(type_, **kw) - def visit_big_integer(self, type_): - return self.visit_BIGINT(type_) + def visit_big_integer(self, type_, **kw): + return self.visit_BIGINT(type_, **kw) - def visit_small_integer(self, type_): - return self.visit_SMALLINT(type_) + def visit_small_integer(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) - def visit_integer(self, type_): - return self.visit_INTEGER(type_) + def visit_integer(self, type_, **kw): + return self.visit_INTEGER(type_, **kw) - def visit_real(self, type_): - return self.visit_REAL(type_) + def visit_real(self, type_, **kw): + return self.visit_REAL(type_, **kw) - def visit_float(self, type_): - return self.visit_FLOAT(type_) + def visit_float(self, type_, **kw): + return self.visit_FLOAT(type_, **kw) - def visit_numeric(self, type_): - return self.visit_NUMERIC(type_) + def visit_numeric(self, type_, **kw): + return self.visit_NUMERIC(type_, **kw) - def visit_string(self, type_): - return self.visit_VARCHAR(type_) + def visit_string(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) - def visit_unicode(self, type_): - return self.visit_VARCHAR(type_) + def visit_unicode(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) - def visit_text(self, type_): - return self.visit_TEXT(type_) + def visit_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) - def visit_unicode_text(self, type_): - return self.visit_TEXT(type_) + def visit_unicode_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) - def visit_enum(self, type_): - return self.visit_VARCHAR(type_) + def visit_enum(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) - def visit_null(self, type_): + def visit_null(self, type_, **kw): raise exc.CompileError("Can't generate DDL for %r; " "did you forget to specify a " "type on this Column?" % type_) - def visit_type_decorator(self, type_): - return self.process(type_.type_engine(self.dialect)) + def visit_type_decorator(self, type_, **kw): + return self.process(type_.type_engine(self.dialect), **kw) - def visit_user_defined(self, type_): - return type_.get_col_spec() + def visit_user_defined(self, type_, **kw): + return type_.get_col_spec(**kw) class IdentifierPreparer(object): |