summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-10-15 00:07:06 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-10-15 00:07:06 +0000
commit7e5e985c0e17a2d300f9aa8633c3610db600f2e2 (patch)
tree553780288c3fc75697d1558187c85f09a9cb42ed
parent6b40f50b87a03172d77abf0e50f42b565f416645 (diff)
downloadsqlalchemy-7e5e985c0e17a2d300f9aa8633c3610db600f2e2.tar.gz
- ForeignKey(Constraint) supports "use_alter=True", to create/drop a foreign key
via ALTER. this allows circular foreign key relationships to be set up.
-rw-r--r--CHANGES2
-rw-r--r--lib/sqlalchemy/ansisql.py58
-rw-r--r--lib/sqlalchemy/databases/mysql.py3
-rw-r--r--lib/sqlalchemy/databases/sqlite.py9
-rw-r--r--lib/sqlalchemy/engine/base.py2
-rw-r--r--lib/sqlalchemy/schema.py7
-rw-r--r--lib/sqlalchemy/sql_util.py2
-rw-r--r--test/orm/cycles.py15
8 files changed, 72 insertions, 26 deletions
diff --git a/CHANGES b/CHANGES
index b9ac78ee9..5d3c1c34d 100644
--- a/CHANGES
+++ b/CHANGES
@@ -40,6 +40,8 @@
indexed. a comparison clause between two pks that are derived from the
same underlying tables (i.e. such as two Alias objects) can be generated
via table1.primary_key==table2.primary_key
+ - ForeignKey(Constraint) supports "use_alter=True", to create/drop a foreign key
+ via ALTER. this allows circular foreign key relationships to be set up.
- append_item() methods removed from Table and Column; preferably
construct Table/Column/related objects inline, but if needed use
append_column(), append_foreign_key(), append_constraint(), etc.
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 208b2f603..b6923c7da 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -606,8 +606,20 @@ class ANSICompiler(sql.Compiled):
def __str__(self):
return self.get_str(self.statement)
-
-class ANSISchemaGenerator(engine.SchemaIterator):
+class ANSISchemaBase(engine.SchemaIterator):
+ def find_alterables(self, tables):
+ alterables = []
+ class FindAlterables(schema.SchemaVisitor):
+ def visit_foreign_key_constraint(self, constraint):
+ if constraint.use_alter and constraint.table in tables:
+ alterables.append(constraint)
+ findalterables = FindAlterables()
+ for table in tables:
+ for c in table.constraints:
+ c.accept_schema_visitor(findalterables)
+ return alterables
+
+class ANSISchemaGenerator(ANSISchemaBase):
def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
super(ANSISchemaGenerator, self).__init__(engine, proxy, **kwargs)
self.checkfirst = checkfirst
@@ -620,11 +632,13 @@ class ANSISchemaGenerator(engine.SchemaIterator):
raise NotImplementedError()
def visit_metadata(self, metadata):
- for table in metadata.table_iterator(reverse=False, tables=self.tables):
- if self.checkfirst and self.dialect.has_table(self.connection, table.name):
- continue
+ collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name))]
+ for table in collection:
table.accept_schema_visitor(self, traverse=False)
-
+ if self.supports_alter():
+ for alterable in self.find_alterables(collection):
+ self.add_foreignkey(alterable)
+
def visit_table(self, table):
for column in table.columns:
if column.default is not None:
@@ -687,9 +701,22 @@ class ANSISchemaGenerator(engine.SchemaIterator):
if constraint.name is not None:
self.append("%s " % constraint.name)
self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
-
+
+ def supports_alter(self):
+ return True
+
def visit_foreign_key_constraint(self, constraint):
+ if constraint.use_alter and self.supports_alter():
+ return
self.append(", \n\t ")
+ self.define_foreign_key(constraint)
+
+ def add_foreignkey(self, constraint):
+ self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table))
+ self.define_foreign_key(constraint)
+ self.execute()
+
+ def define_foreign_key(self, constraint):
if constraint.name is not None:
self.append("CONSTRAINT %s " % constraint.name)
self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
@@ -721,7 +748,7 @@ class ANSISchemaGenerator(engine.SchemaIterator):
string.join([self.preparer.format_column(c) for c in index.columns], ', ')))
self.execute()
-class ANSISchemaDropper(engine.SchemaIterator):
+class ANSISchemaDropper(ANSISchemaBase):
def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
super(ANSISchemaDropper, self).__init__(engine, proxy, **kwargs)
self.checkfirst = checkfirst
@@ -731,14 +758,23 @@ class ANSISchemaDropper(engine.SchemaIterator):
self.dialect = self.engine.dialect
def visit_metadata(self, metadata):
- for table in metadata.table_iterator(reverse=True, tables=self.tables):
- if self.checkfirst and not self.dialect.has_table(self.connection, table.name):
- continue
+ collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or self.dialect.has_table(self.connection, t.name))]
+ if self.supports_alter():
+ for alterable in self.find_alterables(collection):
+ self.drop_foreignkey(alterable)
+ for table in collection:
table.accept_schema_visitor(self, traverse=False)
+ def supports_alter(self):
+ return True
+
def visit_index(self, index):
self.append("\nDROP INDEX " + index.name)
self.execute()
+
+ def drop_foreignkey(self, constraint):
+ self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (self.preparer.format_table(constraint.table), constraint.name))
+ self.execute()
def visit_table(self, table):
for column in table.columns:
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index 2fa7e9227..86b74c364 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -456,6 +456,9 @@ class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
def visit_index(self, index):
self.append("\nDROP INDEX " + index.name + " ON " + index.table.name)
self.execute()
+ def drop_foreignkey(self, constraint):
+ self.append("ALTER TABLE %s DROP FOREIGN KEY %s" % (self.preparer.format_table(constraint.table), constraint.name))
+ self.execute()
class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
def __init__(self, dialect):
diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py
index 90cd66dd3..a4445b1a8 100644
--- a/lib/sqlalchemy/databases/sqlite.py
+++ b/lib/sqlalchemy/databases/sqlite.py
@@ -147,6 +147,8 @@ class SQLiteDialect(ansisql.ANSIDialect):
return SQLiteCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
return SQLiteSchemaGenerator(*args, **kwargs)
+ def schemadropper(self, *args, **kwargs):
+ return SQLiteSchemaDropper(*args, **kwargs)
def preparer(self):
return SQLiteIdentifierPreparer(self)
def create_connect_args(self, url):
@@ -283,6 +285,9 @@ class SQLiteCompiler(ansisql.ANSICompiler):
return ansisql.ANSICompiler.binary_operator_string(self, binary)
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
+ def supports_alter(self):
+ return False
+
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
@@ -303,6 +308,10 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
# else:
# super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint)
+class SQLiteSchemaDropper(ansisql.ANSISchemaDropper):
+ def supports_alter(self):
+ return False
+
class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
def __init__(self, dialect):
super(SQLiteIdentifierPreparer, self).__init__(dialect, omit_schema=True)
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 4ba5e1115..83db06090 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -225,7 +225,7 @@ class Connection(Connectable):
"""when no Transaction is present, this is called after executions to provide "autocommit" behavior."""
# TODO: have the dialect determine if autocommit can be set on the connection directly without this
# extra step
- if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP', statement.lstrip().upper()):
+ if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP|ALTER', statement.lstrip().upper()):
self._commit_impl()
def _autorollback(self):
if not self.in_transaction():
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index 5728d7c37..88d52f075 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -491,7 +491,7 @@ class ForeignKey(SchemaItem):
One or more ForeignKey objects are used within a ForeignKeyConstraint
object which represents the table-level constraint definition."""
- def __init__(self, column, constraint=None, use_alter=False):
+ def __init__(self, column, constraint=None, use_alter=False, name=None):
"""Construct a new ForeignKey object.
"column" can be a schema.Column object representing the relationship,
@@ -507,6 +507,7 @@ class ForeignKey(SchemaItem):
self._column = None
self.constraint = constraint
self.use_alter = use_alter
+ self.name = name
def __repr__(self):
return "ForeignKey(%s)" % repr(self._get_colspec())
@@ -575,7 +576,7 @@ class ForeignKey(SchemaItem):
self.parent = column
if self.constraint is None and isinstance(self.parent.table, Table):
- self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter)
+ self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter, name=self.name)
self.parent.table.append_constraint(self.constraint)
self.constraint._append_fk(self)
@@ -699,6 +700,8 @@ class ForeignKeyConstraint(Constraint):
self.elements = util.Set()
self.onupdate = onupdate
self.ondelete = ondelete
+ if self.name is None and use_alter:
+ raise exceptions.ArgumentError("Alterable ForeignKey/ForeignKeyConstraint requires a name")
self.use_alter = use_alter
def _set_parent(self, table):
self.table = table
diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py
index 94caade68..4935b1add 100644
--- a/lib/sqlalchemy/sql_util.py
+++ b/lib/sqlalchemy/sql_util.py
@@ -40,6 +40,8 @@ class TableCollection(object):
tuples = []
class TVisitor(schema.SchemaVisitor):
def visit_foreign_key(_self, fkey):
+ if fkey.use_alter:
+ return
parent_table = fkey.column.table
if parent_table in self:
child_table = fkey.parent.table
diff --git a/test/orm/cycles.py b/test/orm/cycles.py
index eebe7af75..0ff3abb7b 100644
--- a/test/orm/cycles.py
+++ b/test/orm/cycles.py
@@ -213,28 +213,19 @@ class OneToManyManyToOneTest(AssertMixin):
global ball
ball = Table('ball', metadata,
Column('id', Integer, Sequence('ball_id_seq', optional=True), primary_key=True),
- Column('person_id', Integer),
+ Column('person_id', Integer, ForeignKey('person.id', use_alter=True, name='fk_person_id')),
Column('data', String(30))
)
person = Table('person', metadata,
Column('id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
Column('favorite_ball_id', Integer, ForeignKey('ball.id')),
Column('data', String(30))
-# Column('favorite_ball_id', Integer),
)
- ball.create()
- person.create()
- ball.c.person_id.append_foreign_key(ForeignKey('person.id'))
+ metadata.create_all()
- # make the test more complete for postgres
- if db.engine.__module__.endswith('postgres'):
- db.execute("alter table ball add constraint fk_ball_person foreign key (person_id) references person(id)", {})
def tearDownAll(self):
- if db.engine.__module__.endswith('postgres'):
- db.execute("alter table ball drop constraint fk_ball_person", {})
- person.drop()
- ball.drop()
+ metadata.drop_all()
def tearDown(self):
clear_mappers()