diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-05-25 14:20:23 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-05-25 14:20:23 +0000 |
| commit | bb79e2e871d0a4585164c1a6ed626d96d0231975 (patch) | |
| tree | 6d457ba6c36c408b45db24ec3c29e147fe7504ff /lib/sqlalchemy/databases | |
| parent | 4fc3a0648699c2b441251ba4e1d37a9107bd1986 (diff) | |
| download | sqlalchemy-bb79e2e871d0a4585164c1a6ed626d96d0231975.tar.gz | |
merged 0.2 branch into trunk; 0.1 now in sqlalchemy/branches/rel_0_1
Diffstat (limited to 'lib/sqlalchemy/databases')
| -rw-r--r-- | lib/sqlalchemy/databases/firebird.py | 41 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/information_schema.py | 58 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 87 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 90 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 134 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/sqlite.py | 88 |
7 files changed, 247 insertions, 253 deletions
diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 99ef9eb9f..7dc48a54a 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -5,7 +5,7 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sys, StringIO, string, datetime +import sys, StringIO, string import sqlalchemy.sql as sql import sqlalchemy.schema as schema @@ -30,16 +30,6 @@ class FBSmallInteger(sqltypes.Smallinteger): class FBDateTime(sqltypes.DateTime): def get_col_spec(self): return "DATE" - def convert_bind_param(self, value, engine): - if value is not None: - if isinstance(value, datetime.datetime): - seconds = float(str(value.second) + "." - + str(value.microsecond)) - return kinterbasdb.date_conv_out((value.year, value.month, value.day, - value.hour, value.minute, seconds)) - return kinterbasdb.timestamp_conv_in(value) - else: - return None class FBText(sqltypes.TEXT): def get_col_spec(self): return "BLOB SUB_TYPE 2" @@ -84,12 +74,13 @@ def descriptor(): ]} class FBSQLEngine(ansisql.ANSISQLEngine): - def __init__(self, opts, module=None, use_oids=False, **params): + def __init__(self, opts, use_ansi = True, module = None, **params): + self._use_ansi = use_ansi + self.opts = opts or {} if module is None: self.module = kinterbasdb else: self.module = module - self.opts = self._translate_connect_args(('host', 'database', 'user', 'password'), opts) ansisql.ANSISQLEngine.__init__(self, **params) def do_commit(self, connection): @@ -111,7 +102,7 @@ class FBSQLEngine(ansisql.ANSISQLEngine): return self.context.last_inserted_ids def compiler(self, statement, bindparams, **kwargs): - return FBCompiler(statement, bindparams, engine=self, **kwargs) + return FBCompiler(statement, bindparams, engine=self, use_ansi=self._use_ansi, **kwargs) def schemagenerator(self, **params): return FBSchemaGenerator(self, **params) @@ -197,6 +188,21 @@ class FBSQLEngine(ansisql.ANSISQLEngine): class FBCompiler(ansisql.ANSICompiler): """firebird compiler modifies the lexical structure of Select statements to work under non-ANSI configured Firebird databases, if the use_ansi flag is False.""" + + def __init__(self, engine, statement, parameters, use_ansi = True, **kwargs): + self._outertable = None + self._use_ansi = use_ansi + ansisql.ANSICompiler.__init__(self, engine, statement, parameters, **kwargs) + + def visit_column(self, column): + if self._use_ansi: + return ansisql.ANSICompiler.visit_column(self, column) + + if column.table is self._outertable: + self.strings[column] = "%s.%s(+)" % (column.table.name, column.name) + else: + self.strings[column] = "%s.%s" % (column.table.name, column.name) + def visit_function(self, func): if len(func.clauses): super(FBCompiler, self).visit_function(func) @@ -217,11 +223,10 @@ class FBCompiler(ansisql.ANSICompiler): """ called when building a SELECT statment, position is just before column list Firebird puts the limit and offset right after the select...thanks for adding the visit_select_precolumns!!!""" - result = '' if select.offset: - result +=" FIRST %s " % select.offset + result +=" FIRST " + select.offset if select.limit: - result += " SKIP %s " % select.limit + result += " SKIP " + select.limit if select.distinct: result += " DISTINCT " return result @@ -229,8 +234,6 @@ class FBCompiler(ansisql.ANSICompiler): def limit_clause(self, select): """Already taken care of in the visit_select_precolumns method.""" return "" - def default_from(self): - return ' from RDB$DATABASE ' class FBSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py index 468e9a548..55c522558 100644 --- a/lib/sqlalchemy/databases/information_schema.py +++ b/lib/sqlalchemy/databases/information_schema.py @@ -7,22 +7,22 @@ from sqlalchemy.exceptions import * from sqlalchemy import * from sqlalchemy.ansisql import * -generic_engine = ansisql.engine() +ischema = MetaData() -gen_schemata = schema.Table("schemata", generic_engine, +schemata = schema.Table("schemata", ischema, Column("catalog_name", String), Column("schema_name", String), Column("schema_owner", String), schema="information_schema") -gen_tables = schema.Table("tables", generic_engine, +tables = schema.Table("tables", ischema, Column("table_catalog", String), Column("table_schema", String), Column("table_name", String), Column("table_type", String), schema="information_schema") -gen_columns = schema.Table("columns", generic_engine, +columns = schema.Table("columns", ischema, Column("table_schema", String), Column("table_name", String), Column("column_name", String), @@ -35,28 +35,40 @@ gen_columns = schema.Table("columns", generic_engine, Column("column_default", Integer), schema="information_schema") -gen_constraints = schema.Table("table_constraints", generic_engine, +constraints = schema.Table("table_constraints", ischema, Column("table_schema", String), Column("table_name", String), Column("constraint_name", String), Column("constraint_type", String), schema="information_schema") -gen_column_constraints = schema.Table("constraint_column_usage", generic_engine, +column_constraints = schema.Table("constraint_column_usage", ischema, Column("table_schema", String), Column("table_name", String), Column("column_name", String), Column("constraint_name", String), schema="information_schema") -gen_key_constraints = schema.Table("key_column_usage", generic_engine, +pg_key_constraints = schema.Table("key_column_usage", ischema, Column("table_schema", String), Column("table_name", String), Column("column_name", String), Column("constraint_name", String), schema="information_schema") -gen_ref_constraints = schema.Table("referential_constraints", generic_engine, +mysql_key_constraints = schema.Table("key_column_usage", ischema, + Column("table_schema", String), + Column("table_name", String), + Column("column_name", String), + Column("constraint_name", String), + Column("referenced_table_schema", String), + Column("referenced_table_name", String), + Column("referenced_column_name", String), + schema="information_schema") + +key_constraints = pg_key_constraints + +ref_constraints = schema.Table("referential_constraints", ischema, Column("constraint_catalog", String), Column("constraint_schema", String), Column("constraint_name", String), @@ -88,37 +100,25 @@ class ISchema(object): return self.cache[name] -def reflecttable(engine, table, ischema_names, use_mysql=False): - columns = gen_columns.toengine(engine) - constraints = gen_constraints.toengine(engine) +def reflecttable(connection, table, ischema_names, use_mysql=False): if use_mysql: # no idea which INFORMATION_SCHEMA spec is correct, mysql or postgres - key_constraints = schema.Table("key_column_usage", engine, - Column("table_schema", String), - Column("table_name", String), - Column("column_name", String), - Column("constraint_name", String), - Column("referenced_table_schema", String), - Column("referenced_table_name", String), - Column("referenced_column_name", String), - schema="information_schema", useexisting=True) + key_constraints = mysql_key_constraints else: - column_constraints = gen_column_constraints.toengine(engine) - key_constraints = gen_key_constraints.toengine(engine) - - + key_constraints = pg_key_constraints + if table.schema is not None: current_schema = table.schema else: - current_schema = engine.get_default_schema_name() + current_schema = connection.default_schema_name() s = select([columns], sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema), order_by=[columns.c.ordinal_position]) - c = s.execute() + c = connection.execute(s) while True: row = c.fetchone() if row is None: @@ -160,7 +160,7 @@ def reflecttable(engine, table, ischema_names, use_mysql=False): s.append_whereclause(constraints.c.table_name==table.name) s.append_whereclause(constraints.c.table_schema==current_schema) colmap = [constraints.c.constraint_type, key_constraints.c.column_name, key_constraints.c.referenced_table_schema, key_constraints.c.referenced_table_name, key_constraints.c.referenced_column_name] - c = s.execute() + c = connection.execute(s) while True: row = c.fetchone() @@ -178,6 +178,8 @@ def reflecttable(engine, table, ischema_names, use_mysql=False): if type=='PRIMARY KEY': table.c[constrained_column]._set_primary_key() elif type=='FOREIGN KEY': - remotetable = Table(referred_table, engine, autoload = True, schema=referred_schema) + if current_schema == referred_schema: + referred_schema = table.schema + remotetable = Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema) table.c[constrained_column].append_item(schema.ForeignKey(remotetable.c[referred_column])) diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 6a7ef91b3..a8124537a 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -455,7 +455,7 @@ class MSSQLCompiler(ansisql.ANSICompiler): super(MSSQLCompiler, self).visit_column(column) if column.table is not None and self.tablealiases.has_key(column.table): self.strings[column] = \ - self.strings[self.tablealiases[column.table]._get_col_by_original(column.original)] + self.strings[self.tablealiases[column.table].corresponding_column(column.original)] class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 60435f220..0a480ec11 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -6,14 +6,11 @@ import sys, StringIO, string, types, re, datetime -import sqlalchemy.sql as sql -import sqlalchemy.engine as engine -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql +from sqlalchemy import sql,engine,schema,ansisql +from sqlalchemy.engine import default import sqlalchemy.types as sqltypes -from sqlalchemy import * import sqlalchemy.databases.information_schema as ischema -from sqlalchemy.exceptions import * +import sqlalchemy.exceptions as exceptions try: import MySQLdb as mysql @@ -26,7 +23,7 @@ class MSNumeric(sqltypes.Numeric): class MSDouble(sqltypes.Numeric): def __init__(self, precision = None, length = None): if (precision is None and length is not None) or (precision is not None and length is None): - raise ArgumentError("You must specify both precision and length or omit both altogether.") + raise exceptions.ArgumentError("You must specify both precision and length or omit both altogether.") super(MSDouble, self).__init__(precision, length) def get_col_spec(self): if self.precision is not None and self.length is not None: @@ -56,7 +53,7 @@ class MSDate(sqltypes.Date): class MSTime(sqltypes.Time): def get_col_spec(self): return "TIME" - def convert_result_value(self, value, engine): + def convert_result_value(self, value, dialect): # convert from a timedelta value if value is not None: return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60)) @@ -129,72 +126,66 @@ def descriptor(): return {'name':'mysql', 'description':'MySQL', 'arguments':[ - ('user',"Database Username",None), - ('passwd',"Database Password",None), - ('db',"Database Name",None), + ('username',"Database Username",None), + ('password',"Database Password",None), + ('database',"Database Name",None), ('host',"Hostname", None), ]} -class MySQLEngine(ansisql.ANSISQLEngine): - def __init__(self, opts, module = None, **params): + +class MySQLExecutionContext(default.DefaultExecutionContext): + def post_exec(self, engine, proxy, compiled, parameters, **kwargs): + if getattr(compiled, "isinsert", False): + self._last_inserted_ids = [proxy().lastrowid] + +class MySQLDialect(ansisql.ANSIDialect): + def __init__(self, module = None, **kwargs): if module is None: self.module = mysql - self.opts = self._translate_connect_args(('host', 'db', 'user', 'passwd'), opts) - ansisql.ANSISQLEngine.__init__(self, **params) + ansisql.ANSIDialect.__init__(self, **kwargs) + + def create_connect_args(self, url): + opts = url.translate_connect_args(['host', 'db', 'user', 'passwd', 'port']) + return [[], opts] - def connect_args(self): - return [[], self.opts] + def create_execution_context(self): + return MySQLExecutionContext(self) def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def last_inserted_ids(self): - return self.context.last_inserted_ids def supports_sane_rowcount(self): return False def compiler(self, statement, bindparams, **kwargs): - return MySQLCompiler(statement, bindparams, engine=self, **kwargs) + return MySQLCompiler(self, statement, bindparams, **kwargs) - def schemagenerator(self, **params): - return MySQLSchemaGenerator(self, **params) + def schemagenerator(self, *args, **kwargs): + return MySQLSchemaGenerator(*args, **kwargs) - def schemadropper(self, **params): - return MySQLSchemaDropper(self, **params) + def schemadropper(self, *args, **kwargs): + return MySQLSchemaDropper(*args, **kwargs) def get_default_schema_name(self): if not hasattr(self, '_default_schema_name'): self._default_schema_name = text("select database()", self).scalar() return self._default_schema_name - - def last_inserted_ids(self): - return self.context.last_inserted_ids - - def post_exec(self, proxy, compiled, parameters, **kwargs): - if getattr(compiled, "isinsert", False): - self.context.last_inserted_ids = [proxy().lastrowid] - # executemany just runs normally, since we arent using rowcount at all with mysql -# def _executemany(self, c, statement, parameters): - # """we need accurate rowcounts for updates, inserts and deletes. mysql is *also* is not nice enough - # to produce this correctly for an executemany, so we do our own executemany here.""" - # rowcount = 0 - # for param in parameters: - # c.execute(statement, param) - # rowcount += c.rowcount - # self.context.rowcount = rowcount - def dbapi(self): return self.module - def reflecttable(self, table): + def has_table(self, connection, table_name): + cursor = connection.execute("show table status like '" + table_name + "'") + return bool( not not cursor.rowcount ) + + def reflecttable(self, connection, table): # to use information_schema: #ischema.reflecttable(self, table, ischema_names, use_mysql=True) - tabletype, foreignkeyD = self.moretableinfo(table=table) + tabletype, foreignkeyD = self.moretableinfo(connection, table=table) table.kwargs['mysql_engine'] = tabletype - c = self.execute("describe " + table.name, {}) + c = connection.execute("describe " + table.name, {}) while True: row = c.fetchone() if row is None: @@ -224,7 +215,7 @@ class MySQLEngine(ansisql.ANSISQLEngine): default=default ))) - def moretableinfo(self, table): + def moretableinfo(self, connection, table): """Return (tabletype, {colname:foreignkey,...}) execute(SHOW CREATE TABLE child) => CREATE TABLE `child` ( @@ -233,7 +224,7 @@ class MySQLEngine(ansisql.ANSISQLEngine): KEY `par_ind` (`parent_id`), CONSTRAINT `child_ibfk_1` FOREIGN KEY (`parent_id`) REFERENCES `parent` (`id`) ON DELETE CASCADE\n) TYPE=InnoDB """ - c = self.execute("SHOW CREATE TABLE " + table.name, {}) + c = connection.execute("SHOW CREATE TABLE " + table.name, {}) desc = c.fetchone()[1].strip() tabletype = '' lastparen = re.search(r'\)[^\)]*\Z', desc) @@ -277,7 +268,7 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): if column.primary_key: if not override_pk: colspec += " PRIMARY KEY" - if not column.foreign_key and first_pk and isinstance(column.type, types.Integer): + if not column.foreign_key and first_pk and isinstance(column.type, sqltypes.Integer): colspec += " AUTO_INCREMENT" if column.foreign_key: colspec += ", FOREIGN KEY (%s) REFERENCES %s(%s)" % (column.name, column.foreign_key.column.table.name, column.foreign_key.column.name) @@ -294,3 +285,5 @@ class MySQLSchemaDropper(ansisql.ANSISchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX " + index.name + " ON " + index.table.name) self.execute() + +dialect = MySQLDialect
\ No newline at end of file diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 16c6cb218..b27f87dd0 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -7,11 +7,14 @@ import sys, StringIO, string +import sqlalchemy.util as util import sqlalchemy.sql as sql +import sqlalchemy.engine as engine +import sqlalchemy.engine.default as default import sqlalchemy.schema as schema import sqlalchemy.ansisql as ansisql -from sqlalchemy import * import sqlalchemy.types as sqltypes +import sqlalchemy.exceptions as exceptions try: import cx_Oracle @@ -93,8 +96,6 @@ AND ac.r_constraint_name = rem.constraint_name(+) -- order multiple primary keys correctly ORDER BY ac.constraint_name, loc.position""" -def engine(*args, **params): - return OracleSQLEngine(*args, **params) def descriptor(): return {'name':'oracle', @@ -104,45 +105,53 @@ def descriptor(): ('user', 'Username', None), ('password', 'Password', None) ]} + +class OracleExecutionContext(default.DefaultExecutionContext): + pass -class OracleSQLEngine(ansisql.ANSISQLEngine): - def __init__(self, opts, use_ansi = True, module = None, threaded=False, **params): - self._use_ansi = use_ansi - self.opts = self._translate_connect_args((None, 'dsn', 'user', 'password'), opts) - self.opts['threaded'] = threaded +class OracleDialect(ansisql.ANSIDialect): + def __init__(self, use_ansi=True, module=None, threaded=True, **kwargs): + self.use_ansi = use_ansi + self.threaded = threaded if module is None: self.module = cx_Oracle else: self.module = module - ansisql.ANSISQLEngine.__init__(self, **params) + ansisql.ANSIDialect.__init__(self, **kwargs) def dbapi(self): return self.module - def connect_args(self): - return [[], self.opts] + def create_connect_args(self, url): + opts = url.translate_connect_args([None, 'dsn', 'user', 'password']) + opts['threaded'] = self.threaded + return ([], opts) def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def last_inserted_ids(self): - return self.context.last_inserted_ids - def oid_column_name(self): return "rowid" + def create_execution_context(self): + return OracleExecutionContext(self) + def compiler(self, statement, bindparams, **kwargs): - return OracleCompiler(self, statement, bindparams, use_ansi=self._use_ansi, **kwargs) - - def schemagenerator(self, **params): - return OracleSchemaGenerator(self, **params) - def schemadropper(self, **params): - return OracleSchemaDropper(self, **params) - def defaultrunner(self, proxy): - return OracleDefaultRunner(self, proxy) + return OracleCompiler(self, statement, bindparams, **kwargs) + def schemagenerator(self, *args, **kwargs): + return OracleSchemaGenerator(*args, **kwargs) + def schemadropper(self, *args, **kwargs): + return OracleSchemaDropper(*args, **kwargs) + def defaultrunner(self, engine, proxy): + return OracleDefaultRunner(engine, proxy) + + + def has_table(self, connection, table_name): + cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':table_name.upper()}) + return bool( cursor.fetchone() is not None ) - def reflecttable(self, table): - c = self.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS where TABLE_NAME = :table_name", {'table_name':table.name.upper()}) + def reflecttable(self, connection, table): + c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS where TABLE_NAME = :table_name", {'table_name':table.name.upper()}) while True: row = c.fetchone() @@ -171,14 +180,14 @@ class OracleSQLEngine(ansisql.ANSISQLEngine): colargs = [] if default is not None: - colargs.append(PassiveDefault(sql.text(default))) + colargs.append(schema.PassiveDefault(sql.text(default))) name = name.lower() table.append_item (schema.Column(name, coltype, nullable=nullable, *colargs)) - c = self.execute(constraintSQL, {'table_name' : table.name.upper()}) + c = connection.execute(constraintSQL, {'table_name' : table.name.upper()}) while True: row = c.fetchone() if row is None: @@ -189,34 +198,24 @@ class OracleSQLEngine(ansisql.ANSISQLEngine): table.c[local_column]._set_primary_key() elif cons_type == 'R': table.c[local_column].append_item( - schema.ForeignKey(Table(remote_table, - self, + schema.ForeignKey(schema.Table(remote_table, + table.metadata, autoload=True).c[remote_column] ) ) - def last_inserted_ids(self): - return self.context.last_inserted_ids - - def pre_exec(self, proxy, compiled, parameters, **kwargs): - pass - - def _executemany(self, c, statement, parameters): + def do_executemany(self, c, statement, parameters, context=None): rowcount = 0 for param in parameters: c.execute(statement, param) rowcount += c.rowcount - self.context.rowcount = rowcount + if context is not None: + context._rowcount = rowcount class OracleCompiler(ansisql.ANSICompiler): """oracle compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if the use_ansi flag is False.""" - def __init__(self, engine, statement, parameters, use_ansi = True, **kwargs): - self._outertable = None - self._use_ansi = use_ansi - ansisql.ANSICompiler.__init__(self, statement, parameters, engine=engine, **kwargs) - def default_from(self): """called when a SELECT statement has no froms, and no FROM clause is to be appended. gives Oracle a chance to tack on a "FROM DUAL" to the string output. """ @@ -226,7 +225,7 @@ class OracleCompiler(ansisql.ANSICompiler): return len(func.clauses) > 0 def visit_join(self, join): - if self._use_ansi: + if self.dialect.use_ansi: return ansisql.ANSICompiler.visit_join(self, join) self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right) @@ -251,7 +250,7 @@ class OracleCompiler(ansisql.ANSICompiler): def visit_column(self, column): ansisql.ANSICompiler.visit_column(self, column) - if not self._use_ansi and self._outertable is not None and column.table is self._outertable: + if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable: self.strings[column] = self.strings[column] + "(+)" def visit_insert(self, insert): @@ -275,12 +274,15 @@ class OracleCompiler(ansisql.ANSICompiler): self.strings[select.order_by_clause] = "" ansisql.ANSICompiler.visit_select(self, select) return + if select.limit is not None or select.offset is not None: select._oracle_visit = True # to use ROW_NUMBER(), an ORDER BY is required. orderby = self.strings[select.order_by_clause] if not orderby: orderby = select.oid_column + orderby.accept_visitor(self) + orderby = self.strings[orderby] select.append_column(sql.ColumnClause("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")) limitselect = sql.select([c for c in select.c if c.key!='ora_rn']) if select.offset is not None: @@ -330,3 +332,5 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner): def visit_sequence(self, seq): return self.proxy("SELECT " + seq.name + ".nextval FROM DUAL").fetchone()[0] + +dialect = OracleDialect diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index a92cb340d..b6917c035 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -9,11 +9,11 @@ import datetime, sys, StringIO, string, types, re import sqlalchemy.util as util import sqlalchemy.sql as sql import sqlalchemy.engine as engine +import sqlalchemy.engine.default as default import sqlalchemy.schema as schema import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes -from sqlalchemy.exceptions import * -from sqlalchemy import * +import sqlalchemy.exceptions as exceptions import information_schema as ischema try: @@ -47,7 +47,7 @@ class PG2DateTime(sqltypes.DateTime): def get_col_spec(self): return "TIMESTAMP" class PG1DateTime(sqltypes.DateTime): - def convert_bind_param(self, value, engine): + def convert_bind_param(self, value, dialect): if value is not None: if isinstance(value, datetime.datetime): seconds = float(str(value.second) + "." @@ -59,7 +59,7 @@ class PG1DateTime(sqltypes.DateTime): return psycopg.TimestampFromMx(value) else: return None - def convert_result_value(self, value, engine): + def convert_result_value(self, value, dialect): if value is None: return None second_parts = str(value.second).split(".") @@ -68,21 +68,20 @@ class PG1DateTime(sqltypes.DateTime): return datetime.datetime(value.year, value.month, value.day, value.hour, value.minute, seconds, microseconds) - def get_col_spec(self): return "TIMESTAMP" class PG2Date(sqltypes.Date): def get_col_spec(self): return "DATE" class PG1Date(sqltypes.Date): - def convert_bind_param(self, value, engine): + def convert_bind_param(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime # this one doesnt seem to work with the "emulation" mode if value is not None: return psycopg.DateFromMx(value) else: return None - def convert_result_value(self, value, engine): + def convert_result_value(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime return value def get_col_spec(self): @@ -91,14 +90,14 @@ class PG2Time(sqltypes.Time): def get_col_spec(self): return "TIME" class PG1Time(sqltypes.Time): - def convert_bind_param(self, value, engine): + def convert_bind_param(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime # this one doesnt seem to work with the "emulation" mode if value is not None: return psycopg.TimeFromMx(value) else: return None - def convert_result_value(self, value, engine): + def convert_result_value(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime return value def get_col_spec(self): @@ -175,18 +174,35 @@ def descriptor(): return {'name':'postgres', 'description':'PostGres', 'arguments':[ - ('user',"Database Username",None), + ('username',"Database Username",None), ('password',"Database Password",None), ('database',"Database Name",None), ('host',"Hostname", None), ]} -class PGSQLEngine(ansisql.ANSISQLEngine): - def __init__(self, opts, module=None, use_oids=False, **params): +class PGExecutionContext(default.DefaultExecutionContext): + + def post_exec(self, engine, proxy, compiled, parameters, **kwargs): + if getattr(compiled, "isinsert", False) and self.last_inserted_ids is None: + if not engine.dialect.use_oids: + pass + # will raise invalid error when they go to get them + else: + table = compiled.statement.table + cursor = proxy() + if cursor.lastrowid is not None and table is not None and len(table.primary_key): + s = sql.select(table.primary_key, table.oid_column == cursor.lastrowid) + c = s.compile(engine=engine) + cursor = proxy(str(c), c.get_params()) + row = cursor.fetchone() + self._last_inserted_ids = [v for v in row] + +class PGDialect(ansisql.ANSIDialect): + def __init__(self, module=None, use_oids=False, **params): self.use_oids = use_oids if module is None: - if psycopg is None: - raise ArgumentError("Couldnt locate psycopg1 or psycopg2: specify postgres module argument") + #if psycopg is None: + # raise exceptions.ArgumentError("Couldnt locate psycopg1 or psycopg2: specify postgres module argument") self.module = psycopg else: self.module = module @@ -198,17 +214,19 @@ class PGSQLEngine(ansisql.ANSISQLEngine): self.version = 1 except: self.version = 1 - self.opts = self._translate_connect_args(('host', 'database', 'user', 'password'), opts) - if self.opts.has_key('port'): + ansisql.ANSIDialect.__init__(self, **params) + + def create_connect_args(self, url): + opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) + if opts.has_key('port'): if self.version == 2: - self.opts['port'] = int(self.opts['port']) + opts['port'] = int(opts['port']) else: - self.opts['port'] = str(self.opts['port']) - - ansisql.ANSISQLEngine.__init__(self, **params) - - def connect_args(self): - return [[], self.opts] + opts['port'] = str(opts['port']) + return ([], opts) + + def create_execution_context(self): + return PGExecutionContext(self) def type_descriptor(self, typeobj): if self.version == 2: @@ -217,25 +235,22 @@ class PGSQLEngine(ansisql.ANSISQLEngine): return sqltypes.adapt_type(typeobj, pg1_colspecs) def compiler(self, statement, bindparams, **kwargs): - return PGCompiler(statement, bindparams, engine=self, **kwargs) - - def schemagenerator(self, **params): - return PGSchemaGenerator(self, **params) - - def schemadropper(self, **params): - return PGSchemaDropper(self, **params) - - def defaultrunner(self, proxy=None): - return PGDefaultRunner(self, proxy) + return PGCompiler(self, statement, bindparams, **kwargs) + def schemagenerator(self, *args, **kwargs): + return PGSchemaGenerator(*args, **kwargs) + def schemadropper(self, *args, **kwargs): + return PGSchemaDropper(*args, **kwargs) + def defaultrunner(self, engine, proxy): + return PGDefaultRunner(engine, proxy) - def get_default_schema_name(self): + def get_default_schema_name(self, connection): if not hasattr(self, '_default_schema_name'): - self._default_schema_name = text("select current_schema()", self).scalar() + self._default_schema_name = connection.scalar("select current_schema()", None) return self._default_schema_name def last_inserted_ids(self): if self.context.last_inserted_ids is None: - raise InvalidRequestError("no INSERT executed, or cant use cursor.lastrowid without Postgres OIDs enabled") + raise exceptions.InvalidRequestError("no INSERT executed, or cant use cursor.lastrowid without Postgres OIDs enabled") else: return self.context.last_inserted_ids @@ -245,51 +260,32 @@ class PGSQLEngine(ansisql.ANSISQLEngine): else: return None - def pre_exec(self, proxy, statement, parameters, **kwargs): - return - - def post_exec(self, proxy, compiled, parameters, **kwargs): - if getattr(compiled, "isinsert", False) and self.context.last_inserted_ids is None: - if not self.use_oids: - pass - # will raise invalid error when they go to get them - else: - table = compiled.statement.table - cursor = proxy() - if cursor.lastrowid is not None and table is not None and len(table.primary_key): - s = sql.select(table.primary_key, table.oid_column == cursor.lastrowid) - c = s.compile() - cursor = proxy(str(c), c.get_params()) - row = cursor.fetchone() - self.context.last_inserted_ids = [v for v in row] - - def _executemany(self, c, statement, parameters): + def do_executemany(self, c, statement, parameters, context=None): """we need accurate rowcounts for updates, inserts and deletes. psycopg2 is not nice enough to produce this correctly for an executemany, so we do our own executemany here.""" rowcount = 0 for param in parameters: - try: - c.execute(statement, param) - except Exception, e: - raise exceptions.SQLError(statement, param, e) + c.execute(statement, param) rowcount += c.rowcount - self.context.rowcount = rowcount + if context is not None: + context._rowcount = rowcount def dbapi(self): return self.module - def reflecttable(self, table): + def has_table(self, connection, table_name): + cursor = connection.execute("""select relname from pg_class where lower(relname) = %(name)s""", {'name':table_name.lower()}) + return bool( not not cursor.rowcount ) + + def reflecttable(self, connection, table): if self.version == 2: ischema_names = pg2_ischema_names else: ischema_names = pg1_ischema_names - # give ischema the given table's engine with which to look up - # other tables, not 'self', since it could be a ProxyEngine - ischema.reflecttable(table.engine, table, ischema_names) + ischema.reflecttable(connection, table, ischema_names) class PGCompiler(ansisql.ANSICompiler): - def visit_insert_column(self, column, parameters): # Postgres advises against OID usage and turns it off in 8.1, @@ -322,7 +318,7 @@ class PGCompiler(ansisql.ANSICompiler): return "DISTINCT ON (" + str(select.distinct) + ") " else: return "" - + def binary_operator_string(self, binary): if isinstance(binary.type, sqltypes.String) and binary.operator == '+': return '||' @@ -333,7 +329,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name - if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + if column.primary_key and isinstance(column.type, sqltypes.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): colspec += " SERIAL" else: colspec += " " + column.type.engine_impl(self.engine).get_col_spec() @@ -367,7 +363,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): if isinstance(column.default, schema.PassiveDefault): c = self.proxy("select %s" % column.default.arg) return c.fetchone()[0] - elif isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + elif isinstance(column.type, sqltypes.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): sch = column.table.schema if sch is not None: exc = "select nextval('%s.%s_%s_seq')" % (sch, column.table.name, column.name) @@ -386,3 +382,5 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): return c.fetchone()[0] else: return None + +dialect = PGDialect diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index a7536ee4e..4d9f562ae 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -7,13 +7,9 @@ import sys, StringIO, string, types, re -import sqlalchemy.sql as sql -import sqlalchemy.engine as engine -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql +from sqlalchemy import sql, engine, schema, ansisql, exceptions, pool +import sqlalchemy.engine.default as default import sqlalchemy.types as sqltypes -from sqlalchemy.exceptions import * -from sqlalchemy.ansisql import * import datetime,time pysqlite2_timesupport = False # Change this if the init.d guys ever get around to supporting time cols @@ -38,12 +34,12 @@ class SLSmallInteger(sqltypes.Smallinteger): class SLDateTime(sqltypes.DateTime): def get_col_spec(self): return "TIMESTAMP" - def convert_bind_param(self, value, engine): + def convert_bind_param(self, value, dialect): if value is not None: return str(value) else: return None - def _cvt(self, value, engine, fmt): + def _cvt(self, value, dialect, fmt): if value is None: return None parts = value.split('.') @@ -53,20 +49,20 @@ class SLDateTime(sqltypes.DateTime): except ValueError: (value, microsecond) = (value, 0) return time.strptime(value, fmt)[0:6] + (microsecond,) - def convert_result_value(self, value, engine): - tup = self._cvt(value, engine, "%Y-%m-%d %H:%M:%S") + def convert_result_value(self, value, dialect): + tup = self._cvt(value, dialect, "%Y-%m-%d %H:%M:%S") return tup and datetime.datetime(*tup) class SLDate(SLDateTime): def get_col_spec(self): return "DATE" - def convert_result_value(self, value, engine): - tup = self._cvt(value, engine, "%Y-%m-%d") + def convert_result_value(self, value, dialect): + tup = self._cvt(value, dialect, "%Y-%m-%d") return tup and datetime.date(*tup[0:3]) class SLTime(SLDateTime): def get_col_spec(self): return "TIME" - def convert_result_value(self, value, engine): - tup = self._cvt(value, engine, "%H:%M:%S") + def convert_result_value(self, value, dialect): + tup = self._cvt(value, dialect, "%H:%M:%S") return tup and datetime.time(*tup[4:7]) class SLText(sqltypes.TEXT): def get_col_spec(self): @@ -115,33 +111,32 @@ pragma_names = { if pysqlite2_timesupport: colspecs.update({sqltypes.Time : SLTime}) pragma_names.update({'TIME' : SLTime}) - -def engine(opts, **params): - return SQLiteSQLEngine(opts, **params) def descriptor(): return {'name':'sqlite', 'description':'SQLite', 'arguments':[ - ('filename', "Database Filename",None) + ('database', "Database Filename",None) ]} - -class SQLiteSQLEngine(ansisql.ANSISQLEngine): - def __init__(self, opts, **params): - if sqlite is None: - raise ArgumentError("Couldn't import sqlite or pysqlite2") - self.filename = opts.pop('filename', ':memory:') - self.opts = opts or {} - params['poolclass'] = sqlalchemy.pool.SingletonThreadPool - ansisql.ANSISQLEngine.__init__(self, **params) - def post_exec(self, proxy, compiled, parameters, **kwargs): - if getattr(compiled, "isinsert", False): - self.context.last_inserted_ids = [proxy().lastrowid] +class SQLiteExecutionContext(default.DefaultExecutionContext): + def post_exec(self, engine, proxy, compiled, parameters, **kwargs): + if getattr(compiled, "isinsert", False): + self._last_inserted_ids = [proxy().lastrowid] + +class SQLiteDialect(ansisql.ANSIDialect): + def compiler(self, statement, bindparams, **kwargs): + return SQLiteCompiler(self, statement, bindparams, **kwargs) + def schemagenerator(self, *args, **kwargs): + return SQLiteSchemaGenerator(*args, **kwargs) + def create_connect_args(self, url): + filename = url.database or ':memory:' + return ([filename], {}) def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - + def create_execution_context(self): + return SQLiteExecutionContext(self) def last_inserted_ids(self): return self.context.last_inserted_ids @@ -151,20 +146,21 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): def connect_args(self): return ([self.filename], self.opts) - def compiler(self, statement, bindparams, **kwargs): - return SQLiteCompiler(statement, bindparams, engine=self, **kwargs) - def dbapi(self): + if sqlite is None: + raise ArgumentError("Couldn't import sqlite or pysqlite2") return sqlite def push_session(self): raise InvalidRequestError("SQLite doesn't support nested sessions") - def schemagenerator(self, **params): - return SQLiteSchemaGenerator(self, **params) + def has_table(self, connection, table_name): + cursor = connection.execute("PRAGMA table_info(" + table_name + ")", {}) + row = cursor.fetchone() + return (row is not None) - def reflecttable(self, table): - c = self.execute("PRAGMA table_info(" + table.name + ")", {}) + def reflecttable(self, connection, table): + c = connection.execute("PRAGMA table_info(" + table.name + ")", {}) while True: row = c.fetchone() if row is None: @@ -183,7 +179,7 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): #print "args! " +repr(args) coltype = coltype(*[int(a) for a in args]) table.append_item(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable)) - c = self.execute("PRAGMA foreign_key_list(" + table.name + ")", {}) + c = connection.execute("PRAGMA foreign_key_list(" + table.name + ")", {}) while True: row = c.fetchone() if row is None: @@ -192,10 +188,10 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): #print "row! " + repr(row) # look up the table based on the given table's engine, not 'self', # since it could be a ProxyEngine - remotetable = Table(tablename, table.engine, autoload = True) + remotetable = schema.Table(tablename, table.metadata, autoload=True, autoload_with=connection) table.c[localcol].append_item(schema.ForeignKey(remotetable.c[remotecol])) # check for UNIQUE indexes - c = self.execute("PRAGMA index_list(" + table.name + ")", {}) + c = connection.execute("PRAGMA index_list(" + table.name + ")", {}) unique_indexes = [] while True: row = c.fetchone() @@ -205,7 +201,7 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): unique_indexes.append(row[1]) # loop thru unique indexes for one that includes the primary key for idx in unique_indexes: - c = self.execute("PRAGMA index_info(" + idx + ")", {}) + c = connection.execute("PRAGMA index_info(" + idx + ")", {}) cols = [] while True: row = c.fetchone() @@ -219,9 +215,6 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): table.columns[col]._set_primary_key() class SQLiteCompiler(ansisql.ANSICompiler): - def __init__(self, *args, **params): - params.setdefault('paramstyle', 'named') - ansisql.ANSICompiler.__init__(self, *args, **params) def limit_clause(self, select): text = "" if select.limit is not None: @@ -238,7 +231,7 @@ class SQLiteCompiler(ansisql.ANSICompiler): return '||' else: return ansisql.ANSICompiler.binary_operator_string(self, binary) - + class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec() @@ -277,4 +270,5 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): for index in table.indexes: self.visit_index(index) - +dialect = SQLiteDialect +poolclass = pool.SingletonThreadPool |
