diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-08-11 00:03:26 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-08-11 00:03:26 +0000 |
| commit | 1391efea78d552fa81dd056c21e7570fba437bcf (patch) | |
| tree | df7cd37efeefeb9048fff7b8bdf9fe1d75ce81d9 | |
| parent | b852fcbce0204fb8edcb7fda9605824a1193e685 (diff) | |
| download | sqlalchemy-1391efea78d552fa81dd056c21e7570fba437bcf.tar.gz | |
repaired oracle savepoint implementation
| -rw-r--r-- | lib/sqlalchemy/ansisql.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 13 | ||||
| -rw-r--r-- | test/engine/transaction.py | 4 |
3 files changed, 20 insertions, 7 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 430027ed8..0efaf8657 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -745,13 +745,13 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): return text def visit_savepoint(self, savepoint_stmt): - return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) + return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def visit_rollback_to_savepoint(self, savepoint_stmt): - return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) + return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def visit_release_savepoint(self, savepoint_stmt): - return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) + return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def __str__(self): return self.string @@ -1052,8 +1052,8 @@ class ANSIIdentifierPreparer(object): def format_alias(self, alias, name=None): return self.__generic_obj_format(alias, name or alias.name) - def format_savepoint(self, savepoint): - return self.__generic_obj_format(savepoint, savepoint) + def format_savepoint(self, savepoint, name=None): + return self.__generic_obj_format(savepoint, name or savepoint.ident) def format_constraint(self, constraint): return self.__generic_obj_format(constraint, constraint.name) diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 9b8bb2f9e..7bbc63fba 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -280,12 +280,19 @@ class OracleDialect(ansisql.ANSIDialect): else: return "rowid" + def do_release_savepoint(self, connection, name): + # Oracle does not support RELEASE SAVEPOINT + pass + def create_execution_context(self, *args, **kwargs): return OracleExecutionContext(self, *args, **kwargs) def compiler(self, statement, bindparams, **kwargs): return OracleCompiler(self, statement, bindparams, **kwargs) + def preparer(self): + return OracleIdentifierPreparer(self) + def schemagenerator(self, *args, **kwargs): return OracleSchemaGenerator(self, *args, **kwargs) @@ -662,4 +669,10 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner): def visit_sequence(self, seq): return self.connection.execute("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL").scalar() +class OracleIdentifierPreparer(ansisql.ANSIIdentifierPreparer): + def format_savepoint(self, savepoint): + name = re.sub(r'^_+', '', savepoint.ident) + return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name) + + dialect = OracleDialect diff --git a/test/engine/transaction.py b/test/engine/transaction.py index 3c84684da..8c516b195 100644 --- a/test/engine/transaction.py +++ b/test/engine/transaction.py @@ -173,7 +173,7 @@ class TransactionTest(PersistTest): ) connection.close() - @testing.supported('postgres', 'mysql') + @testing.supported('postgres', 'mysql', 'oracle') @testing.exclude('mysql', '<', (5, 0, 3)) def testtwophasetransaction(self): connection = testbase.db.connect() @@ -301,7 +301,7 @@ class TLTransactionTest(PersistTest): tlengine = create_engine(testbase.db.url, strategy='threadlocal') metadata = MetaData() users = Table('query_users', metadata, - Column('user_id', INT, primary_key = True), + Column('user_id', INT, Sequence('query_users_id_seq', optional=True), primary_key=True), Column('user_name', VARCHAR(20)), test_needs_acid=True, ) |
