diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-02-11 20:50:41 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-02-11 20:50:41 +0000 |
| commit | 280274812261868e8f665f706cd27e06eaff4302 (patch) | |
| tree | e39e17c4a18469c7f47e5a83b19e5f63eaa7b548 /lib/sqlalchemy | |
| parent | 349c00c97a1931cb28cb199b12af1bde82f5bd1d (diff) | |
| download | sqlalchemy-280274812261868e8f665f706cd27e06eaff4302.tar.gz | |
streamlined engine.schemagenerator and engine.schemadropper methodology
added support for creating PassiveDefault (i.e. regular DEFAULT) on table columns
postgres can reflect default values via information_schema
added unittests for PassiveDefault values getting created, inserted, coming back in result sets
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/ansisql.py | 19 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/information_schema.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 22 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/sqlite.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine.py | 28 | ||||
| -rw-r--r-- | lib/sqlalchemy/schema.py | 10 |
8 files changed, 76 insertions, 47 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 9688cb67b..3b4ae64a7 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -20,11 +20,11 @@ def engine(**params): class ANSISQLEngine(sqlalchemy.engine.SQLEngine): - def schemagenerator(self, proxy, **params): - return ANSISchemaGenerator(proxy, **params) + def schemagenerator(self, **params): + return ANSISchemaGenerator(self, **params) - def schemadropper(self, proxy, **params): - return ANSISchemaDropper(proxy, **params) + def schemadropper(self, **params): + return ANSISchemaDropper(self, **params) def compiler(self, statement, parameters, **kwargs): return ANSICompiler(self, statement, parameters, **kwargs) @@ -492,7 +492,6 @@ class ANSICompiler(sql.Compiled): class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator): - def get_column_specification(self, column, override_pk=False, first_pk=False): raise NotImplementedError() @@ -521,6 +520,16 @@ class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator): def post_create_table(self, table): return '' + def get_column_default_string(self, column): + if isinstance(column.default, schema.PassiveDefault): + if not isinstance(column.default.arg, str): + arg = str(column.default.arg.compile(self.engine)) + else: + arg = column.default.arg + return arg + else: + return None + def visit_column(self, column): pass diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py index c0503c25c..f6dd251cd 100644 --- a/lib/sqlalchemy/databases/information_schema.py +++ b/lib/sqlalchemy/databases/information_schema.py @@ -31,6 +31,7 @@ gen_columns = schema.Table("columns", generic_engine, Column("character_maximum_length", Integer), Column("numeric_precision", Integer), Column("numeric_scale", Integer), + Column("column_default", Integer), schema="information_schema") gen_constraints = schema.Table("table_constraints", generic_engine, @@ -109,15 +110,16 @@ def reflecttable(engine, table, ischema_names, use_mysql=False): row = c.fetchone() if row is None: break -# print "row! " + repr(row) + #print "row! " + repr(row) # continue - (name, type, nullable, charlen, numericprec, numericscale) = ( + (name, type, nullable, charlen, numericprec, numericscale, default) = ( row[columns.c.column_name], row[columns.c.data_type], row[columns.c.is_nullable] == 'YES', row[columns.c.character_maximum_length], row[columns.c.numeric_precision], row[columns.c.numeric_scale], + row[columns.c.column_default] ) args = [] @@ -127,7 +129,10 @@ def reflecttable(engine, table, ischema_names, use_mysql=False): coltype = ischema_names[type] #print "coltype " + repr(coltype) + " args " + repr(args) coltype = coltype(*args) - table.append_item(schema.Column(name, coltype, nullable = nullable)) + colargs= [] + if default is not None: + colargs.append(PassiveDefault(default)) + table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs)) s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True) if not use_mysql: diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 6734274cd..0afac7df3 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -132,8 +132,8 @@ class MySQLEngine(ansisql.ANSISQLEngine): def compiler(self, statement, bindparams, **kwargs): return MySQLCompiler(self, statement, bindparams, **kwargs) - def schemagenerator(self, proxy, **params): - return MySQLSchemaGenerator(proxy, **params) + def schemagenerator(self, **params): + return MySQLSchemaGenerator(self, **params) def get_default_schema_name(self): if not hasattr(self, '_default_schema_name'): @@ -234,6 +234,13 @@ class MySQLTableImpl(sql.TableImpl): self.mysql_engine = mysql_engine class MySQLCompiler(ansisql.ANSICompiler): + + def visit_function(self, func): + if len(func.clauses): + super(MySQLCompiler, self).visit_function(func) + else: + self.strings[func] = func.name + def limit_clause(self, select): text = "" if select.limit is not None: @@ -248,6 +255,9 @@ class MySQLCompiler(ansisql.ANSICompiler): class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, first_pk=False): colspec = column.name + " " + column.type.get_col_spec() + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default if not column.nullable: colspec += " NOT NULL" diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 857b0c2fc..2ce07a3c6 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -104,10 +104,10 @@ class OracleSQLEngine(ansisql.ANSISQLEngine): def compiler(self, statement, bindparams, **kwargs): return OracleCompiler(self, statement, bindparams, use_ansi=self._use_ansi, **kwargs) - def schemagenerator(self, proxy, **params): - return OracleSchemaGenerator(proxy, **params) - def schemadropper(self, proxy, **params): - return OracleSchemaDropper(proxy, **params) + def schemagenerator(self, **params): + return OracleSchemaGenerator(self, **params) + def schemadropper(self, **params): + return OracleSchemaDropper(self, **params) def defaultrunner(self, proxy): return OracleDefaultRunner(self, proxy) @@ -227,6 +227,9 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name colspec += " " + column.type.get_col_spec() + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default if not column.nullable: colspec += " NOT NULL" diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 9122c2afa..5d0a4e172 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -192,11 +192,11 @@ class PGSQLEngine(ansisql.ANSISQLEngine): def compiler(self, statement, bindparams, **kwargs): return PGCompiler(self, statement, bindparams, **kwargs) - def schemagenerator(self, proxy, **params): - return PGSchemaGenerator(proxy, **params) + def schemagenerator(self, **params): + return PGSchemaGenerator(self, **params) - def schemadropper(self, proxy, **params): - return PGSchemaDropper(proxy, **params) + def schemadropper(self, **params): + return PGSchemaDropper(self, **params) def defaultrunner(self, proxy): return PGDefaultRunner(self, proxy) @@ -254,6 +254,12 @@ class PGSQLEngine(ansisql.ANSISQLEngine): class PGCompiler(ansisql.ANSICompiler): + def visit_function(self, func): + if len(func.clauses): + super(PGCompiler, self).visit_function(func) + else: + self.strings[func] = func.name + def visit_insert_column(self, column): # Postgres advises against OID usage and turns it off in 8.1, # effectively making cursor.lastrowid @@ -273,14 +279,16 @@ class PGCompiler(ansisql.ANSICompiler): return text class PGSchemaGenerator(ansisql.ANSISchemaGenerator): + def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name - if isinstance(column.default, schema.PassiveDefault): - colspec += " DEFAULT " + column.default.text - elif column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): colspec += " SERIAL" else: colspec += " " + column.type.get_col_spec() + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default if not column.nullable: colspec += " NOT NULL" diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 83fb00205..5401c350f 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -148,8 +148,8 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): def dbapi(self): return sqlite - def schemagenerator(self, proxy, **params): - return SQLiteSchemaGenerator(proxy, **params) + def schemagenerator(self, **params): + return SQLiteSchemaGenerator(self, **params) def reflecttable(self, table): c = self.execute("PRAGMA table_info(" + table.name + ")", {}) @@ -226,6 +226,10 @@ class SQLiteCompiler(ansisql.ANSICompiler): class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name + " " + column.type.get_col_spec() + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + if not column.nullable: colspec += " NOT NULL" if column.primary_key and not override_pk: diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 29acdc665..aa8e89ca4 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -103,13 +103,13 @@ def engine_descriptors(): class SchemaIterator(schema.SchemaVisitor): """a visitor that can gather text into a buffer and execute the contents of the buffer.""" - def __init__(self, sqlproxy, **params): + def __init__(self, engine, **params): """initializes this SchemaIterator and initializes its buffer. sqlproxy - a callable function returned by SQLEngine.proxy(), which executes a statement plus optional parameters. """ - self.sqlproxy = sqlproxy + self.engine = engine self.buffer = StringIO.StringIO() def append(self, s): @@ -120,7 +120,7 @@ class SchemaIterator(schema.SchemaVisitor): """executes the contents of the SchemaIterator's buffer using its sql proxy and clears out the buffer.""" try: - return self.sqlproxy(self.buffer.getvalue()) + return self.engine.execute(self.buffer.getvalue(), None) finally: self.buffer.truncate(0) @@ -250,21 +250,17 @@ class SQLEngine(schema.SchemaEngine): """returns a sql.text() object for performing literal queries.""" return sql.text(text, engine=self, *args, **kwargs) - def schemagenerator(self, proxy, **params): + def schemagenerator(self, **params): """returns a schema.SchemaVisitor instance that can generate schemas, when it is - invoked to traverse a set of schema objects. The - "proxy" argument is a callable will execute a given string SQL statement - and a dictionary or list of parameters. + invoked to traverse a set of schema objects. schemagenerator is called via the create() method. """ raise NotImplementedError() - def schemadropper(self, proxy, **params): + def schemadropper(self, **params): """returns a schema.SchemaVisitor instance that can drop schemas, when it is - invoked to traverse a set of schema objects. The - "proxy" argument is a callable will execute a given string SQL statement - and a dictionary or list of parameters. + invoked to traverse a set of schema objects. schemagenerator is called via the drop() method. """ @@ -300,11 +296,11 @@ class SQLEngine(schema.SchemaEngine): def create(self, table, **params): """creates a table within this engine's database connection given a schema.Table object.""" - table.accept_visitor(self.schemagenerator(self.proxy(), **params)) + table.accept_visitor(self.schemagenerator(**params)) def drop(self, table, **params): """drops a table within this engine's database connection given a schema.Table object.""" - table.accept_visitor(self.schemadropper(self.proxy(), **params)) + table.accept_visitor(self.schemadropper(**params)) def compile(self, statement, parameters, **kwargs): """given a sql.ClauseElement statement plus optional bind parameters, creates a new @@ -369,12 +365,6 @@ class SQLEngine(schema.SchemaEngine): """implementations might want to put logic here for turning autocommit on/off, etc.""" connection.commit() - def proxy(self, **kwargs): - """provides a callable that will execute the given string statement and parameters. - The statement and parameters should be in the format specific to the particular database; - i.e. named or positional.""" - return lambda s, p = None: self.execute(s, p, **kwargs) - def connection(self): """returns a managed DBAPI connection from this SQLEngine's connection pool.""" return self._pool.connect() diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 01b7c7a11..8e85fb310 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -19,7 +19,7 @@ from sqlalchemy.util import * from sqlalchemy.types import * import copy, re, string -__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'SchemaEngine', 'SchemaVisitor'] +__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] class SchemaItem(object): @@ -418,12 +418,12 @@ class DefaultGenerator(SchemaItem): class PassiveDefault(DefaultGenerator): """a default that takes effect on the database side""" - def __init__(self, text): - self.text = text + def __init__(self, arg): + self.arg = arg def accept_visitor(self, visitor): - return visitor_visit_passive_default(self) + return visitor.visit_passive_default(self) def __repr__(self): - return "PassiveDefault(%s)" % repr(self.text) + return "PassiveDefault(%s)" % repr(self.arg) class ColumnDefault(DefaultGenerator): """A plain default value on a column. this could correspond to a constant, |
