summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-08-11 00:03:26 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-08-11 00:03:26 +0000
commit1391efea78d552fa81dd056c21e7570fba437bcf (patch)
treedf7cd37efeefeb9048fff7b8bdf9fe1d75ce81d9
parentb852fcbce0204fb8edcb7fda9605824a1193e685 (diff)
downloadsqlalchemy-1391efea78d552fa81dd056c21e7570fba437bcf.tar.gz
repaired oracle savepoint implementation
-rw-r--r--lib/sqlalchemy/ansisql.py10
-rw-r--r--lib/sqlalchemy/databases/oracle.py13
-rw-r--r--test/engine/transaction.py4
3 files changed, 20 insertions, 7 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 430027ed8..0efaf8657 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -745,13 +745,13 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
return text
def visit_savepoint(self, savepoint_stmt):
- return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+ return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
def visit_rollback_to_savepoint(self, savepoint_stmt):
- return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+ return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
def visit_release_savepoint(self, savepoint_stmt):
- return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+ return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
def __str__(self):
return self.string
@@ -1052,8 +1052,8 @@ class ANSIIdentifierPreparer(object):
def format_alias(self, alias, name=None):
return self.__generic_obj_format(alias, name or alias.name)
- def format_savepoint(self, savepoint):
- return self.__generic_obj_format(savepoint, savepoint)
+ def format_savepoint(self, savepoint, name=None):
+ return self.__generic_obj_format(savepoint, name or savepoint.ident)
def format_constraint(self, constraint):
return self.__generic_obj_format(constraint, constraint.name)
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py
index 9b8bb2f9e..7bbc63fba 100644
--- a/lib/sqlalchemy/databases/oracle.py
+++ b/lib/sqlalchemy/databases/oracle.py
@@ -280,12 +280,19 @@ class OracleDialect(ansisql.ANSIDialect):
else:
return "rowid"
+ def do_release_savepoint(self, connection, name):
+ # Oracle does not support RELEASE SAVEPOINT
+ pass
+
def create_execution_context(self, *args, **kwargs):
return OracleExecutionContext(self, *args, **kwargs)
def compiler(self, statement, bindparams, **kwargs):
return OracleCompiler(self, statement, bindparams, **kwargs)
+ def preparer(self):
+ return OracleIdentifierPreparer(self)
+
def schemagenerator(self, *args, **kwargs):
return OracleSchemaGenerator(self, *args, **kwargs)
@@ -662,4 +669,10 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
def visit_sequence(self, seq):
return self.connection.execute("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL").scalar()
+class OracleIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+ def format_savepoint(self, savepoint):
+ name = re.sub(r'^_+', '', savepoint.ident)
+ return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
+
+
dialect = OracleDialect
diff --git a/test/engine/transaction.py b/test/engine/transaction.py
index 3c84684da..8c516b195 100644
--- a/test/engine/transaction.py
+++ b/test/engine/transaction.py
@@ -173,7 +173,7 @@ class TransactionTest(PersistTest):
)
connection.close()
- @testing.supported('postgres', 'mysql')
+ @testing.supported('postgres', 'mysql', 'oracle')
@testing.exclude('mysql', '<', (5, 0, 3))
def testtwophasetransaction(self):
connection = testbase.db.connect()
@@ -301,7 +301,7 @@ class TLTransactionTest(PersistTest):
tlengine = create_engine(testbase.db.url, strategy='threadlocal')
metadata = MetaData()
users = Table('query_users', metadata,
- Column('user_id', INT, primary_key = True),
+ Column('user_id', INT, Sequence('query_users_id_seq', optional=True), primary_key=True),
Column('user_name', VARCHAR(20)),
test_needs_acid=True,
)