diff options
Diffstat (limited to 'test')
97 files changed, 3794 insertions, 2677 deletions
diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py index 3e4274d47..79ae09b05 100644 --- a/test/aaa_profiling/test_compiler.py +++ b/test/aaa_profiling/test_compiler.py @@ -15,15 +15,15 @@ class CompileTest(TestBase, AssertsExecutionResults): Column('c1', Integer, primary_key=True), Column('c2', String(30))) - @profiling.function_call_count(68, {'2.4': 42}) + @profiling.function_call_count(72, {'2.4': 42, '3.0':77}) def test_insert(self): t1.insert().compile() - @profiling.function_call_count(68, {'2.4': 45}) + @profiling.function_call_count(72, {'2.4': 45}) def test_update(self): t1.update().compile() - @profiling.function_call_count(185, versions={'2.4':118}) + @profiling.function_call_count(195, versions={'2.4':118, '3.0':208}) def test_select(self): s = select([t1], t1.c.c2==t2.c.c1) s.compile() diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index 70a3cf8cd..fbf0560ca 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -1,17 +1,22 @@ from sqlalchemy.test.testing import eq_ -import gc from sqlalchemy.orm import mapper, relation, create_session, clear_mappers, sessionmaker from sqlalchemy.orm.mapper import _mapper_registry from sqlalchemy.orm.session import _sessions +from sqlalchemy.util import jython import operator from sqlalchemy.test import testing from sqlalchemy import MetaData, Integer, String, ForeignKey, PickleType -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column import sqlalchemy as sa from sqlalchemy.sql import column +from sqlalchemy.test.util import gc_collect +import gc from test.orm import _base +if jython: + from nose import SkipTest + raise SkipTest("Profiling not supported on this platform") + class A(_base.ComparableEntity): pass @@ -22,11 +27,11 @@ def profile_memory(func): # run the test 50 times. if length of gc.get_objects() # keeps growing, assert false def profile(*args): - gc.collect() + gc_collect() samples = [0 for x in range(0, 50)] for x in range(0, 50): func(*args) - gc.collect() + gc_collect() samples[x] = len(gc.get_objects()) print "sample gc sizes:", samples @@ -50,7 +55,7 @@ def profile_memory(func): def assert_no_mappers(): clear_mappers() - gc.collect() + gc_collect() assert len(_mapper_registry) == 0 class EnsureZeroed(_base.ORMTest): @@ -61,7 +66,7 @@ class EnsureZeroed(_base.ORMTest): class MemUsageTest(EnsureZeroed): # ensure a pure growing test trips the assertion - @testing.fails_if(lambda:True) + @testing.fails_if(lambda: True) def test_fixture(self): class Foo(object): pass @@ -76,11 +81,11 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', String(30))) table2 = Table("mytable2", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', String(30)), Column('col3', Integer, ForeignKey("mytable.col1"))) @@ -129,11 +134,11 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', String(30))) table2 = Table("mytable2", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', String(30)), Column('col3', Integer, ForeignKey("mytable.col1"))) @@ -184,13 +189,13 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', String(30)) ) table2 = Table("mytable2", metadata, Column('col1', Integer, ForeignKey('mytable.col1'), - primary_key=True), + primary_key=True, test_needs_autoincrement=True), Column('col3', String(30)), ) @@ -244,12 +249,12 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', String(30)) ) table2 = Table("mytable2", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', String(30)), ) @@ -308,12 +313,12 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("table1", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)) ) table2 = Table("table2", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('t1id', Integer, ForeignKey('table1.id')) ) @@ -347,7 +352,7 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', PickleType(comparator=operator.eq)) ) @@ -382,7 +387,7 @@ class MemUsageTest(EnsureZeroed): testing.eq_(len(session.identity_map._mutable_attrs), 12) testing.eq_(len(session.identity_map), 12) obj = None - gc.collect() + gc_collect() testing.eq_(len(session.identity_map._mutable_attrs), 0) testing.eq_(len(session.identity_map), 0) @@ -392,7 +397,7 @@ class MemUsageTest(EnsureZeroed): metadata.drop_all() def test_type_compile(self): - from sqlalchemy.databases.sqlite import SQLiteDialect + from sqlalchemy.dialects.sqlite.base import dialect as SQLiteDialect cast = sa.cast(column('x'), sa.Integer) @profile_memory def go(): diff --git a/test/aaa_profiling/test_pool.py b/test/aaa_profiling/test_pool.py index 7bb61deb2..6ae3edc98 100644 --- a/test/aaa_profiling/test_pool.py +++ b/test/aaa_profiling/test_pool.py @@ -5,6 +5,9 @@ from sqlalchemy.pool import QueuePool class QueuePoolTest(TestBase, AssertsExecutionResults): class Connection(object): + def rollback(self): + pass + def close(self): pass @@ -15,7 +18,7 @@ class QueuePoolTest(TestBase, AssertsExecutionResults): use_threadlocal=True) - @profiling.function_call_count(54, {'2.4': 38}) + @profiling.function_call_count(54, {'2.4': 38, '3.0':57}) def test_first_connect(self): conn = pool.connect() @@ -23,7 +26,7 @@ class QueuePoolTest(TestBase, AssertsExecutionResults): conn = pool.connect() conn.close() - @profiling.function_call_count(31, {'2.4': 21}) + @profiling.function_call_count(29, {'2.4': 21}) def go(): conn2 = pool.connect() return conn2 @@ -32,7 +35,7 @@ class QueuePoolTest(TestBase, AssertsExecutionResults): def test_second_samethread_connect(self): conn = pool.connect() - @profiling.function_call_count(5, {'2.4': 3}) + @profiling.function_call_count(5, {'2.4': 3, '3.0':6}) def go(): return pool.connect() c2 = go() diff --git a/test/aaa_profiling/test_zoomark.py b/test/aaa_profiling/test_zoomark.py index be2931896..e41303192 100644 --- a/test/aaa_profiling/test_zoomark.py +++ b/test/aaa_profiling/test_zoomark.py @@ -26,7 +26,7 @@ class ZooMarkTest(TestBase): """ - __only_on__ = 'postgres' + __only_on__ = 'postgresql+psycopg2' __skip_if__ = ((lambda: sys.version_info < (2, 4)), ) def test_baseline_0_setup(self): @@ -75,15 +75,15 @@ class ZooMarkTest(TestBase): Opens=datetime.time(8, 15, 59), LastEscape=datetime.datetime(2004, 7, 29, 5, 6, 7), Admission=4.95, - ).last_inserted_ids()[0] + ).inserted_primary_key[0] sdz = Zoo.insert().execute(Name =u'San Diego Zoo', Founded = datetime.date(1935, 9, 13), Opens = datetime.time(9, 0, 0), Admission = 0, - ).last_inserted_ids()[0] + ).inserted_primary_key[0] - Zoo.insert().execute( + Zoo.insert(inline=True).execute( Name = u'Montr\xe9al Biod\xf4me', Founded = datetime.date(1992, 6, 19), Opens = datetime.time(9, 0, 0), @@ -91,48 +91,48 @@ class ZooMarkTest(TestBase): ) seaworld = Zoo.insert().execute( - Name =u'Sea_World', Admission = 60).last_inserted_ids()[0] + Name =u'Sea_World', Admission = 60).inserted_primary_key[0] # Let's add a crazy futuristic Zoo to test large date values. lp = Zoo.insert().execute(Name =u'Luna Park', Founded = datetime.date(2072, 7, 17), Opens = datetime.time(0, 0, 0), Admission = 134.95, - ).last_inserted_ids()[0] + ).inserted_primary_key[0] # Animals leopardid = Animal.insert().execute(Species=u'Leopard', Lifespan=73.5, - ).last_inserted_ids()[0] + ).inserted_primary_key[0] Animal.update(Animal.c.ID==leopardid).execute(ZooID=wap, LastEscape=datetime.datetime(2004, 12, 21, 8, 15, 0, 999907)) - lion = Animal.insert().execute(Species=u'Lion', ZooID=wap).last_inserted_ids()[0] + lion = Animal.insert().execute(Species=u'Lion', ZooID=wap).inserted_primary_key[0] Animal.insert().execute(Species=u'Slug', Legs=1, Lifespan=.75) tiger = Animal.insert().execute(Species=u'Tiger', ZooID=sdz - ).last_inserted_ids()[0] + ).inserted_primary_key[0] # Override Legs.default with itself just to make sure it works. - Animal.insert().execute(Species=u'Bear', Legs=4) - Animal.insert().execute(Species=u'Ostrich', Legs=2, Lifespan=103.2) - Animal.insert().execute(Species=u'Centipede', Legs=100) + Animal.insert(inline=True).execute(Species=u'Bear', Legs=4) + Animal.insert(inline=True).execute(Species=u'Ostrich', Legs=2, Lifespan=103.2) + Animal.insert(inline=True).execute(Species=u'Centipede', Legs=100) emp = Animal.insert().execute(Species=u'Emperor Penguin', Legs=2, - ZooID=seaworld).last_inserted_ids()[0] + ZooID=seaworld).inserted_primary_key[0] adelie = Animal.insert().execute(Species=u'Adelie Penguin', Legs=2, - ZooID=seaworld).last_inserted_ids()[0] + ZooID=seaworld).inserted_primary_key[0] - Animal.insert().execute(Species=u'Millipede', Legs=1000000, ZooID=sdz) + Animal.insert(inline=True).execute(Species=u'Millipede', Legs=1000000, ZooID=sdz) # Add a mother and child to test relationships bai_yun = Animal.insert().execute(Species=u'Ape', Name=u'Bai Yun', - Legs=2).last_inserted_ids()[0] - Animal.insert().execute(Species=u'Ape', Name=u'Hua Mei', Legs=2, + Legs=2).inserted_primary_key[0] + Animal.insert(inline=True).execute(Species=u'Ape', Name=u'Hua Mei', Legs=2, MotherID=bai_yun) def test_baseline_2_insert(self): Animal = metadata.tables['Animal'] - i = Animal.insert() + i = Animal.insert(inline=True) for x in xrange(ITERATIONS): tick = i.execute(Species=u'Tick', Name=u'Tick %d' % x, Legs=8) @@ -142,7 +142,7 @@ class ZooMarkTest(TestBase): def fullobject(select): """Iterate over the full result row.""" - return list(select.execute().fetchone()) + return list(select.execute().first()) for x in xrange(ITERATIONS): # Zoos @@ -254,7 +254,7 @@ class ZooMarkTest(TestBase): for x in xrange(ITERATIONS): # Edit - SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().fetchone() + SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().first() Zoo.update(Zoo.c.ID==SDZ['ID']).execute( Name=u'The San Diego Zoo', Founded = datetime.date(1900, 1, 1), @@ -262,7 +262,7 @@ class ZooMarkTest(TestBase): Admission = "35.00") # Test edits - SDZ = Zoo.select(Zoo.c.Name==u'The San Diego Zoo').execute().fetchone() + SDZ = Zoo.select(Zoo.c.Name==u'The San Diego Zoo').execute().first() assert SDZ['Founded'] == datetime.date(1900, 1, 1), SDZ['Founded'] # Change it back @@ -273,7 +273,7 @@ class ZooMarkTest(TestBase): Admission = "0") # Test re-edits - SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().fetchone() + SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().first() assert SDZ['Founded'] == datetime.date(1935, 9, 13) def test_baseline_7_multiview(self): @@ -316,10 +316,10 @@ class ZooMarkTest(TestBase): global metadata player = lambda: dbapi_session.player() - engine = create_engine('postgres:///', creator=player) + engine = create_engine('postgresql:///', creator=player) metadata = MetaData(engine) - @profiling.function_call_count(3230, {'2.4': 1796}) + @profiling.function_call_count(2991, {'2.4': 1796}) def test_profile_1_create_tables(self): self.test_baseline_1_create_tables() @@ -327,7 +327,7 @@ class ZooMarkTest(TestBase): def test_profile_1a_populate(self): self.test_baseline_1a_populate() - @profiling.function_call_count(322, {'2.4': 202}) + @profiling.function_call_count(305, {'2.4': 202}) def test_profile_2_insert(self): self.test_baseline_2_insert() diff --git a/test/aaa_profiling/test_zoomark_orm.py b/test/aaa_profiling/test_zoomark_orm.py index 57e1e2404..660f47811 100644 --- a/test/aaa_profiling/test_zoomark_orm.py +++ b/test/aaa_profiling/test_zoomark_orm.py @@ -27,7 +27,7 @@ class ZooMarkTest(TestBase): """ - __only_on__ = 'postgres' + __only_on__ = 'postgresql+psycopg2' __skip_if__ = ((lambda: sys.version_info < (2, 5)), ) # TODO: get 2.4 support def test_baseline_0_setup(self): @@ -281,7 +281,7 @@ class ZooMarkTest(TestBase): global metadata, session player = lambda: dbapi_session.player() - engine = create_engine('postgres:///', creator=player) + engine = create_engine('postgresql:///', creator=player) metadata = MetaData(engine) session = sessionmaker()() diff --git a/test/base/test_dependency.py b/test/base/test_dependency.py index 0457d552a..890dd7607 100644 --- a/test/base/test_dependency.py +++ b/test/base/test_dependency.py @@ -178,9 +178,12 @@ class DependencySortTest(TestBase): self.assert_sort(tuples, head) def testbigsort(self): - tuples = [] - for i in range(0,1500, 2): - tuples.append((i, i+1)) + tuples = [(i, i + 1) for i in range(0, 1500, 2)] head = topological.sort_as_tree(tuples, []) + def testids(self): + # ticket:1380 regression: would raise a KeyError + topological.sort([(id(i), i) for i in range(3)], []) + + diff --git a/test/base/test_except.py b/test/base/test_except.py index efb18a153..fbe0a05de 100644 --- a/test/base/test_except.py +++ b/test/base/test_except.py @@ -1,10 +1,14 @@ """Tests exceptions and DB-API exception wrapping.""" -import exceptions as stdlib_exceptions from sqlalchemy import exc as sa_exceptions from sqlalchemy.test import TestBase +# Py3K +#StandardError = BaseException +# Py2K +from exceptions import StandardError, KeyboardInterrupt, SystemExit +# end Py2K -class Error(stdlib_exceptions.StandardError): +class Error(StandardError): """This class will be old-style on <= 2.4 and new-style on >= 2.5.""" class DatabaseError(Error): pass @@ -101,19 +105,19 @@ class WrapTest(TestBase): def test_db_error_keyboard_interrupt(self): try: raise sa_exceptions.DBAPIError.instance( - '', [], stdlib_exceptions.KeyboardInterrupt()) + '', [], KeyboardInterrupt()) except sa_exceptions.DBAPIError: self.assert_(False) - except stdlib_exceptions.KeyboardInterrupt: + except KeyboardInterrupt: self.assert_(True) def test_db_error_system_exit(self): try: raise sa_exceptions.DBAPIError.instance( - '', [], stdlib_exceptions.SystemExit()) + '', [], SystemExit()) except sa_exceptions.DBAPIError: self.assert_(False) - except stdlib_exceptions.SystemExit: + except SystemExit: self.assert_(True) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 39561e968..e4c2eaba0 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -3,6 +3,7 @@ import copy, threading from sqlalchemy import util, sql, exc from sqlalchemy.test import TestBase from sqlalchemy.test.testing import eq_, is_, ne_ +from sqlalchemy.test.util import gc_collect class OrderedDictTest(TestBase): def test_odict(self): @@ -260,7 +261,7 @@ class IdentitySetTest(TestBase): except TypeError: assert True - assert_raises(TypeError, cmp, ids) + assert_raises(TypeError, util.cmp, ids) assert_raises(TypeError, hash, ids) def test_difference(self): @@ -325,11 +326,13 @@ class DictlikeIteritemsTest(TestBase): d = subdict(a=1,b=2,c=3) self._ok(d) + # Py2K def test_UserDict(self): import UserDict d = UserDict.UserDict(a=1,b=2,c=3) self._ok(d) - + # end Py2K + def test_object(self): self._notok(object()) @@ -339,12 +342,15 @@ class DictlikeIteritemsTest(TestBase): return iter(self.baseline) self._ok(duck1()) + # Py2K def test_duck_2(self): class duck2(object): def items(duck): return list(self.baseline) self._ok(duck2()) + # end Py2K + # Py2K def test_duck_3(self): class duck3(object): def iterkeys(duck): @@ -352,6 +358,7 @@ class DictlikeIteritemsTest(TestBase): def __getitem__(duck, key): return dict(a=1,b=2,c=3).get(key) self._ok(duck3()) + # end Py2K def test_duck_4(self): class duck4(object): @@ -376,16 +383,20 @@ class DictlikeIteritemsTest(TestBase): class DuckTypeCollectionTest(TestBase): def test_sets(self): + # Py2K import sets + # end Py2K class SetLike(object): def add(self): pass class ForcedSet(list): __emulates__ = set - + for type_ in (set, + # Py2K sets.Set, + # end Py2K SetLike, ForcedSet): eq_(util.duck_type_collection(type_), set) @@ -393,12 +404,14 @@ class DuckTypeCollectionTest(TestBase): eq_(util.duck_type_collection(instance), set) for type_ in (frozenset, - sets.ImmutableSet): + # Py2K + sets.ImmutableSet + # end Py2K + ): is_(util.duck_type_collection(type_), None) instance = type_() is_(util.duck_type_collection(instance), None) - class ArgInspectionTest(TestBase): def test_get_cls_kwargs(self): class A(object): @@ -646,6 +659,8 @@ class WeakIdentityMappingTest(TestBase): assert len(data) == len(wim) == len(wim.by_id) del data[:] + gc_collect() + eq_(wim, {}) eq_(wim.by_id, {}) eq_(wim._weakrefs, {}) @@ -657,6 +672,7 @@ class WeakIdentityMappingTest(TestBase): oid = id(data[0]) del data[0] + gc_collect() assert len(data) == len(wim) == len(wim.by_id) assert oid not in wim.by_id @@ -679,6 +695,7 @@ class WeakIdentityMappingTest(TestBase): th.start() cv.wait() cv.release() + gc_collect() eq_(wim, {}) eq_(wim.by_id, {}) @@ -939,7 +956,8 @@ class TestClassHierarchy(TestBase): eq_(set(util.class_hierarchy(A)), set((A, B, C, object))) eq_(set(util.class_hierarchy(B)), set((A, B, C, object))) - + + # Py2K def test_oldstyle_mixin(self): class A(object): pass @@ -953,5 +971,5 @@ class TestClassHierarchy(TestBase): eq_(set(util.class_hierarchy(B)), set((A, B, object))) eq_(set(util.class_hierarchy(Mixin)), set()) eq_(set(util.class_hierarchy(A)), set((A, B, object))) - + # end Py2K diff --git a/test/dialect/test_firebird.py b/test/dialect/test_firebird.py index fa608c9a1..2dc6af91b 100644 --- a/test/dialect/test_firebird.py +++ b/test/dialect/test_firebird.py @@ -50,6 +50,7 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): con.execute('DROP GENERATOR gen_testtable_id') def test_table_is_reflected(self): + from sqlalchemy.types import Integer, Text, Binary, String, Date, Time, DateTime metadata = MetaData(testing.db) table = Table('testtable', metadata, autoload=True) eq_(set(table.columns.keys()), @@ -57,17 +58,17 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): "Columns of reflected table didn't equal expected columns") eq_(table.c.question.primary_key, True) eq_(table.c.question.sequence.name, 'gen_testtable_id') - eq_(table.c.question.type.__class__, firebird.FBInteger) + assert isinstance(table.c.question.type, Integer) eq_(table.c.question.server_default.arg.text, "42") - eq_(table.c.answer.type.__class__, firebird.FBString) + assert isinstance(table.c.answer.type, String) eq_(table.c.answer.server_default.arg.text, "'no answer'") - eq_(table.c.remark.type.__class__, firebird.FBText) + assert isinstance(table.c.remark.type, Text) eq_(table.c.remark.server_default.arg.text, "''") - eq_(table.c.photo.type.__class__, firebird.FBBinary) + assert isinstance(table.c.photo.type, Binary) # The following assume a Dialect 3 database - eq_(table.c.d.type.__class__, firebird.FBDate) - eq_(table.c.t.type.__class__, firebird.FBTime) - eq_(table.c.dt.type.__class__, firebird.FBDateTime) + assert isinstance(table.c.d.type, Date) + assert isinstance(table.c.t.type, Time) + assert isinstance(table.c.dt.type, DateTime) class CompileTest(TestBase, AssertsCompiledSQL): @@ -76,7 +77,13 @@ class CompileTest(TestBase, AssertsCompiledSQL): def test_alias(self): t = table('sometable', column('col1'), column('col2')) s = select([t.alias()]) - self.assert_compile(s, "SELECT sometable_1.col1, sometable_1.col2 FROM sometable sometable_1") + self.assert_compile(s, "SELECT sometable_1.col1, sometable_1.col2 FROM sometable AS sometable_1") + + dialect = firebird.FBDialect() + dialect._version_two = False + self.assert_compile(s, "SELECT sometable_1.col1, sometable_1.col2 FROM sometable sometable_1", + dialect = dialect + ) def test_function(self): self.assert_compile(func.foo(1, 2), "foo(:foo_1, :foo_2)") @@ -98,15 +105,15 @@ class CompileTest(TestBase, AssertsCompiledSQL): column('description', String(128)), ) - u = update(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name]) + u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING mytable.myid, mytable.name") - u = update(table1, values=dict(name='foo'), firebird_returning=[table1]) + u = update(table1, values=dict(name='foo')).returning(table1) self.assert_compile(u, "UPDATE mytable SET name=:name "\ "RETURNING mytable.myid, mytable.name, mytable.description") - u = update(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)]) - self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name)") + u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) + self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name) AS length_1") def test_insert_returning(self): table1 = table('mytable', @@ -115,90 +122,20 @@ class CompileTest(TestBase, AssertsCompiledSQL): column('description', String(128)), ) - i = insert(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name]) + i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING mytable.myid, mytable.name") - i = insert(table1, values=dict(name='foo'), firebird_returning=[table1]) + i = insert(table1, values=dict(name='foo')).returning(table1) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) "\ "RETURNING mytable.myid, mytable.name, mytable.description") - i = insert(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)]) - self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name)") + i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name) AS length_1") -class ReturningTest(TestBase, AssertsExecutionResults): - __only_on__ = 'firebird' - - @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') - def test_update_returning(self): - meta = MetaData(testing.db) - table = Table('tables', meta, - Column('id', Integer, Sequence('gen_tables_id'), primary_key=True), - Column('persons', Integer), - Column('full', Boolean) - ) - table.create() - try: - table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) - - result = table.update(table.c.persons > 4, dict(full=True), firebird_returning=[table.c.id]).execute() - eq_(result.fetchall(), [(1,)]) - - result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() - eq_(result2.fetchall(), [(1,True),(2,False)]) - finally: - table.drop() - - @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') - def test_insert_returning(self): - meta = MetaData(testing.db) - table = Table('tables', meta, - Column('id', Integer, Sequence('gen_tables_id'), primary_key=True), - Column('persons', Integer), - Column('full', Boolean) - ) - table.create() - try: - result = table.insert(firebird_returning=[table.c.id]).execute({'persons': 1, 'full': False}) - - eq_(result.fetchall(), [(1,)]) - - # Multiple inserts only return the last row - result2 = table.insert(firebird_returning=[table]).execute( - [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}]) - - eq_(result2.fetchall(), [(3,3,True)]) - - result3 = table.insert(firebird_returning=[table.c.id]).execute({'persons': 4, 'full': False}) - eq_([dict(row) for row in result3], [{'ID':4}]) - - result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, 1) returning persons') - eq_([dict(row) for row in result4], [{'PERSONS': 10}]) - finally: - table.drop() - - @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') - def test_delete_returning(self): - meta = MetaData(testing.db) - table = Table('tables', meta, - Column('id', Integer, Sequence('gen_tables_id'), primary_key=True), - Column('persons', Integer), - Column('full', Boolean) - ) - table.create() - try: - table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) - - result = table.delete(table.c.persons > 4, firebird_returning=[table.c.id]).execute() - eq_(result.fetchall(), [(1,)]) - - result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() - eq_(result2.fetchall(), [(2,False),]) - finally: - table.drop() -class MiscFBTests(TestBase): +class MiscTest(TestBase): __only_on__ = 'firebird' def test_strlen(self): @@ -217,12 +154,20 @@ class MiscFBTests(TestBase): try: t.insert(values=dict(name='dante')).execute() t.insert(values=dict(name='alighieri')).execute() - select([func.count(t.c.id)],func.length(t.c.name)==5).execute().fetchone()[0] == 1 + select([func.count(t.c.id)],func.length(t.c.name)==5).execute().first()[0] == 1 finally: meta.drop_all() def test_server_version_info(self): - version = testing.db.dialect.server_version_info(testing.db.connect()) + version = testing.db.dialect.server_version_info assert len(version) == 3, "Got strange version info: %s" % repr(version) + def test_percents_in_text(self): + for expr, result in ( + (text("select '%' from rdb$database"), '%'), + (text("select '%%' from rdb$database"), '%%'), + (text("select '%%%' from rdb$database"), '%%%'), + (text("select 'hello % world' from rdb$database"), "hello % world") + ): + eq_(testing.db.scalar(expr), result) diff --git a/test/dialect/test_informix.py b/test/dialect/test_informix.py index 86a4e751d..e647990d3 100644 --- a/test/dialect/test_informix.py +++ b/test/dialect/test_informix.py @@ -4,7 +4,8 @@ from sqlalchemy.test import * class CompileTest(TestBase, AssertsCompiledSQL): - __dialect__ = informix.InfoDialect() + __only_on__ = 'informix' + __dialect__ = informix.InformixDialect() def test_statements(self): meta =MetaData() diff --git a/test/dialect/test_maxdb.py b/test/dialect/test_maxdb.py index 033a05533..c69a81120 100644 --- a/test/dialect/test_maxdb.py +++ b/test/dialect/test_maxdb.py @@ -185,7 +185,7 @@ class DBAPITest(TestBase, AssertsExecutionResults): vals = [] for i in xrange(3): cr.execute('SELECT busto.NEXTVAL FROM DUAL') - vals.append(cr.fetchone()[0]) + vals.append(cr.first()[0]) # should be 1,2,3, but no... self.assert_(vals != [1,2,3]) diff --git a/test/dialect/test_mssql.py b/test/dialect/test_mssql.py index dd86ce0de..423310db6 100644 --- a/test/dialect/test_mssql.py +++ b/test/dialect/test_mssql.py @@ -2,17 +2,18 @@ from sqlalchemy.test.testing import eq_ import datetime, os, re from sqlalchemy import * -from sqlalchemy import types, exc +from sqlalchemy import types, exc, schema from sqlalchemy.orm import * from sqlalchemy.sql import table, column from sqlalchemy.databases import mssql -import sqlalchemy.engine.url as url +from sqlalchemy.dialects.mssql import pyodbc +from sqlalchemy.engine import url from sqlalchemy.test import * from sqlalchemy.test.testing import eq_ class CompileTest(TestBase, AssertsCompiledSQL): - __dialect__ = mssql.MSSQLDialect() + __dialect__ = mssql.dialect() def test_insert(self): t = table('sometable', column('somecolumn')) @@ -157,6 +158,45 @@ class CompileTest(TestBase, AssertsCompiledSQL): select([extract(field, t.c.col1)]), 'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % field) + def test_update_returning(self): + table1 = table('mytable', + column('myid', Integer), + column('name', String(128)), + column('description', String(128)), + ) + + u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) + self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, inserted.name") + + u = update(table1, values=dict(name='foo')).returning(table1) + self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, " + "inserted.name, inserted.description") + + u = update(table1, values=dict(name='foo')).returning(table1).where(table1.c.name=='bar') + self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, " + "inserted.name, inserted.description WHERE mytable.name = :name_1") + + u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) + self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT LEN(inserted.name) AS length_1") + + def test_insert_returning(self): + table1 = table('mytable', + column('myid', Integer), + column('name', String(128)), + column('description', String(128)), + ) + + i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) + self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT inserted.myid, inserted.name VALUES (:name)") + + i = insert(table1, values=dict(name='foo')).returning(table1) + self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT inserted.myid, " + "inserted.name, inserted.description VALUES (:name)") + + i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) + self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT LEN(inserted.name) AS length_1 VALUES (:name)") + + class IdentityInsertTest(TestBase, AssertsCompiledSQL): __only_on__ = 'mssql' @@ -189,9 +229,9 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL): eq_([(9, 'Python')], list(cats)) result = cattable.insert().values(description='PHP').execute() - eq_([10], result.last_inserted_ids()) + eq_([10], result.inserted_primary_key) lastcat = cattable.select().order_by(desc(cattable.c.id)).execute() - eq_((10, 'PHP'), lastcat.fetchone()) + eq_((10, 'PHP'), lastcat.first()) def test_executemany(self): cattable.insert().execute([ @@ -213,10 +253,51 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL): eq_([(91, 'Smalltalk'), (90, 'PHP')], list(lastcats)) -class ReflectionTest(TestBase): +class ReflectionTest(TestBase, ComparesTables): __only_on__ = 'mssql' - def testidentity(self): + def test_basic_reflection(self): + meta = MetaData(testing.db) + + users = Table('engine_users', meta, + Column('user_id', types.INT, primary_key=True), + Column('user_name', types.VARCHAR(20), nullable=False), + Column('test1', types.CHAR(5), nullable=False), + Column('test2', types.Float(5), nullable=False), + Column('test3', types.Text), + Column('test4', types.Numeric, nullable = False), + Column('test5', types.DateTime), + Column('parent_user_id', types.Integer, + ForeignKey('engine_users.user_id')), + Column('test6', types.DateTime, nullable=False), + Column('test7', types.Text), + Column('test8', types.Binary), + Column('test_passivedefault2', types.Integer, server_default='5'), + Column('test9', types.Binary(100)), + Column('test_numeric', types.Numeric()), + test_needs_fk=True, + ) + + addresses = Table('engine_email_addresses', meta, + Column('address_id', types.Integer, primary_key = True), + Column('remote_user_id', types.Integer, ForeignKey(users.c.user_id)), + Column('email_address', types.String(20)), + test_needs_fk=True, + ) + meta.create_all() + + try: + meta2 = MetaData() + reflected_users = Table('engine_users', meta2, autoload=True, + autoload_with=testing.db) + reflected_addresses = Table('engine_email_addresses', meta2, + autoload=True, autoload_with=testing.db) + self.assert_tables_equal(users, reflected_users) + self.assert_tables_equal(addresses, reflected_addresses) + finally: + meta.drop_all() + + def test_identity(self): meta = MetaData(testing.db) table = Table( 'identity_test', meta, @@ -240,7 +321,7 @@ class QueryUnicodeTest(TestBase): meta = MetaData(testing.db) t1 = Table('unitest_table', meta, Column('id', Integer, primary_key=True), - Column('descr', mssql.MSText(200, convert_unicode=True))) + Column('descr', mssql.MSText(convert_unicode=True))) meta.create_all() con = testing.db.connect() @@ -248,7 +329,7 @@ class QueryUnicodeTest(TestBase): con.execute(u"insert into unitest_table values ('bien mangé')".encode('UTF-8')) try: - r = t1.select().execute().fetchone() + r = t1.select().execute().first() assert isinstance(r[1], unicode), '%s is %s instead of unicode, working on %s' % ( r[1], type(r[1]), meta.bind) @@ -262,7 +343,9 @@ class QueryTest(TestBase): meta = MetaData(testing.db) t1 = Table('t1', meta, Column('id', Integer, Sequence('fred', 100, 1), primary_key=True), - Column('descr', String(200))) + Column('descr', String(200)), + implicit_returning = False + ) t2 = Table('t2', meta, Column('id', Integer, Sequence('fred', 200, 1), primary_key=True), Column('descr', String(200))) @@ -274,9 +357,9 @@ class QueryTest(TestBase): try: tr = con.begin() r = con.execute(t2.insert(), descr='hello') - self.assert_(r.last_inserted_ids() == [200]) + self.assert_(r.inserted_primary_key == [200]) r = con.execute(t1.insert(), descr='hello') - self.assert_(r.last_inserted_ids() == [100]) + self.assert_(r.inserted_primary_key == [100]) finally: tr.commit() @@ -295,6 +378,19 @@ class QueryTest(TestBase): tbl.drop() con.execute('drop schema paj') + def test_returning_no_autoinc(self): + meta = MetaData(testing.db) + + table = Table('t1', meta, Column('id', Integer, primary_key=True), Column('data', String(50))) + table.create() + try: + result = table.insert().values(id=1, data=func.lower("SomeString")).returning(table.c.id, table.c.data).execute() + eq_(result.fetchall(), [(1, 'somestring',)]) + finally: + # this will hang if the "SET IDENTITY_INSERT t1 OFF" occurs before the + # result is fetched + table.drop() + def test_delete_schema(self): meta = MetaData(testing.db) con = testing.db.connect() @@ -371,36 +467,26 @@ class SchemaTest(TestBase): ) self.column = t.c.test_column + dialect = mssql.dialect() + self.ddl_compiler = dialect.ddl_compiler(dialect, schema.CreateTable(t)) + + def _column_spec(self): + return self.ddl_compiler.get_column_specification(self.column) + def test_that_mssql_default_nullability_emits_null(self): - schemagenerator = \ - mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None) - column_specification = \ - schemagenerator.get_column_specification(self.column) - eq_("test_column VARCHAR NULL", column_specification) + eq_("test_column VARCHAR NULL", self._column_spec()) def test_that_mssql_none_nullability_does_not_emit_nullability(self): - schemagenerator = \ - mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None) self.column.nullable = None - column_specification = \ - schemagenerator.get_column_specification(self.column) - eq_("test_column VARCHAR", column_specification) + eq_("test_column VARCHAR", self._column_spec()) def test_that_mssql_specified_nullable_emits_null(self): - schemagenerator = \ - mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None) self.column.nullable = True - column_specification = \ - schemagenerator.get_column_specification(self.column) - eq_("test_column VARCHAR NULL", column_specification) + eq_("test_column VARCHAR NULL", self._column_spec()) def test_that_mssql_specified_not_nullable_emits_not_null(self): - schemagenerator = \ - mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None) self.column.nullable = False - column_specification = \ - schemagenerator.get_column_specification(self.column) - eq_("test_column VARCHAR NOT NULL", column_specification) + eq_("test_column VARCHAR NOT NULL", self._column_spec()) def full_text_search_missing(): @@ -515,79 +601,73 @@ class MatchTest(TestBase, AssertsCompiledSQL): class ParseConnectTest(TestBase, AssertsCompiledSQL): __only_on__ = 'mssql' + @classmethod + def setup_class(cls): + global dialect + dialect = pyodbc.MSDialect_pyodbc() + def test_pyodbc_connect_dsn_trusted(self): u = url.make_url('mssql://mydsn') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['dsn=mydsn;TrustedConnection=Yes'], {}], connection) def test_pyodbc_connect_old_style_dsn_trusted(self): u = url.make_url('mssql:///?dsn=mydsn') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['dsn=mydsn;TrustedConnection=Yes'], {}], connection) def test_pyodbc_connect_dsn_non_trusted(self): u = url.make_url('mssql://username:password@mydsn') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['dsn=mydsn;UID=username;PWD=password'], {}], connection) def test_pyodbc_connect_dsn_extra(self): u = url.make_url('mssql://username:password@mydsn/?LANGUAGE=us_english&foo=bar') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['dsn=mydsn;UID=username;PWD=password;LANGUAGE=us_english;foo=bar'], {}], connection) def test_pyodbc_connect(self): u = url.make_url('mssql://username:password@hostspec/database') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection) def test_pyodbc_connect_comma_port(self): u = url.make_url('mssql://username:password@hostspec:12345/database') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['DRIVER={SQL Server};Server=hostspec,12345;Database=database;UID=username;PWD=password'], {}], connection) def test_pyodbc_connect_config_port(self): u = url.make_url('mssql://username:password@hostspec/database?port=12345') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;port=12345'], {}], connection) def test_pyodbc_extra_connect(self): u = url.make_url('mssql://username:password@hostspec/database?LANGUAGE=us_english&foo=bar') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;foo=bar;LANGUAGE=us_english'], {}], connection) def test_pyodbc_odbc_connect(self): u = url.make_url('mssql:///?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection) def test_pyodbc_odbc_connect_with_dsn(self): u = url.make_url('mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['dsn=mydsn;Database=database;UID=username;PWD=password'], {}], connection) def test_pyodbc_odbc_connect_ignores_other_values(self): u = url.make_url('mssql://userdiff:passdiff@localhost/dbdiff?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection) -class TypesTest(TestBase): +class TypesTest(TestBase, AssertsExecutionResults, ComparesTables): __only_on__ = 'mssql' @classmethod def setup_class(cls): - global numeric_table, metadata + global metadata metadata = MetaData(testing.db) def teardown(self): @@ -601,26 +681,22 @@ class TypesTest(TestBase): ) metadata.create_all() - try: - test_items = [decimal.Decimal(d) for d in '1500000.00000000000000000000', - '-1500000.00000000000000000000', '1500000', - '0.0000000000000000002', '0.2', '-0.0000000000000000002', '-2E-2', - '156666.458923543', '-156666.458923543', '1', '-1', '-1234', '1234', - '2E-12', '4E8', '3E-6', '3E-7', '4.1', '1E-1', '1E-2', '1E-3', - '1E-4', '1E-5', '1E-6', '1E-7', '1E-1', '1E-8', '0.2732E2', '-0.2432E2', '4.35656E2', - '-02452E-2', '45125E-2', - '1234.58965E-2', '1.521E+15', '-1E-25', '1E-25', '1254E-25', '-1203E-25', - '0', '-0.00', '-0', '4585E12', '000000000000000000012', '000000000000.32E12', - '00000000000000.1E+12', '000000000000.2E-32'] + test_items = [decimal.Decimal(d) for d in '1500000.00000000000000000000', + '-1500000.00000000000000000000', '1500000', + '0.0000000000000000002', '0.2', '-0.0000000000000000002', '-2E-2', + '156666.458923543', '-156666.458923543', '1', '-1', '-1234', '1234', + '2E-12', '4E8', '3E-6', '3E-7', '4.1', '1E-1', '1E-2', '1E-3', + '1E-4', '1E-5', '1E-6', '1E-7', '1E-1', '1E-8', '0.2732E2', '-0.2432E2', '4.35656E2', + '-02452E-2', '45125E-2', + '1234.58965E-2', '1.521E+15', '-1E-25', '1E-25', '1254E-25', '-1203E-25', + '0', '-0.00', '-0', '4585E12', '000000000000000000012', '000000000000.32E12', + '00000000000000.1E+12', '000000000000.2E-32'] - for value in test_items: - numeric_table.insert().execute(numericcol=value) + for value in test_items: + numeric_table.insert().execute(numericcol=value) - for value in select([numeric_table.c.numericcol]).execute(): - assert value[0] in test_items, "%s not in test_items" % value[0] - - except Exception, e: - raise e + for value in select([numeric_table.c.numericcol]).execute(): + assert value[0] in test_items, "%s not in test_items" % value[0] def test_float(self): float_table = Table('float_table', metadata, @@ -643,11 +719,6 @@ class TypesTest(TestBase): raise e -class TypesTest2(TestBase, AssertsExecutionResults): - "Test Microsoft SQL Server column types" - - __only_on__ = 'mssql' - def test_money(self): "Exercise type specification for money types." @@ -659,13 +730,14 @@ class TypesTest2(TestBase, AssertsExecutionResults): 'SMALLMONEY'), ] - table_args = ['test_mssql_money', MetaData(testing.db)] + table_args = ['test_mssql_money', metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) money_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + dialect = mssql.dialect() + gen = dialect.ddl_compiler(dialect, schema.CreateTable(money_table)) for col in money_table.c: index = int(col.name[1:]) @@ -688,15 +760,27 @@ class TypesTest2(TestBase, AssertsExecutionResults): (mssql.MSDateTime, [], {}, 'DATETIME', []), + (types.DATE, [], {}, + 'DATE', ['>=', (10,)]), + (types.Date, [], {}, + 'DATE', ['>=', (10,)]), + (types.Date, [], {}, + 'DATETIME', ['<', (10,)], mssql.MSDateTime), (mssql.MSDate, [], {}, 'DATE', ['>=', (10,)]), (mssql.MSDate, [], {}, 'DATETIME', ['<', (10,)], mssql.MSDateTime), + (types.TIME, [], {}, + 'TIME', ['>=', (10,)]), + (types.Time, [], {}, + 'TIME', ['>=', (10,)]), (mssql.MSTime, [], {}, 'TIME', ['>=', (10,)]), (mssql.MSTime, [1], {}, 'TIME(1)', ['>=', (10,)]), + (types.Time, [], {}, + 'DATETIME', ['<', (10,)], mssql.MSDateTime), (mssql.MSTime, [], {}, 'DATETIME', ['<', (10,)], mssql.MSDateTime), @@ -715,14 +799,14 @@ class TypesTest2(TestBase, AssertsExecutionResults): ] - table_args = ['test_mssql_dates', MetaData(testing.db)] + table_args = ['test_mssql_dates', metadata] for index, spec in enumerate(columns): type_, args, kw, res, requires = spec[0:5] if (requires and testing._is_excluded('mssql', *requires)) or not requires: table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) dates_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + gen = testing.db.dialect.ddl_compiler(testing.db.dialect, schema.CreateTable(dates_table)) for col in dates_table.c: index = int(col.name[1:]) @@ -730,49 +814,37 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - try: - dates_table.create(checkfirst=True) - assert True - except: - raise + dates_table.create(checkfirst=True) reflected_dates = Table('test_mssql_dates', MetaData(testing.db), autoload=True) for col in reflected_dates.c: - index = int(col.name[1:]) - testing.eq_(testing.db.dialect.type_descriptor(col.type).__class__, - len(columns[index]) > 5 and columns[index][5] or columns[index][0]) - dates_table.drop() - - def test_dates2(self): - meta = MetaData(testing.db) - t = Table('test_dates', meta, - Column('id', Integer, - Sequence('datetest_id_seq', optional=True), - primary_key=True), - Column('adate', Date), - Column('atime', Time), - Column('adatetime', DateTime)) - t.create(checkfirst=True) - try: - d1 = datetime.date(2007, 10, 30) - t1 = datetime.time(11, 2, 32) - d2 = datetime.datetime(2007, 10, 30, 11, 2, 32) - t.insert().execute(adate=d1, adatetime=d2, atime=t1) - t.insert().execute(adate=d2, adatetime=d2, atime=d2) + self.assert_types_base(col, dates_table.c[col.key]) - x = t.select().execute().fetchall()[0] - self.assert_(x.adate.__class__ == datetime.date) - self.assert_(x.atime.__class__ == datetime.time) - self.assert_(x.adatetime.__class__ == datetime.datetime) + def test_date_roundtrip(self): + t = Table('test_dates', metadata, + Column('id', Integer, + Sequence('datetest_id_seq', optional=True), + primary_key=True), + Column('adate', Date), + Column('atime', Time), + Column('adatetime', DateTime)) + metadata.create_all() + d1 = datetime.date(2007, 10, 30) + t1 = datetime.time(11, 2, 32) + d2 = datetime.datetime(2007, 10, 30, 11, 2, 32) + t.insert().execute(adate=d1, adatetime=d2, atime=t1) + t.insert().execute(adate=d2, adatetime=d2, atime=d2) - t.delete().execute() + x = t.select().execute().fetchall()[0] + self.assert_(x.adate.__class__ == datetime.date) + self.assert_(x.atime.__class__ == datetime.time) + self.assert_(x.adatetime.__class__ == datetime.datetime) - t.insert().execute(adate=d1, adatetime=d2, atime=t1) + t.delete().execute() - eq_(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)]) + t.insert().execute(adate=d1, adatetime=d2, atime=t1) - finally: - t.drop(checkfirst=True) + eq_(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)]) def test_binary(self): "Exercise type specification for binary types." @@ -781,6 +853,9 @@ class TypesTest2(TestBase, AssertsExecutionResults): # column type, args, kwargs, expected ddl (mssql.MSBinary, [], {}, 'BINARY'), + (types.Binary, [10], {}, + 'BINARY(10)'), + (mssql.MSBinary, [10], {}, 'BINARY(10)'), @@ -798,13 +873,14 @@ class TypesTest2(TestBase, AssertsExecutionResults): 'BINARY(10)') ] - table_args = ['test_mssql_binary', MetaData(testing.db)] + table_args = ['test_mssql_binary', metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) binary_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + dialect = mssql.dialect() + gen = dialect.ddl_compiler(dialect, schema.CreateTable(binary_table)) for col in binary_table.c: index = int(col.name[1:]) @@ -812,22 +888,15 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - try: - binary_table.create(checkfirst=True) - assert True - except: - raise + metadata.create_all() reflected_binary = Table('test_mssql_binary', MetaData(testing.db), autoload=True) for col in reflected_binary.c: - # don't test the MSGenericBinary since it's a special case and - # reflected it will map to a MSImage or MSBinary depending - if not testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__ == mssql.MSGenericBinary: - testing.eq_(testing.db.dialect.type_descriptor(col.type).__class__, - testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__) + c1 =testing.db.dialect.type_descriptor(col.type).__class__ + c2 =testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__ + assert issubclass(c1, c2), "%r is not a subclass of %r" % (c1, c2) if binary_table.c[col.name].type.length: testing.eq_(col.type.length, binary_table.c[col.name].type.length) - binary_table.drop() def test_boolean(self): "Exercise type specification for boolean type." @@ -838,13 +907,14 @@ class TypesTest2(TestBase, AssertsExecutionResults): 'BIT'), ] - table_args = ['test_mssql_boolean', MetaData(testing.db)] + table_args = ['test_mssql_boolean', metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) boolean_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + dialect = mssql.dialect() + gen = dialect.ddl_compiler(dialect, schema.CreateTable(boolean_table)) for col in boolean_table.c: index = int(col.name[1:]) @@ -852,12 +922,7 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - try: - boolean_table.create(checkfirst=True) - assert True - except: - raise - boolean_table.drop() + metadata.create_all() def test_numeric(self): "Exercise type specification and options for numeric types." @@ -865,40 +930,39 @@ class TypesTest2(TestBase, AssertsExecutionResults): columns = [ # column type, args, kwargs, expected ddl (mssql.MSNumeric, [], {}, - 'NUMERIC(10, 2)'), + 'NUMERIC'), (mssql.MSNumeric, [None], {}, 'NUMERIC'), - (mssql.MSNumeric, [12], {}, - 'NUMERIC(12, 2)'), (mssql.MSNumeric, [12, 4], {}, 'NUMERIC(12, 4)'), - (mssql.MSFloat, [], {}, - 'FLOAT(10)'), - (mssql.MSFloat, [None], {}, + (types.Float, [], {}, + 'FLOAT'), + (types.Float, [None], {}, 'FLOAT'), - (mssql.MSFloat, [12], {}, + (types.Float, [12], {}, 'FLOAT(12)'), (mssql.MSReal, [], {}, 'REAL'), - (mssql.MSInteger, [], {}, + (types.Integer, [], {}, 'INTEGER'), - (mssql.MSBigInteger, [], {}, + (types.BigInteger, [], {}, 'BIGINT'), (mssql.MSTinyInteger, [], {}, 'TINYINT'), - (mssql.MSSmallInteger, [], {}, + (types.SmallInteger, [], {}, 'SMALLINT'), ] - table_args = ['test_mssql_numeric', MetaData(testing.db)] + table_args = ['test_mssql_numeric', metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) numeric_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + dialect = mssql.dialect() + gen = dialect.ddl_compiler(dialect, schema.CreateTable(numeric_table)) for col in numeric_table.c: index = int(col.name[1:]) @@ -906,20 +970,11 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - try: - numeric_table.create(checkfirst=True) - assert True - except: - raise - numeric_table.drop() + metadata.create_all() def test_char(self): """Exercise COLLATE-ish options on string types.""" - # modify the text_as_varchar setting since we are not testing that behavior here - text_as_varchar = testing.db.dialect.text_as_varchar - testing.db.dialect.text_as_varchar = False - columns = [ (mssql.MSChar, [], {}, 'CHAR'), @@ -960,13 +1015,14 @@ class TypesTest2(TestBase, AssertsExecutionResults): 'NTEXT COLLATE Latin1_General_CI_AS'), ] - table_args = ['test_mssql_charset', MetaData(testing.db)] + table_args = ['test_mssql_charset', metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) charset_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + dialect = mssql.dialect() + gen = dialect.ddl_compiler(dialect, schema.CreateTable(charset_table)) for col in charset_table.c: index = int(col.name[1:]) @@ -974,110 +1030,91 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - try: - charset_table.create(checkfirst=True) - assert True - except: - raise - charset_table.drop() - - testing.db.dialect.text_as_varchar = text_as_varchar + metadata.create_all() def test_timestamp(self): """Exercise TIMESTAMP column.""" - meta = MetaData(testing.db) - - try: - columns = [ - (TIMESTAMP, - 'TIMESTAMP'), - (mssql.MSTimeStamp, - 'TIMESTAMP'), - ] - for idx, (spec, expected) in enumerate(columns): - t = Table('mssql_ts%s' % idx, meta, - Column('id', Integer, primary_key=True), - Column('t', spec, nullable=None)) - testing.eq_(colspec(t.c.t), "t %s" % expected) - self.assert_(repr(t.c.t)) - try: - t.create(checkfirst=True) - assert True - except: - raise - t.drop() - finally: - meta.drop_all() + dialect = mssql.dialect() + spec, expected = (TIMESTAMP,'TIMESTAMP') + t = Table('mssql_ts', metadata, + Column('id', Integer, primary_key=True), + Column('t', spec, nullable=None)) + gen = dialect.ddl_compiler(dialect, schema.CreateTable(t)) + testing.eq_(gen.get_column_specification(t.c.t), "t %s" % expected) + self.assert_(repr(t.c.t)) + t.create(checkfirst=True) + def test_autoincrement(self): - meta = MetaData(testing.db) - try: - Table('ai_1', meta, - Column('int_y', Integer, primary_key=True), - Column('int_n', Integer, DefaultClause('0'), - primary_key=True)) - Table('ai_2', meta, - Column('int_y', Integer, primary_key=True), - Column('int_n', Integer, DefaultClause('0'), - primary_key=True)) - Table('ai_3', meta, - Column('int_n', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False), - Column('int_y', Integer, primary_key=True)) - Table('ai_4', meta, - Column('int_n', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False), - Column('int_n2', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False)) - Table('ai_5', meta, - Column('int_y', Integer, primary_key=True), - Column('int_n', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False)) - Table('ai_6', meta, - Column('o1', String(1), DefaultClause('x'), - primary_key=True), - Column('int_y', Integer, primary_key=True)) - Table('ai_7', meta, - Column('o1', String(1), DefaultClause('x'), - primary_key=True), - Column('o2', String(1), DefaultClause('x'), - primary_key=True), - Column('int_y', Integer, primary_key=True)) - Table('ai_8', meta, - Column('o1', String(1), DefaultClause('x'), - primary_key=True), - Column('o2', String(1), DefaultClause('x'), - primary_key=True)) - meta.create_all() - - table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4', - 'ai_5', 'ai_6', 'ai_7', 'ai_8'] - mr = MetaData(testing.db) - mr.reflect(only=table_names) - - for tbl in [mr.tables[name] for name in table_names]: - for c in tbl.c: - if c.name.startswith('int_y'): - assert c.autoincrement - elif c.name.startswith('int_n'): - assert not c.autoincrement - tbl.insert().execute() + Table('ai_1', metadata, + Column('int_y', Integer, primary_key=True), + Column('int_n', Integer, DefaultClause('0'), + primary_key=True)) + Table('ai_2', metadata, + Column('int_y', Integer, primary_key=True), + Column('int_n', Integer, DefaultClause('0'), + primary_key=True)) + Table('ai_3', metadata, + Column('int_n', Integer, DefaultClause('0'), + primary_key=True, autoincrement=False), + Column('int_y', Integer, primary_key=True)) + Table('ai_4', metadata, + Column('int_n', Integer, DefaultClause('0'), + primary_key=True, autoincrement=False), + Column('int_n2', Integer, DefaultClause('0'), + primary_key=True, autoincrement=False)) + Table('ai_5', metadata, + Column('int_y', Integer, primary_key=True), + Column('int_n', Integer, DefaultClause('0'), + primary_key=True, autoincrement=False)) + Table('ai_6', metadata, + Column('o1', String(1), DefaultClause('x'), + primary_key=True), + Column('int_y', Integer, primary_key=True)) + Table('ai_7', metadata, + Column('o1', String(1), DefaultClause('x'), + primary_key=True), + Column('o2', String(1), DefaultClause('x'), + primary_key=True), + Column('int_y', Integer, primary_key=True)) + Table('ai_8', metadata, + Column('o1', String(1), DefaultClause('x'), + primary_key=True), + Column('o2', String(1), DefaultClause('x'), + primary_key=True)) + metadata.create_all() + + table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4', + 'ai_5', 'ai_6', 'ai_7', 'ai_8'] + mr = MetaData(testing.db) + + for name in table_names: + tbl = Table(name, mr, autoload=True) + for c in tbl.c: + if c.name.startswith('int_y'): + assert c.autoincrement + elif c.name.startswith('int_n'): + assert not c.autoincrement + + for counter, engine in enumerate([ + engines.testing_engine(options={'implicit_returning':False}), + engines.testing_engine(options={'implicit_returning':True}), + ] + ): + engine.execute(tbl.insert()) if 'int_y' in tbl.c: - assert select([tbl.c.int_y]).scalar() == 1 - assert list(tbl.select().execute().fetchone()).count(1) == 1 + assert engine.scalar(select([tbl.c.int_y])) == counter + 1 + assert list(engine.execute(tbl.select()).first()).count(counter + 1) == 1 else: - assert 1 not in list(tbl.select().execute().fetchone()) - finally: - meta.drop_all() - -def colspec(c): - return testing.db.dialect.schemagenerator(testing.db.dialect, - testing.db, None, None).get_column_specification(c) - + assert 1 not in list(engine.execute(tbl.select()).first()) + engine.execute(tbl.delete()) class BinaryTest(TestBase, AssertsExecutionResults): """Test the Binary and VarBinary types""" + + __only_on__ = 'mssql' + @classmethod def setup_class(cls): global binary_table, MyPickleType @@ -1125,6 +1162,11 @@ class BinaryTest(TestBase, AssertsExecutionResults): stream2 =self.load_stream('binary_data_two.dat') binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_image=stream1, data_slice=stream1[0:100], pickled=testobj1, mypickle=testobj3) binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_image=stream2, data_slice=stream2[0:99], pickled=testobj2) + + # TODO: pyodbc does not seem to accept "None" for a VARBINARY column (data=None). + # error: [Microsoft][ODBC SQL Server Driver][SQL Server]Implicit conversion from + # data type varchar to varbinary is not allowed. Use the CONVERT function to run this query. (257) + #binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_image=None, data_slice=stream2[0:99], pickled=None) binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data_image=None, data_slice=stream2[0:99], pickled=None) for stmt in ( diff --git a/test/dialect/test_mysql.py b/test/dialect/test_mysql.py index 8adb2d71c..405264152 100644 --- a/test/dialect/test_mysql.py +++ b/test/dialect/test_mysql.py @@ -1,8 +1,12 @@ from sqlalchemy.test.testing import eq_ + +# Py2K import sets +# end Py2K + from sqlalchemy import * from sqlalchemy import sql, exc -from sqlalchemy.databases import mysql +from sqlalchemy.dialects.mysql import base as mysql from sqlalchemy.test.testing import eq_ from sqlalchemy.test import * @@ -56,11 +60,11 @@ class TypesTest(TestBase, AssertsExecutionResults): # column type, args, kwargs, expected ddl # e.g. Column(Integer(10, unsigned=True)) == 'INTEGER(10) UNSIGNED' (mysql.MSNumeric, [], {}, - 'NUMERIC(10, 2)'), + 'NUMERIC'), (mysql.MSNumeric, [None], {}, 'NUMERIC'), (mysql.MSNumeric, [12], {}, - 'NUMERIC(12, 2)'), + 'NUMERIC(12)'), (mysql.MSNumeric, [12, 4], {'unsigned':True}, 'NUMERIC(12, 4) UNSIGNED'), (mysql.MSNumeric, [12, 4], {'zerofill':True}, @@ -69,11 +73,11 @@ class TypesTest(TestBase, AssertsExecutionResults): 'NUMERIC(12, 4) UNSIGNED ZEROFILL'), (mysql.MSDecimal, [], {}, - 'DECIMAL(10, 2)'), + 'DECIMAL'), (mysql.MSDecimal, [None], {}, 'DECIMAL'), (mysql.MSDecimal, [12], {}, - 'DECIMAL(12, 2)'), + 'DECIMAL(12)'), (mysql.MSDecimal, [12, None], {}, 'DECIMAL(12)'), (mysql.MSDecimal, [12, 4], {'unsigned':True}, @@ -178,11 +182,11 @@ class TypesTest(TestBase, AssertsExecutionResults): table_args.append(Column('c%s' % index, type_(*args, **kw))) numeric_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + gen = testing.db.dialect.ddl_compiler(testing.db.dialect, numeric_table) for col in numeric_table.c: index = int(col.name[1:]) - self.assert_eq(gen.get_column_specification(col), + eq_(gen.get_column_specification(col), "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) @@ -262,11 +266,11 @@ class TypesTest(TestBase, AssertsExecutionResults): table_args.append(Column('c%s' % index, type_(*args, **kw))) charset_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + gen = testing.db.dialect.ddl_compiler(testing.db.dialect, charset_table) for col in charset_table.c: index = int(col.name[1:]) - self.assert_eq(gen.get_column_specification(col), + eq_(gen.get_column_specification(col), "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) @@ -292,14 +296,14 @@ class TypesTest(TestBase, AssertsExecutionResults): Column('b7', mysql.MSBit(63)), Column('b8', mysql.MSBit(64))) - self.assert_eq(colspec(bit_table.c.b1), 'b1 BIT') - self.assert_eq(colspec(bit_table.c.b2), 'b2 BIT') - self.assert_eq(colspec(bit_table.c.b3), 'b3 BIT NOT NULL') - self.assert_eq(colspec(bit_table.c.b4), 'b4 BIT(1)') - self.assert_eq(colspec(bit_table.c.b5), 'b5 BIT(8)') - self.assert_eq(colspec(bit_table.c.b6), 'b6 BIT(32)') - self.assert_eq(colspec(bit_table.c.b7), 'b7 BIT(63)') - self.assert_eq(colspec(bit_table.c.b8), 'b8 BIT(64)') + eq_(colspec(bit_table.c.b1), 'b1 BIT') + eq_(colspec(bit_table.c.b2), 'b2 BIT') + eq_(colspec(bit_table.c.b3), 'b3 BIT NOT NULL') + eq_(colspec(bit_table.c.b4), 'b4 BIT(1)') + eq_(colspec(bit_table.c.b5), 'b5 BIT(8)') + eq_(colspec(bit_table.c.b6), 'b6 BIT(32)') + eq_(colspec(bit_table.c.b7), 'b7 BIT(63)') + eq_(colspec(bit_table.c.b8), 'b8 BIT(64)') for col in bit_table.c: self.assert_(repr(col)) @@ -314,7 +318,7 @@ class TypesTest(TestBase, AssertsExecutionResults): def roundtrip(store, expected=None): expected = expected or store table.insert(store).execute() - row = list(table.select().execute())[0] + row = table.select().execute().first() try: self.assert_(list(row) == expected) except: @@ -322,7 +326,7 @@ class TypesTest(TestBase, AssertsExecutionResults): print "Expected %s" % expected print "Found %s" % list(row) raise - table.delete().execute() + table.delete().execute().close() roundtrip([0] * 8) roundtrip([None, None, 0, None, None, None, None, None]) @@ -350,10 +354,10 @@ class TypesTest(TestBase, AssertsExecutionResults): Column('b3', mysql.MSTinyInteger(1)), Column('b4', mysql.MSTinyInteger)) - self.assert_eq(colspec(bool_table.c.b1), 'b1 BOOL') - self.assert_eq(colspec(bool_table.c.b2), 'b2 BOOL') - self.assert_eq(colspec(bool_table.c.b3), 'b3 TINYINT(1)') - self.assert_eq(colspec(bool_table.c.b4), 'b4 TINYINT') + eq_(colspec(bool_table.c.b1), 'b1 BOOL') + eq_(colspec(bool_table.c.b2), 'b2 BOOL') + eq_(colspec(bool_table.c.b3), 'b3 TINYINT(1)') + eq_(colspec(bool_table.c.b4), 'b4 TINYINT') for col in bool_table.c: self.assert_(repr(col)) @@ -364,7 +368,7 @@ class TypesTest(TestBase, AssertsExecutionResults): def roundtrip(store, expected=None): expected = expected or store table.insert(store).execute() - row = list(table.select().execute())[0] + row = table.select().execute().first() try: self.assert_(list(row) == expected) for i, val in enumerate(expected): @@ -375,7 +379,7 @@ class TypesTest(TestBase, AssertsExecutionResults): print "Expected %s" % expected print "Found %s" % list(row) raise - table.delete().execute() + table.delete().execute().close() roundtrip([None, None, None, None]) @@ -387,7 +391,7 @@ class TypesTest(TestBase, AssertsExecutionResults): meta2 = MetaData(testing.db) # replace with reflected table = Table('mysql_bool', meta2, autoload=True) - self.assert_eq(colspec(table.c.b3), 'b3 BOOL') + eq_(colspec(table.c.b3), 'b3 BOOL') roundtrip([None, None, None, None]) roundtrip([True, True, 1, 1], [True, True, True, 1]) @@ -430,7 +434,7 @@ class TypesTest(TestBase, AssertsExecutionResults): t = Table('mysql_ts%s' % idx, meta, Column('id', Integer, primary_key=True), Column('t', *spec)) - self.assert_eq(colspec(t.c.t), "t %s" % expected) + eq_(colspec(t.c.t), "t %s" % expected) self.assert_(repr(t.c.t)) t.create() r = Table('mysql_ts%s' % idx, MetaData(testing.db), @@ -460,12 +464,12 @@ class TypesTest(TestBase, AssertsExecutionResults): for table in year_table, reflected: table.insert(['1950', '50', None, 50, 1950]).execute() - row = list(table.select().execute())[0] - self.assert_eq(list(row), [1950, 2050, None, 50, 1950]) + row = table.select().execute().first() + eq_(list(row), [1950, 2050, None, 50, 1950]) table.delete().execute() self.assert_(colspec(table.c.y1).startswith('y1 YEAR')) - self.assert_eq(colspec(table.c.y4), 'y4 YEAR(2)') - self.assert_eq(colspec(table.c.y5), 'y5 YEAR(4)') + eq_(colspec(table.c.y4), 'y4 YEAR(2)') + eq_(colspec(table.c.y5), 'y5 YEAR(4)') finally: meta.drop_all() @@ -479,9 +483,9 @@ class TypesTest(TestBase, AssertsExecutionResults): Column('s2', mysql.MSSet("'a'")), Column('s3', mysql.MSSet("'5'", "'7'", "'9'"))) - self.assert_eq(colspec(set_table.c.s1), "s1 SET('dq','sq')") - self.assert_eq(colspec(set_table.c.s2), "s2 SET('a')") - self.assert_eq(colspec(set_table.c.s3), "s3 SET('5','7','9')") + eq_(colspec(set_table.c.s1), "s1 SET('dq','sq')") + eq_(colspec(set_table.c.s2), "s2 SET('a')") + eq_(colspec(set_table.c.s3), "s3 SET('5','7','9')") for col in set_table.c: self.assert_(repr(col)) @@ -494,7 +498,7 @@ class TypesTest(TestBase, AssertsExecutionResults): def roundtrip(store, expected=None): expected = expected or store table.insert(store).execute() - row = list(table.select().execute())[0] + row = table.select().execute().first() try: self.assert_(list(row) == expected) except: @@ -518,12 +522,12 @@ class TypesTest(TestBase, AssertsExecutionResults): {'s3':set(['5', '7'])}, {'s3':set(['5', '7', '9'])}, {'s3':set(['7', '9'])}) - rows = list(select( + rows = select( [set_table.c.s3], - set_table.c.s3.in_([set(['5']), set(['5', '7'])])).execute()) + set_table.c.s3.in_([set(['5']), set(['5', '7']), set(['7', '5'])]) + ).execute().fetchall() found = set([frozenset(row[0]) for row in rows]) - eq_(found, - set([frozenset(['5']), frozenset(['5', '7'])])) + eq_(found, set([frozenset(['5']), frozenset(['5', '7'])])) finally: meta.drop_all() @@ -542,17 +546,17 @@ class TypesTest(TestBase, AssertsExecutionResults): Column('e6', mysql.MSEnum("'a'", "b")), ) - self.assert_eq(colspec(enum_table.c.e1), + eq_(colspec(enum_table.c.e1), "e1 ENUM('a','b')") - self.assert_eq(colspec(enum_table.c.e2), + eq_(colspec(enum_table.c.e2), "e2 ENUM('a','b') NOT NULL") - self.assert_eq(colspec(enum_table.c.e3), + eq_(colspec(enum_table.c.e3), "e3 ENUM('a','b')") - self.assert_eq(colspec(enum_table.c.e4), + eq_(colspec(enum_table.c.e4), "e4 ENUM('a','b') NOT NULL") - self.assert_eq(colspec(enum_table.c.e5), + eq_(colspec(enum_table.c.e5), "e5 ENUM('a','b')") - self.assert_eq(colspec(enum_table.c.e6), + eq_(colspec(enum_table.c.e6), "e6 ENUM('''a''','b')") enum_table.drop(checkfirst=True) enum_table.create() @@ -585,8 +589,9 @@ class TypesTest(TestBase, AssertsExecutionResults): # This is known to fail with MySQLDB 1.2.2 beta versions # which return these as sets.Set(['a']), sets.Set(['b']) # (even on Pythons with __builtin__.set) - if testing.db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and \ - testing.db.dialect.dbapi.version_info >= (1, 2, 2): + if (not testing.against('+zxjdbc') and + testing.db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and + testing.db.dialect.dbapi.version_info >= (1, 2, 2)): # these mysqldb seem to always uses 'sets', even on later pythons import sets def convert(value): @@ -602,7 +607,7 @@ class TypesTest(TestBase, AssertsExecutionResults): e.append(tuple([convert(c) for c in row])) expected = e - self.assert_eq(res, expected) + eq_(res, expected) enum_table.drop() @testing.exclude('mysql', '<', (4,), "3.23 can't handle an ENUM of ''") @@ -637,25 +642,52 @@ class TypesTest(TestBase, AssertsExecutionResults): finally: enum_table.drop() + + +class ReflectionTest(TestBase, AssertsExecutionResults): + + __only_on__ = 'mysql' + def test_default_reflection(self): """Test reflection of column defaults.""" def_table = Table('mysql_def', MetaData(testing.db), Column('c1', String(10), DefaultClause('')), Column('c2', String(10), DefaultClause('0')), - Column('c3', String(10), DefaultClause('abc'))) + Column('c3', String(10), DefaultClause('abc')), + Column('c4', TIMESTAMP, DefaultClause('2009-04-05 12:00:00')), + Column('c5', TIMESTAMP, ), + + ) + def_table.create() try: - def_table.create() reflected = Table('mysql_def', MetaData(testing.db), - autoload=True) - for t in def_table, reflected: - assert t.c.c1.server_default.arg == '' - assert t.c.c2.server_default.arg == '0' - assert t.c.c3.server_default.arg == 'abc' + autoload=True) finally: def_table.drop() + + assert def_table.c.c1.server_default.arg == '' + assert def_table.c.c2.server_default.arg == '0' + assert def_table.c.c3.server_default.arg == 'abc' + assert def_table.c.c4.server_default.arg == '2009-04-05 12:00:00' + + assert str(reflected.c.c1.server_default.arg) == "''" + assert str(reflected.c.c2.server_default.arg) == "'0'" + assert str(reflected.c.c3.server_default.arg) == "'abc'" + assert str(reflected.c.c4.server_default.arg) == "'2009-04-05 12:00:00'" + + reflected.create() + try: + reflected2 = Table('mysql_def', MetaData(testing.db), autoload=True) + finally: + reflected.drop() + assert str(reflected2.c.c1.server_default.arg) == "''" + assert str(reflected2.c.c2.server_default.arg) == "'0'" + assert str(reflected2.c.c3.server_default.arg) == "'abc'" + assert str(reflected2.c.c4.server_default.arg) == "'2009-04-05 12:00:00'" + def test_reflection_on_include_columns(self): """Test reflection of include_columns to be sure they respect case.""" @@ -700,8 +732,8 @@ class TypesTest(TestBase, AssertsExecutionResults): ( mysql.MSSmallInteger(4), mysql.MSSmallInteger(4), ), ( mysql.MSMediumInteger(), mysql.MSMediumInteger(), ), ( mysql.MSMediumInteger(8), mysql.MSMediumInteger(8), ), - ( Binary(3), mysql.MSBlob(3), ), - ( Binary(), mysql.MSBlob() ), + ( Binary(3), mysql.TINYBLOB(), ), + ( Binary(), mysql.BLOB() ), ( mysql.MSBinary(3), mysql.MSBinary(3), ), ( mysql.MSVarBinary(3),), ( mysql.MSVarBinary(), mysql.MSBlob()), @@ -734,14 +766,15 @@ class TypesTest(TestBase, AssertsExecutionResults): # in a view, e.g. char -> varchar, tinyblob -> mediumblob # # Not sure exactly which point version has the fix. - if db.dialect.server_version_info(db.connect()) < (5, 0, 11): + if db.dialect.server_version_info < (5, 0, 11): tables = rt, else: tables = rt, rv for table in tables: for i, reflected in enumerate(table.c): - assert isinstance(reflected.type, type(expected[i])) + assert isinstance(reflected.type, type(expected[i])), \ + "element %d: %r not instance of %r" % (i, reflected.type, type(expected[i])) finally: db.execute('DROP VIEW mysql_types_v') finally: @@ -802,17 +835,12 @@ class TypesTest(TestBase, AssertsExecutionResults): tbl.insert().execute() if 'int_y' in tbl.c: assert select([tbl.c.int_y]).scalar() == 1 - assert list(tbl.select().execute().fetchone()).count(1) == 1 + assert list(tbl.select().execute().first()).count(1) == 1 else: - assert 1 not in list(tbl.select().execute().fetchone()) + assert 1 not in list(tbl.select().execute().first()) finally: meta.drop_all() - def assert_eq(self, got, wanted): - if got != wanted: - print "Expected %s" % wanted - print "Found %s" % got - eq_(got, wanted) class SQLTest(TestBase, AssertsCompiledSQL): @@ -909,11 +937,11 @@ class SQLTest(TestBase, AssertsCompiledSQL): (m.MSBit, "t.col"), # this is kind of sucky. thank you default arguments! - (NUMERIC, "CAST(t.col AS DECIMAL(10, 2))"), - (DECIMAL, "CAST(t.col AS DECIMAL(10, 2))"), - (Numeric, "CAST(t.col AS DECIMAL(10, 2))"), - (m.MSNumeric, "CAST(t.col AS DECIMAL(10, 2))"), - (m.MSDecimal, "CAST(t.col AS DECIMAL(10, 2))"), + (NUMERIC, "CAST(t.col AS DECIMAL)"), + (DECIMAL, "CAST(t.col AS DECIMAL)"), + (Numeric, "CAST(t.col AS DECIMAL)"), + (m.MSNumeric, "CAST(t.col AS DECIMAL)"), + (m.MSDecimal, "CAST(t.col AS DECIMAL)"), (FLOAT, "t.col"), (Float, "t.col"), @@ -928,8 +956,8 @@ class SQLTest(TestBase, AssertsCompiledSQL): (DateTime, "CAST(t.col AS DATETIME)"), (Date, "CAST(t.col AS DATE)"), (Time, "CAST(t.col AS TIME)"), - (m.MSDateTime, "CAST(t.col AS DATETIME)"), - (m.MSDate, "CAST(t.col AS DATE)"), + (DateTime, "CAST(t.col AS DATETIME)"), + (Date, "CAST(t.col AS DATE)"), (m.MSTime, "CAST(t.col AS TIME)"), (m.MSTimeStamp, "CAST(t.col AS DATETIME)"), (m.MSYear, "t.col"), @@ -998,12 +1026,11 @@ class SQLTest(TestBase, AssertsCompiledSQL): class RawReflectionTest(TestBase): def setup(self): - self.dialect = mysql.dialect() - self.reflector = mysql.MySQLSchemaReflector( - self.dialect.identifier_preparer) + dialect = mysql.dialect() + self.parser = mysql.MySQLTableDefinitionParser(dialect, dialect.identifier_preparer) def test_key_reflection(self): - regex = self.reflector._re_key + regex = self.parser._re_key assert regex.match(' PRIMARY KEY (`id`),') assert regex.match(' PRIMARY KEY USING BTREE (`id`),') @@ -1023,37 +1050,11 @@ class ExecutionTest(TestBase): cx = engine.connect() meta = MetaData() - - assert ('mysql', 'charset') not in cx.info - assert ('mysql', 'force_charset') not in cx.info - - cx.execute(text("SELECT 1")).fetchall() - assert ('mysql', 'charset') not in cx.info - - meta.reflect(cx) - assert ('mysql', 'charset') in cx.info - - cx.execute(text("SET @squiznart=123")) - assert ('mysql', 'charset') in cx.info - - # the charset invalidation is very conservative - cx.execute(text("SET TIMESTAMP = DEFAULT")) - assert ('mysql', 'charset') not in cx.info - - cx.info[('mysql', 'force_charset')] = 'latin1' - - assert engine.dialect._detect_charset(cx) == 'latin1' - assert cx.info[('mysql', 'charset')] == 'latin1' - - del cx.info[('mysql', 'force_charset')] - del cx.info[('mysql', 'charset')] + charset = engine.dialect._detect_charset(cx) meta.reflect(cx) - assert ('mysql', 'charset') in cx.info - - # String execution doesn't go through the detector. - cx.execute("SET TIMESTAMP = DEFAULT") - assert ('mysql', 'charset') in cx.info + eq_(cx.dialect._connection_charset, charset) + cx.close() class MatchTest(TestBase, AssertsCompiledSQL): @@ -1102,9 +1103,10 @@ class MatchTest(TestBase, AssertsCompiledSQL): metadata.drop_all() def test_expression(self): + format = testing.db.dialect.paramstyle == 'format' and '%s' or '?' self.assert_compile( matchtable.c.title.match('somstr'), - "MATCH (matchtable.title) AGAINST (%s IN BOOLEAN MODE)") + "MATCH (matchtable.title) AGAINST (%s IN BOOLEAN MODE)" % format) def test_simple_match(self): results = (matchtable.select(). @@ -1162,6 +1164,5 @@ class MatchTest(TestBase, AssertsCompiledSQL): def colspec(c): - return testing.db.dialect.schemagenerator(testing.db.dialect, - testing.db, None, None).get_column_specification(c) + return testing.db.dialect.ddl_compiler(testing.db.dialect, c.table).get_column_specification(c) diff --git a/test/dialect/test_oracle.py b/test/dialect/test_oracle.py index d9d64806e..53e0f9ec2 100644 --- a/test/dialect/test_oracle.py +++ b/test/dialect/test_oracle.py @@ -2,12 +2,14 @@ from sqlalchemy.test.testing import eq_ from sqlalchemy import * +from sqlalchemy import types as sqltypes from sqlalchemy.sql import table, column -from sqlalchemy.databases import oracle from sqlalchemy.test import * from sqlalchemy.test.testing import eq_ from sqlalchemy.test.engines import testing_engine +from sqlalchemy.dialects.oracle import cx_oracle, base as oracle from sqlalchemy.engine import default +from sqlalchemy.util import jython import os @@ -43,10 +45,10 @@ class CompileTest(TestBase, AssertsCompiledSQL): meta = MetaData() parent = Table('parent', meta, Column('id', Integer, primary_key=True), Column('name', String(50)), - owner='ed') + schema='ed') child = Table('child', meta, Column('id', Integer, primary_key=True), Column('parent_id', Integer, ForeignKey('ed.parent.id')), - owner = 'ed') + schema = 'ed') self.assert_compile(parent.join(child), "ed.parent JOIN ed.child ON ed.parent.id = ed.child.parent_id") @@ -342,6 +344,25 @@ class TypesTest(TestBase, AssertsCompiledSQL): b = bindparam("foo", u"hello world!") assert b.type.dialect_impl(dialect).get_dbapi_type(dbapi) == 'STRING' + def test_type_adapt(self): + dialect = cx_oracle.dialect() + + for start, test in [ + (DateTime(), cx_oracle._OracleDateTime), + (TIMESTAMP(), cx_oracle._OracleTimestamp), + (oracle.OracleRaw(), cx_oracle._OracleRaw), + (String(), String), + (VARCHAR(), VARCHAR), + (String(50), String), + (Unicode(), Unicode), + (Text(), cx_oracle._OracleText), + (UnicodeText(), cx_oracle._OracleUnicodeText), + (NCHAR(), NCHAR), + (oracle.RAW(50), cx_oracle._OracleRaw), + ]: + assert isinstance(start.dialect_impl(dialect), test), "wanted %r got %r" % (test, start.dialect_impl(dialect)) + + def test_reflect_raw(self): types_table = Table( 'all_types', MetaData(testing.db), @@ -354,16 +375,16 @@ class TypesTest(TestBase, AssertsCompiledSQL): def test_reflect_nvarchar(self): metadata = MetaData(testing.db) t = Table('t', metadata, - Column('data', oracle.OracleNVarchar(255)) + Column('data', sqltypes.NVARCHAR(255)) ) metadata.create_all() try: m2 = MetaData(testing.db) t2 = Table('t', m2, autoload=True) - assert isinstance(t2.c.data.type, oracle.OracleNVarchar) + assert isinstance(t2.c.data.type, sqltypes.NVARCHAR) data = u'm’a réveillé.' t2.insert().execute(data=data) - eq_(t2.select().execute().fetchone()['data'], data) + eq_(t2.select().execute().first()['data'], data) finally: metadata.drop_all() @@ -391,7 +412,7 @@ class TypesTest(TestBase, AssertsCompiledSQL): t.create(engine) try: engine.execute(t.insert(), id=1, data='this is text', bindata='this is binary') - row = engine.execute(t.select()).fetchone() + row = engine.execute(t.select()).first() eq_(row['data'].read(), 'this is text') eq_(row['bindata'].read(), 'this is binary') finally: @@ -408,7 +429,6 @@ class BufferedColumnTest(TestBase, AssertsCompiledSQL): Column('data', Binary) ) meta.create_all() - stream = os.path.join(os.path.dirname(__file__), "..", 'binary_data_one.dat') stream = file(stream).read(12000) @@ -420,17 +440,18 @@ class BufferedColumnTest(TestBase, AssertsCompiledSQL): meta.drop_all() def test_fetch(self): - eq_( - binary_table.select().execute().fetchall() , - [(i, stream) for i in range(1, 11)], - ) + result = binary_table.select().execute().fetchall() + if jython: + result = [(i, value.tostring()) for i, value in result] + eq_(result, [(i, stream) for i in range(1, 11)]) + @testing.fails_on('+zxjdbc', 'FIXME: zxjdbc should support this') def test_fetch_single_arraysize(self): eng = testing_engine(options={'arraysize':1}) - eq_( - eng.execute(binary_table.select()).fetchall(), - [(i, stream) for i in range(1, 11)], - ) + result = eng.execute(binary_table.select()).fetchall(), + if jython: + result = [(i, value.tostring()) for i, value in result] + eq_(result, [(i, stream) for i in range(1, 11)]) class SequenceTest(TestBase, AssertsCompiledSQL): def test_basic(self): diff --git a/test/dialect/test_postgres.py b/test/dialect/test_postgresql.py index 8ca714bad..e1c351a93 100644 --- a/test/dialect/test_postgres.py +++ b/test/dialect/test_postgresql.py @@ -1,18 +1,19 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +from sqlalchemy.test import engines import datetime from sqlalchemy import * from sqlalchemy.orm import * -from sqlalchemy import exc -from sqlalchemy.databases import postgres +from sqlalchemy import exc, schema +from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.engine.strategies import MockEngineStrategy from sqlalchemy.test import * from sqlalchemy.sql import table, column - +from sqlalchemy.test.testing import eq_ class SequenceTest(TestBase, AssertsCompiledSQL): def test_basic(self): seq = Sequence("my_seq_no_schema") - dialect = postgres.PGDialect() + dialect = postgresql.PGDialect() assert dialect.identifier_preparer.format_sequence(seq) == "my_seq_no_schema" seq = Sequence("my_seq", schema="some_schema") @@ -22,43 +23,77 @@ class SequenceTest(TestBase, AssertsCompiledSQL): assert dialect.identifier_preparer.format_sequence(seq) == '"Some_Schema"."My_Seq"' class CompileTest(TestBase, AssertsCompiledSQL): - __dialect__ = postgres.dialect() + __dialect__ = postgresql.dialect() def test_update_returning(self): - dialect = postgres.dialect() + dialect = postgresql.dialect() table1 = table('mytable', column('myid', Integer), column('name', String(128)), column('description', String(128)), ) - u = update(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name]) + u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect) - u = update(table1, values=dict(name='foo'), postgres_returning=[table1]) + u = update(table1, values=dict(name='foo')).returning(table1) self.assert_compile(u, "UPDATE mytable SET name=%(name)s "\ "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect) - u = update(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)]) - self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name)", dialect=dialect) + u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) + self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name) AS length_1", dialect=dialect) + def test_insert_returning(self): - dialect = postgres.dialect() + dialect = postgresql.dialect() table1 = table('mytable', column('myid', Integer), column('name', String(128)), column('description', String(128)), ) - i = insert(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name]) + i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect) - i = insert(table1, values=dict(name='foo'), postgres_returning=[table1]) + i = insert(table1, values=dict(name='foo')).returning(table1) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) "\ "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect) - i = insert(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)]) - self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect) + i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name) AS length_1", dialect=dialect) + + @testing.uses_deprecated(r".*argument is deprecated. Please use statement.returning.*") + def test_old_returning_names(self): + dialect = postgresql.dialect() + table1 = table('mytable', + column('myid', Integer), + column('name', String(128)), + column('description', String(128)), + ) + + u = update(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name]) + self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect) + + u = update(table1, values=dict(name='foo'), postgresql_returning=[table1.c.myid, table1.c.name]) + self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect) + + i = insert(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name]) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect) + + def test_create_partial_index(self): + tbl = Table('testtbl', MetaData(), Column('data',Integer)) + idx = Index('test_idx1', tbl.c.data, postgresql_where=and_(tbl.c.data > 5, tbl.c.data < 10)) + + self.assert_compile(schema.CreateIndex(idx), + "CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10", dialect=postgresql.dialect()) + + @testing.uses_deprecated(r".*'postgres_where' argument has been renamed.*") + def test_old_create_partial_index(self): + tbl = Table('testtbl', MetaData(), Column('data',Integer)) + idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10)) + + self.assert_compile(schema.CreateIndex(idx), + "CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10", dialect=postgresql.dialect()) def test_extract(self): t = table('t', column('col1')) @@ -70,72 +105,20 @@ class CompileTest(TestBase, AssertsCompiledSQL): "FROM t" % field) -class ReturningTest(TestBase, AssertsExecutionResults): - __only_on__ = 'postgres' - - @testing.exclude('postgres', '<', (8, 2), '8.3+ feature') - def test_update_returning(self): - meta = MetaData(testing.db) - table = Table('tables', meta, - Column('id', Integer, primary_key=True), - Column('persons', Integer), - Column('full', Boolean) - ) - table.create() - try: - table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) - - result = table.update(table.c.persons > 4, dict(full=True), postgres_returning=[table.c.id]).execute() - eq_(result.fetchall(), [(1,)]) - - result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() - eq_(result2.fetchall(), [(1,True),(2,False)]) - finally: - table.drop() - - @testing.exclude('postgres', '<', (8, 2), '8.3+ feature') - def test_insert_returning(self): - meta = MetaData(testing.db) - table = Table('tables', meta, - Column('id', Integer, primary_key=True), - Column('persons', Integer), - Column('full', Boolean) - ) - table.create() - try: - result = table.insert(postgres_returning=[table.c.id]).execute({'persons': 1, 'full': False}) - - eq_(result.fetchall(), [(1,)]) - - @testing.fails_on('postgres', 'Known limitation of psycopg2') - def test_executemany(): - # return value is documented as failing with psycopg2/executemany - result2 = table.insert(postgres_returning=[table]).execute( - [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}]) - eq_(result2.fetchall(), [(2, 2, False), (3,3,True)]) - - test_executemany() - - result3 = table.insert(postgres_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False}) - eq_([dict(row) for row in result3], [{'double_id':8}]) - - result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, true) returning persons') - eq_([dict(row) for row in result4], [{'persons': 10}]) - finally: - table.drop() - - class InsertTest(TestBase, AssertsExecutionResults): - __only_on__ = 'postgres' + __only_on__ = 'postgresql' @classmethod def setup_class(cls): global metadata + cls.engine= testing.db metadata = MetaData(testing.db) def teardown(self): metadata.drop_all() metadata.tables.clear() + if self.engine is not testing.db: + self.engine.dispose() def test_compiled_insert(self): table = Table('testtable', metadata, @@ -144,7 +127,7 @@ class InsertTest(TestBase, AssertsExecutionResults): metadata.create_all() - ins = table.insert(values={'data':bindparam('x')}).compile() + ins = table.insert(inline=True, values={'data':bindparam('x')}).compile() ins.execute({'x':"five"}, {'x':"seven"}) assert table.select().execute().fetchall() == [(1, 'five'), (2, 'seven')] @@ -155,6 +138,13 @@ class InsertTest(TestBase, AssertsExecutionResults): metadata.create_all() self._assert_data_with_sequence(table, "my_seq") + def test_sequence_returning_insert(self): + table = Table('testtable', metadata, + Column('id', Integer, Sequence('my_seq'), primary_key=True), + Column('data', String(30))) + metadata.create_all() + self._assert_data_with_sequence_returning(table, "my_seq") + def test_opt_sequence_insert(self): table = Table('testtable', metadata, Column('id', Integer, Sequence('my_seq', optional=True), primary_key=True), @@ -162,6 +152,13 @@ class InsertTest(TestBase, AssertsExecutionResults): metadata.create_all() self._assert_data_autoincrement(table) + def test_opt_sequence_returning_insert(self): + table = Table('testtable', metadata, + Column('id', Integer, Sequence('my_seq', optional=True), primary_key=True), + Column('data', String(30))) + metadata.create_all() + self._assert_data_autoincrement_returning(table) + def test_autoincrement_insert(self): table = Table('testtable', metadata, Column('id', Integer, primary_key=True), @@ -169,6 +166,13 @@ class InsertTest(TestBase, AssertsExecutionResults): metadata.create_all() self._assert_data_autoincrement(table) + def test_autoincrement_returning_insert(self): + table = Table('testtable', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30))) + metadata.create_all() + self._assert_data_autoincrement_returning(table) + def test_noautoincrement_insert(self): table = Table('testtable', metadata, Column('id', Integer, primary_key=True, autoincrement=False), @@ -177,14 +181,17 @@ class InsertTest(TestBase, AssertsExecutionResults): self._assert_data_noautoincrement(table) def _assert_data_autoincrement(self, table): + self.engine = engines.testing_engine(options={'implicit_returning':False}) + metadata.bind = self.engine + def go(): # execute with explicit id r = table.insert().execute({'id':30, 'data':'d1'}) - assert r.last_inserted_ids() == [30] + assert r.inserted_primary_key == [30] # execute with prefetch id r = table.insert().execute({'data':'d2'}) - assert r.last_inserted_ids() == [1] + assert r.inserted_primary_key == [1] # executemany with explicit ids table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) @@ -201,7 +208,7 @@ class InsertTest(TestBase, AssertsExecutionResults): # note that the test framework doesnt capture the "preexecute" of a seqeuence # or default. we just see it in the bind params. - self.assert_sql(testing.db, go, [], with_sequences=[ + self.assert_sql(self.engine, go, [], with_sequences=[ ( "INSERT INTO testtable (id, data) VALUES (:id, :data)", {'id':30, 'data':'d1'} @@ -242,19 +249,19 @@ class InsertTest(TestBase, AssertsExecutionResults): # test the same series of events using a reflected # version of the table - m2 = MetaData(testing.db) + m2 = MetaData(self.engine) table = Table(table.name, m2, autoload=True) def go(): table.insert().execute({'id':30, 'data':'d1'}) r = table.insert().execute({'data':'d2'}) - assert r.last_inserted_ids() == [5] + assert r.inserted_primary_key == [5] table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) table.insert().execute({'data':'d5'}, {'data':'d6'}) table.insert(inline=True).execute({'id':33, 'data':'d7'}) table.insert(inline=True).execute({'data':'d8'}) - self.assert_sql(testing.db, go, [], with_sequences=[ + self.assert_sql(self.engine, go, [], with_sequences=[ ( "INSERT INTO testtable (id, data) VALUES (:id, :data)", {'id':30, 'data':'d1'} @@ -293,7 +300,127 @@ class InsertTest(TestBase, AssertsExecutionResults): ] table.delete().execute() + def _assert_data_autoincrement_returning(self, table): + self.engine = engines.testing_engine(options={'implicit_returning':True}) + metadata.bind = self.engine + + def go(): + # execute with explicit id + r = table.insert().execute({'id':30, 'data':'d1'}) + assert r.inserted_primary_key == [30] + + # execute with prefetch id + r = table.insert().execute({'data':'d2'}) + assert r.inserted_primary_key == [1] + + # executemany with explicit ids + table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) + + # executemany, uses SERIAL + table.insert().execute({'data':'d5'}, {'data':'d6'}) + + # single execute, explicit id, inline + table.insert(inline=True).execute({'id':33, 'data':'d7'}) + + # single execute, inline, uses SERIAL + table.insert(inline=True).execute({'data':'d8'}) + + self.assert_sql(self.engine, go, [], with_sequences=[ + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {'id':30, 'data':'d1'} + ), + ( + "INSERT INTO testtable (data) VALUES (:data) RETURNING testtable.id", + {'data': 'd2'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d5'}, {'data':'d6'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':33, 'data':'d7'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d8'}] + ), + ]) + + assert table.select().execute().fetchall() == [ + (30, 'd1'), + (1, 'd2'), + (31, 'd3'), + (32, 'd4'), + (2, 'd5'), + (3, 'd6'), + (33, 'd7'), + (4, 'd8'), + ] + table.delete().execute() + + # test the same series of events using a reflected + # version of the table + m2 = MetaData(self.engine) + table = Table(table.name, m2, autoload=True) + + def go(): + table.insert().execute({'id':30, 'data':'d1'}) + r = table.insert().execute({'data':'d2'}) + assert r.inserted_primary_key == [5] + table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) + table.insert().execute({'data':'d5'}, {'data':'d6'}) + table.insert(inline=True).execute({'id':33, 'data':'d7'}) + table.insert(inline=True).execute({'data':'d8'}) + + self.assert_sql(self.engine, go, [], with_sequences=[ + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {'id':30, 'data':'d1'} + ), + ( + "INSERT INTO testtable (data) VALUES (:data) RETURNING testtable.id", + {'data':'d2'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d5'}, {'data':'d6'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':33, 'data':'d7'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d8'}] + ), + ]) + + assert table.select().execute().fetchall() == [ + (30, 'd1'), + (5, 'd2'), + (31, 'd3'), + (32, 'd4'), + (6, 'd5'), + (7, 'd6'), + (33, 'd7'), + (8, 'd8'), + ] + table.delete().execute() + def _assert_data_with_sequence(self, table, seqname): + self.engine = engines.testing_engine(options={'implicit_returning':False}) + metadata.bind = self.engine + def go(): table.insert().execute({'id':30, 'data':'d1'}) table.insert().execute({'data':'d2'}) @@ -302,7 +429,7 @@ class InsertTest(TestBase, AssertsExecutionResults): table.insert(inline=True).execute({'id':33, 'data':'d7'}) table.insert(inline=True).execute({'data':'d8'}) - self.assert_sql(testing.db, go, [], with_sequences=[ + self.assert_sql(self.engine, go, [], with_sequences=[ ( "INSERT INTO testtable (id, data) VALUES (:id, :data)", {'id':30, 'data':'d1'} @@ -343,18 +470,76 @@ class InsertTest(TestBase, AssertsExecutionResults): # cant test reflection here since the Sequence must be # explicitly specified + def _assert_data_with_sequence_returning(self, table, seqname): + self.engine = engines.testing_engine(options={'implicit_returning':True}) + metadata.bind = self.engine + + def go(): + table.insert().execute({'id':30, 'data':'d1'}) + table.insert().execute({'data':'d2'}) + table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) + table.insert().execute({'data':'d5'}, {'data':'d6'}) + table.insert(inline=True).execute({'id':33, 'data':'d7'}) + table.insert(inline=True).execute({'data':'d8'}) + + self.assert_sql(self.engine, go, [], with_sequences=[ + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {'id':30, 'data':'d1'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (nextval('my_seq'), :data) RETURNING testtable.id", + {'data':'d2'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (nextval('%s'), :data)" % seqname, + [{'data':'d5'}, {'data':'d6'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':33, 'data':'d7'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (nextval('%s'), :data)" % seqname, + [{'data':'d8'}] + ), + ]) + + assert table.select().execute().fetchall() == [ + (30, 'd1'), + (1, 'd2'), + (31, 'd3'), + (32, 'd4'), + (2, 'd5'), + (3, 'd6'), + (33, 'd7'), + (4, 'd8'), + ] + + # cant test reflection here since the Sequence must be + # explicitly specified + def _assert_data_noautoincrement(self, table): + self.engine = engines.testing_engine(options={'implicit_returning':False}) + metadata.bind = self.engine + table.insert().execute({'id':30, 'data':'d1'}) - try: - table.insert().execute({'data':'d2'}) - assert False - except exc.IntegrityError, e: - assert "violates not-null constraint" in str(e) - try: - table.insert().execute({'data':'d2'}, {'data':'d3'}) - assert False - except exc.IntegrityError, e: - assert "violates not-null constraint" in str(e) + + if self.engine.driver == 'pg8000': + exception_cls = exc.ProgrammingError + else: + exception_cls = exc.IntegrityError + + assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}) + assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'}) + + assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}) + + assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'}) table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'}) table.insert(inline=True).execute({'id':33, 'data':'d4'}) @@ -369,19 +554,12 @@ class InsertTest(TestBase, AssertsExecutionResults): # test the same series of events using a reflected # version of the table - m2 = MetaData(testing.db) + m2 = MetaData(self.engine) table = Table(table.name, m2, autoload=True) table.insert().execute({'id':30, 'data':'d1'}) - try: - table.insert().execute({'data':'d2'}) - assert False - except exc.IntegrityError, e: - assert "violates not-null constraint" in str(e) - try: - table.insert().execute({'data':'d2'}, {'data':'d3'}) - assert False - except exc.IntegrityError, e: - assert "violates not-null constraint" in str(e) + + assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}) + assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'}) table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'}) table.insert(inline=True).execute({'id':33, 'data':'d4'}) @@ -396,36 +574,36 @@ class InsertTest(TestBase, AssertsExecutionResults): class DomainReflectionTest(TestBase, AssertsExecutionResults): "Test PostgreSQL domains" - __only_on__ = 'postgres' + __only_on__ = 'postgresql' @classmethod def setup_class(cls): con = testing.db.connect() for ddl in ('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42', - 'CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0'): + 'CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0'): try: con.execute(ddl) except exc.SQLError, e: if not "already exists" in str(e): raise e con.execute('CREATE TABLE testtable (question integer, answer testdomain)') - con.execute('CREATE TABLE alt_schema.testtable(question integer, answer alt_schema.testdomain, anything integer)') - con.execute('CREATE TABLE crosschema (question integer, answer alt_schema.testdomain)') + con.execute('CREATE TABLE test_schema.testtable(question integer, answer test_schema.testdomain, anything integer)') + con.execute('CREATE TABLE crosschema (question integer, answer test_schema.testdomain)') @classmethod def teardown_class(cls): con = testing.db.connect() con.execute('DROP TABLE testtable') - con.execute('DROP TABLE alt_schema.testtable') + con.execute('DROP TABLE test_schema.testtable') con.execute('DROP TABLE crosschema') con.execute('DROP DOMAIN testdomain') - con.execute('DROP DOMAIN alt_schema.testdomain') + con.execute('DROP DOMAIN test_schema.testdomain') def test_table_is_reflected(self): metadata = MetaData(testing.db) table = Table('testtable', metadata, autoload=True) eq_(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns") - eq_(table.c.answer.type.__class__, postgres.PGInteger) + assert isinstance(table.c.answer.type, Integer) def test_domain_is_reflected(self): metadata = MetaData(testing.db) @@ -433,15 +611,15 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): eq_(str(table.columns.answer.server_default.arg), '42', "Reflected default value didn't equal expected value") assert not table.columns.answer.nullable, "Expected reflected column to not be nullable." - def test_table_is_reflected_alt_schema(self): + def test_table_is_reflected_test_schema(self): metadata = MetaData(testing.db) - table = Table('testtable', metadata, autoload=True, schema='alt_schema') + table = Table('testtable', metadata, autoload=True, schema='test_schema') eq_(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns") - eq_(table.c.anything.type.__class__, postgres.PGInteger) + assert isinstance(table.c.anything.type, Integer) def test_schema_domain_is_reflected(self): metadata = MetaData(testing.db) - table = Table('testtable', metadata, autoload=True, schema='alt_schema') + table = Table('testtable', metadata, autoload=True, schema='test_schema') eq_(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value") assert table.columns.answer.nullable, "Expected reflected column to be nullable." @@ -452,10 +630,10 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): assert table.columns.answer.nullable, "Expected reflected column to be nullable." def test_unknown_types(self): - from sqlalchemy.databases import postgres + from sqlalchemy.databases import postgresql - ischema_names = postgres.ischema_names - postgres.ischema_names = {} + ischema_names = postgresql.PGDialect.ischema_names + postgresql.PGDialect.ischema_names = {} try: m2 = MetaData(testing.db) assert_raises(exc.SAWarning, Table, "testtable", m2, autoload=True) @@ -467,11 +645,11 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): assert t3.c.answer.type.__class__ == sa.types.NullType finally: - postgres.ischema_names = ischema_names + postgresql.PGDialect.ischema_names = ischema_names -class MiscTest(TestBase, AssertsExecutionResults): - __only_on__ = 'postgres' +class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): + __only_on__ = 'postgresql' def test_date_reflection(self): m1 = MetaData(testing.db) @@ -536,26 +714,26 @@ class MiscTest(TestBase, AssertsExecutionResults): 'FROM mytable') def test_schema_reflection(self): - """note: this test requires that the 'alt_schema' schema be separate and accessible by the test user""" + """note: this test requires that the 'test_schema' schema be separate and accessible by the test user""" meta1 = MetaData(testing.db) users = Table('users', meta1, Column('user_id', Integer, primary_key = True), Column('user_name', String(30), nullable = False), - schema="alt_schema" + schema="test_schema" ) addresses = Table('email_addresses', meta1, Column('address_id', Integer, primary_key = True), Column('remote_user_id', Integer, ForeignKey(users.c.user_id)), Column('email_address', String(20)), - schema="alt_schema" + schema="test_schema" ) meta1.create_all() try: meta2 = MetaData(testing.db) - addresses = Table('email_addresses', meta2, autoload=True, schema="alt_schema") - users = Table('users', meta2, mustexist=True, schema="alt_schema") + addresses = Table('email_addresses', meta2, autoload=True, schema="test_schema") + users = Table('users', meta2, mustexist=True, schema="test_schema") print users print addresses @@ -574,12 +752,12 @@ class MiscTest(TestBase, AssertsExecutionResults): referer = Table("referer", meta1, Column("id", Integer, primary_key=True), Column("ref", Integer, ForeignKey('subject.id')), - schema="alt_schema") + schema="test_schema") meta1.create_all() try: meta2 = MetaData(testing.db) subject = Table("subject", meta2, autoload=True) - referer = Table("referer", meta2, schema="alt_schema", autoload=True) + referer = Table("referer", meta2, schema="test_schema", autoload=True) print str(subject.join(referer).onclause) self.assert_((subject.c.id==referer.c.ref).compare(subject.join(referer).onclause)) finally: @@ -589,19 +767,19 @@ class MiscTest(TestBase, AssertsExecutionResults): meta1 = MetaData(testing.db) subject = Table("subject", meta1, Column("id", Integer, primary_key=True), - schema='alt_schema_2' + schema='test_schema_2' ) referer = Table("referer", meta1, Column("id", Integer, primary_key=True), - Column("ref", Integer, ForeignKey('alt_schema_2.subject.id')), - schema="alt_schema") + Column("ref", Integer, ForeignKey('test_schema_2.subject.id')), + schema="test_schema") meta1.create_all() try: meta2 = MetaData(testing.db) - subject = Table("subject", meta2, autoload=True, schema="alt_schema_2") - referer = Table("referer", meta2, schema="alt_schema", autoload=True) + subject = Table("subject", meta2, autoload=True, schema="test_schema_2") + referer = Table("referer", meta2, schema="test_schema", autoload=True) print str(subject.join(referer).onclause) self.assert_((subject.c.id==referer.c.ref).compare(subject.join(referer).onclause)) finally: @@ -611,7 +789,7 @@ class MiscTest(TestBase, AssertsExecutionResults): meta = MetaData(testing.db) users = Table('users', meta, Column('id', Integer, primary_key=True), - Column('name', String(50)), schema='alt_schema') + Column('name', String(50)), schema='test_schema') users.create() try: users.insert().execute(id=1, name='name1') @@ -646,15 +824,15 @@ class MiscTest(TestBase, AssertsExecutionResults): user_name VARCHAR NOT NULL, user_password VARCHAR NOT NULL ); - """, None) + """) t = Table("speedy_users", meta, autoload=True) r = t.insert().execute(user_name='user', user_password='lala') - assert r.last_inserted_ids() == [1] + assert r.inserted_primary_key == [1] l = t.select().execute().fetchall() assert l == [(1, 'user', 'lala')] finally: - testing.db.execute("drop table speedy_users", None) + testing.db.execute("drop table speedy_users") @testing.emits_warning() def test_index_reflection(self): @@ -676,10 +854,10 @@ class MiscTest(TestBase, AssertsExecutionResults): testing.db.execute(""" create index idx1 on party ((id || name)) - """, None) + """) testing.db.execute(""" create unique index idx2 on party (id) where name = 'test' - """, None) + """) testing.db.execute(""" create index idx3 on party using btree @@ -713,35 +891,42 @@ class MiscTest(TestBase, AssertsExecutionResults): warnings.warn = capture_warnings._orig_showwarning m1.drop_all() - def test_create_partial_index(self): - tbl = Table('testtbl', MetaData(), Column('data',Integer)) - idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10)) - - executed_sql = [] - mock_strategy = MockEngineStrategy() - mock_conn = mock_strategy.create('postgres://', executed_sql.append) + def test_set_isolation_level(self): + """Test setting the isolation level with create_engine""" + eng = create_engine(testing.db.url) + eq_( + eng.execute("show transaction isolation level").scalar(), + 'read committed') + eng = create_engine(testing.db.url, isolation_level="SERIALIZABLE") + eq_( + eng.execute("show transaction isolation level").scalar(), + 'serializable') + eng = create_engine(testing.db.url, isolation_level="FOO") - idx.create(mock_conn) + if testing.db.driver == 'zxjdbc': + exception_cls = eng.dialect.dbapi.Error + else: + exception_cls = eng.dialect.dbapi.ProgrammingError + assert_raises(exception_cls, eng.execute, "show transaction isolation level") - assert executed_sql == ['CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10'] class TimezoneTest(TestBase, AssertsExecutionResults): """Test timezone-aware datetimes. - psycopg will return a datetime with a tzinfo attached to it, if postgres + psycopg will return a datetime with a tzinfo attached to it, if postgresql returns it. python then will not let you compare a datetime with a tzinfo to a datetime that doesnt have one. this test illustrates two ways to have datetime types with and without timezone info. """ - __only_on__ = 'postgres' + __only_on__ = 'postgresql' @classmethod def setup_class(cls): global tztable, notztable, metadata metadata = MetaData(testing.db) - # current_timestamp() in postgres is assumed to return TIMESTAMP WITH TIMEZONE + # current_timestamp() in postgresql is assumed to return TIMESTAMP WITH TIMEZONE tztable = Table('tztable', metadata, Column("id", Integer, primary_key=True), Column("date", DateTime(timezone=True), onupdate=func.current_timestamp()), @@ -762,17 +947,17 @@ class TimezoneTest(TestBase, AssertsExecutionResults): somedate = testing.db.connect().scalar(func.current_timestamp().select()) tztable.insert().execute(id=1, name='row1', date=somedate) c = tztable.update(tztable.c.id==1).execute(name='newname') - print tztable.select(tztable.c.id==1).execute().fetchone() + print tztable.select(tztable.c.id==1).execute().first() def test_without_timezone(self): # get a date without a tzinfo somedate = datetime.datetime(2005, 10,20, 11, 52, 00) notztable.insert().execute(id=1, name='row1', date=somedate) c = notztable.update(notztable.c.id==1).execute(name='newname') - print notztable.select(tztable.c.id==1).execute().fetchone() + print notztable.select(tztable.c.id==1).execute().first() class ArrayTest(TestBase, AssertsExecutionResults): - __only_on__ = 'postgres' + __only_on__ = 'postgresql' @classmethod def setup_class(cls): @@ -781,10 +966,14 @@ class ArrayTest(TestBase, AssertsExecutionResults): arrtable = Table('arrtable', metadata, Column('id', Integer, primary_key=True), - Column('intarr', postgres.PGArray(Integer)), - Column('strarr', postgres.PGArray(String(convert_unicode=True)), nullable=False) + Column('intarr', postgresql.PGArray(Integer)), + Column('strarr', postgresql.PGArray(String(convert_unicode=True)), nullable=False) ) metadata.create_all() + + def teardown(self): + arrtable.delete().execute() + @classmethod def teardown_class(cls): metadata.drop_all() @@ -792,34 +981,38 @@ class ArrayTest(TestBase, AssertsExecutionResults): def test_reflect_array_column(self): metadata2 = MetaData(testing.db) tbl = Table('arrtable', metadata2, autoload=True) - assert isinstance(tbl.c.intarr.type, postgres.PGArray) - assert isinstance(tbl.c.strarr.type, postgres.PGArray) + assert isinstance(tbl.c.intarr.type, postgresql.PGArray) + assert isinstance(tbl.c.strarr.type, postgresql.PGArray) assert isinstance(tbl.c.intarr.type.item_type, Integer) assert isinstance(tbl.c.strarr.type.item_type, String) + @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays') def test_insert_array(self): arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) results = arrtable.select().execute().fetchall() eq_(len(results), 1) eq_(results[0]['intarr'], [1,2,3]) eq_(results[0]['strarr'], ['abc','def']) - arrtable.delete().execute() + @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays') + @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays') def test_array_where(self): arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) arrtable.insert().execute(intarr=[4,5,6], strarr='ABC') results = arrtable.select().where(arrtable.c.intarr == [1,2,3]).execute().fetchall() eq_(len(results), 1) eq_(results[0]['intarr'], [1,2,3]) - arrtable.delete().execute() + @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays') + @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays') def test_array_concat(self): arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) results = select([arrtable.c.intarr + [4,5,6]]).execute().fetchall() eq_(len(results), 1) eq_(results[0][0], [1,2,3,4,5,6]) - arrtable.delete().execute() + @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays') + @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays') def test_array_subtype_resultprocessor(self): arrtable.insert().execute(intarr=[4,5,6], strarr=[[u'm\xe4\xe4'], [u'm\xf6\xf6']]) arrtable.insert().execute(intarr=[1,2,3], strarr=[u'm\xe4\xe4', u'm\xf6\xf6']) @@ -827,13 +1020,14 @@ class ArrayTest(TestBase, AssertsExecutionResults): eq_(len(results), 2) eq_(results[0]['strarr'], [u'm\xe4\xe4', u'm\xf6\xf6']) eq_(results[1]['strarr'], [[u'm\xe4\xe4'], [u'm\xf6\xf6']]) - arrtable.delete().execute() + @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays') + @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays') def test_array_mutability(self): class Foo(object): pass footable = Table('foo', metadata, Column('id', Integer, primary_key=True), - Column('intarr', postgres.PGArray(Integer), nullable=True) + Column('intarr', postgresql.PGArray(Integer), nullable=True) ) mapper(Foo, footable) metadata.create_all() @@ -870,19 +1064,19 @@ class ArrayTest(TestBase, AssertsExecutionResults): sess.add(foo) sess.flush() -class TimeStampTest(TestBase, AssertsExecutionResults): - __only_on__ = 'postgres' - - @testing.uses_deprecated() +class TimestampTest(TestBase, AssertsExecutionResults): + __only_on__ = 'postgresql' + def test_timestamp(self): engine = testing.db connection = engine.connect() - s = select([func.TIMESTAMP("12/25/07").label("ts")]) - result = connection.execute(s).fetchone() + + s = select(["timestamp '2007-12-25'"]) + result = connection.execute(s).first() eq_(result[0], datetime.datetime(2007, 12, 25, 0, 0)) class ServerSideCursorsTest(TestBase, AssertsExecutionResults): - __only_on__ = 'postgres' + __only_on__ = 'postgresql+psycopg2' @classmethod def setup_class(cls): @@ -927,8 +1121,8 @@ class ServerSideCursorsTest(TestBase, AssertsExecutionResults): class SpecialTypesTest(TestBase, ComparesTables): """test DDL and reflection of PG-specific types """ - __only_on__ = 'postgres' - __excluded_on__ = (('postgres', '<', (8, 3, 0)),) + __only_on__ = 'postgresql' + __excluded_on__ = (('postgresql', '<', (8, 3, 0)),) @classmethod def setup_class(cls): @@ -936,11 +1130,11 @@ class SpecialTypesTest(TestBase, ComparesTables): metadata = MetaData(testing.db) table = Table('sometable', metadata, - Column('id', postgres.PGUuid, primary_key=True), - Column('flag', postgres.PGBit), - Column('addr', postgres.PGInet), - Column('addr2', postgres.PGMacAddr), - Column('addr3', postgres.PGCidr) + Column('id', postgresql.PGUuid, primary_key=True), + Column('flag', postgresql.PGBit), + Column('addr', postgresql.PGInet), + Column('addr2', postgresql.PGMacAddr), + Column('addr3', postgresql.PGCidr) ) metadata.create_all() @@ -957,8 +1151,8 @@ class SpecialTypesTest(TestBase, ComparesTables): class MatchTest(TestBase, AssertsCompiledSQL): - __only_on__ = 'postgres' - __excluded_on__ = (('postgres', '<', (8, 3, 0)),) + __only_on__ = 'postgresql' + __excluded_on__ = (('postgresql', '<', (8, 3, 0)),) @classmethod def setup_class(cls): @@ -992,9 +1186,16 @@ class MatchTest(TestBase, AssertsCompiledSQL): def teardown_class(cls): metadata.drop_all() - def test_expression(self): + @testing.fails_on('postgresql+pg8000', 'uses positional') + @testing.fails_on('postgresql+zxjdbc', 'uses qmark') + def test_expression_pyformat(self): self.assert_compile(matchtable.c.title.match('somstr'), "matchtable.title @@ to_tsquery(%(title_1)s)") + @testing.fails_on('postgresql+psycopg2', 'uses pyformat') + @testing.fails_on('postgresql+zxjdbc', 'uses qmark') + def test_expression_positional(self): + self.assert_compile(matchtable.c.title.match('somstr'), "matchtable.title @@ to_tsquery(%s)") + def test_simple_match(self): results = matchtable.select().where(matchtable.c.title.match('python')).order_by(matchtable.c.id).execute().fetchall() eq_([2, 5], [r.id for r in results]) diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index eb4581e20..448ee947c 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -4,7 +4,7 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import datetime from sqlalchemy import * from sqlalchemy import exc, sql -from sqlalchemy.databases import sqlite +from sqlalchemy.dialects.sqlite import base as sqlite, pysqlite as pysqlite_dialect from sqlalchemy.test import * @@ -19,7 +19,7 @@ class TestTypes(TestBase, AssertsExecutionResults): meta = MetaData(testing.db) t = Table('bool_table', meta, Column('id', Integer, primary_key=True), - Column('boo', sqlite.SLBoolean)) + Column('boo', Boolean)) try: meta.create_all() @@ -39,7 +39,7 @@ class TestTypes(TestBase, AssertsExecutionResults): def test_time_microseconds(self): dt = datetime.datetime(2008, 6, 27, 12, 0, 0, 125) # 125 usec eq_(str(dt), '2008-06-27 12:00:00.000125') - sldt = sqlite.SLDateTime() + sldt = sqlite._SLDateTime() bp = sldt.bind_processor(None) eq_(bp(dt), '2008-06-27 12:00:00.000125') @@ -69,59 +69,44 @@ class TestTypes(TestBase, AssertsExecutionResults): bindproc = t.dialect_impl(dialect).bind_processor(dialect) assert not bindproc or isinstance(bindproc(u"some string"), unicode) - @testing.uses_deprecated('Using String type with no length') def test_type_reflection(self): # (ask_for, roundtripped_as_if_different) - specs = [( String(), sqlite.SLString(), ), - ( String(1), sqlite.SLString(1), ), - ( String(3), sqlite.SLString(3), ), - ( Text(), sqlite.SLText(), ), - ( Unicode(), sqlite.SLString(), ), - ( Unicode(1), sqlite.SLString(1), ), - ( Unicode(3), sqlite.SLString(3), ), - ( UnicodeText(), sqlite.SLText(), ), - ( CLOB, sqlite.SLText(), ), - ( sqlite.SLChar(1), ), - ( CHAR(3), sqlite.SLChar(3), ), - ( NCHAR(2), sqlite.SLChar(2), ), - ( SmallInteger(), sqlite.SLSmallInteger(), ), - ( sqlite.SLSmallInteger(), ), - ( Binary(3), sqlite.SLBinary(), ), - ( Binary(), sqlite.SLBinary() ), - ( sqlite.SLBinary(3), sqlite.SLBinary(), ), - ( NUMERIC, sqlite.SLNumeric(), ), - ( NUMERIC(10,2), sqlite.SLNumeric(10,2), ), - ( Numeric, sqlite.SLNumeric(), ), - ( Numeric(10, 2), sqlite.SLNumeric(10, 2), ), - ( DECIMAL, sqlite.SLNumeric(), ), - ( DECIMAL(10, 2), sqlite.SLNumeric(10, 2), ), - ( Float, sqlite.SLFloat(), ), - ( sqlite.SLNumeric(), ), - ( INT, sqlite.SLInteger(), ), - ( Integer, sqlite.SLInteger(), ), - ( sqlite.SLInteger(), ), - ( TIMESTAMP, sqlite.SLDateTime(), ), - ( DATETIME, sqlite.SLDateTime(), ), - ( DateTime, sqlite.SLDateTime(), ), - ( sqlite.SLDateTime(), ), - ( DATE, sqlite.SLDate(), ), - ( Date, sqlite.SLDate(), ), - ( sqlite.SLDate(), ), - ( TIME, sqlite.SLTime(), ), - ( Time, sqlite.SLTime(), ), - ( sqlite.SLTime(), ), - ( BOOLEAN, sqlite.SLBoolean(), ), - ( Boolean, sqlite.SLBoolean(), ), - ( sqlite.SLBoolean(), ), + specs = [( String(), String(), ), + ( String(1), String(1), ), + ( String(3), String(3), ), + ( Text(), Text(), ), + ( Unicode(), String(), ), + ( Unicode(1), String(1), ), + ( Unicode(3), String(3), ), + ( UnicodeText(), Text(), ), + ( CHAR(1), ), + ( CHAR(3), CHAR(3), ), + ( NUMERIC, NUMERIC(), ), + ( NUMERIC(10,2), NUMERIC(10,2), ), + ( Numeric, NUMERIC(), ), + ( Numeric(10, 2), NUMERIC(10, 2), ), + ( DECIMAL, DECIMAL(), ), + ( DECIMAL(10, 2), DECIMAL(10, 2), ), + ( Float, Float(), ), + ( NUMERIC(), ), + ( TIMESTAMP, TIMESTAMP(), ), + ( DATETIME, DATETIME(), ), + ( DateTime, DateTime(), ), + ( DateTime(), ), + ( DATE, DATE(), ), + ( Date, Date(), ), + ( TIME, TIME(), ), + ( Time, Time(), ), + ( BOOLEAN, BOOLEAN(), ), + ( Boolean, Boolean(), ), ] columns = [Column('c%i' % (i + 1), t[0]) for i, t in enumerate(specs)] db = testing.db m = MetaData(db) t_table = Table('types', m, *columns) + m.create_all() try: - m.create_all() - m2 = MetaData(db) rt = Table('types', m2, autoload=True) try: @@ -131,7 +116,7 @@ class TestTypes(TestBase, AssertsExecutionResults): expected = [len(c) > 1 and c[1] or c[0] for c in specs] for table in rt, rv: for i, reflected in enumerate(table.c): - assert isinstance(reflected.type, type(expected[i])), type(expected[i]) + assert isinstance(reflected.type, type(expected[i])), "%d: %r" % (i, type(expected[i])) finally: db.execute('DROP VIEW types_v') finally: @@ -163,7 +148,7 @@ class TestDefaults(TestBase, AssertsExecutionResults): rt = Table('t_defaults', m2, autoload=True) expected = [c[1] for c in specs] for i, reflected in enumerate(rt.c): - eq_(reflected.server_default.arg.text, expected[i]) + eq_(str(reflected.server_default.arg), expected[i]) finally: m.drop_all() @@ -173,7 +158,7 @@ class TestDefaults(TestBase, AssertsExecutionResults): db = testing.db m = MetaData(db) - expected = ["'my_default'", '0'] + expected = ["my_default", '0'] table = """CREATE TABLE r_defaults ( data VARCHAR(40) DEFAULT 'my_default', val INTEGER NOT NULL DEFAULT 0 @@ -184,7 +169,7 @@ class TestDefaults(TestBase, AssertsExecutionResults): rt = Table('r_defaults', m, autoload=True) for i, reflected in enumerate(rt.c): - eq_(reflected.server_default.arg.text, expected[i]) + eq_(str(reflected.server_default.arg), expected[i]) finally: db.execute("DROP TABLE r_defaults") @@ -247,24 +232,24 @@ class DialectTest(TestBase, AssertsExecutionResults): def test_attached_as_schema(self): cx = testing.db.connect() try: - cx.execute('ATTACH DATABASE ":memory:" AS alt_schema') + cx.execute('ATTACH DATABASE ":memory:" AS test_schema') dialect = cx.dialect - assert dialect.table_names(cx, 'alt_schema') == [] + assert dialect.table_names(cx, 'test_schema') == [] meta = MetaData(cx) Table('created', meta, Column('id', Integer), - schema='alt_schema') + schema='test_schema') alt_master = Table('sqlite_master', meta, autoload=True, - schema='alt_schema') + schema='test_schema') meta.create_all(cx) - eq_(dialect.table_names(cx, 'alt_schema'), + eq_(dialect.table_names(cx, 'test_schema'), ['created']) assert len(alt_master.c) > 0 meta.clear() reflected = Table('created', meta, autoload=True, - schema='alt_schema') + schema='test_schema') assert len(reflected.c) == 1 cx.execute(reflected.insert(), dict(id=1)) @@ -282,9 +267,9 @@ class DialectTest(TestBase, AssertsExecutionResults): # note that sqlite_master is cleared, above meta.drop_all() - assert dialect.table_names(cx, 'alt_schema') == [] + assert dialect.table_names(cx, 'test_schema') == [] finally: - cx.execute('DETACH DATABASE alt_schema') + cx.execute('DETACH DATABASE test_schema') @testing.exclude('sqlite', '<', (2, 6), 'no database support') def test_temp_table_reflection(self): @@ -305,6 +290,20 @@ class DialectTest(TestBase, AssertsExecutionResults): pass raise + def test_set_isolation_level(self): + """Test setting the read uncommitted/serializable levels""" + eng = create_engine(testing.db.url) + eq_(eng.execute("PRAGMA read_uncommitted").scalar(), 0) + + eng = create_engine(testing.db.url, isolation_level="READ UNCOMMITTED") + eq_(eng.execute("PRAGMA read_uncommitted").scalar(), 1) + + eng = create_engine(testing.db.url, isolation_level="SERIALIZABLE") + eq_(eng.execute("PRAGMA read_uncommitted").scalar(), 0) + + assert_raises(exc.ArgumentError, create_engine, testing.db.url, + isolation_level="FOO") + class SQLTest(TestBase, AssertsCompiledSQL): """Tests SQLite-dialect specific compilation.""" diff --git a/test/engine/test_bind.py b/test/engine/test_bind.py index 7fd3009bc..1122f1632 100644 --- a/test/engine/test_bind.py +++ b/test/engine/test_bind.py @@ -121,7 +121,7 @@ class BindTest(testing.TestBase): table = Table('test_table', metadata, Column('foo', Integer)) - metadata.connect(bind) + metadata.bind = bind assert metadata.bind is table.bind is bind metadata.create_all() @@ -199,7 +199,7 @@ class BindTest(testing.TestBase): try: e = elem(bind=bind) assert e.bind is bind - e.execute() + e.execute().close() finally: if isinstance(bind, engine.Connection): bind.close() diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py index 5716006d9..434a5d873 100644 --- a/test/engine/test_ddlevents.py +++ b/test/engine/test_ddlevents.py @@ -1,12 +1,13 @@ -from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message -from sqlalchemy.schema import DDL +from sqlalchemy.test.testing import assert_raises, assert_raises_message +from sqlalchemy.schema import DDL, CheckConstraint, AddConstraint, DropConstraint from sqlalchemy import create_engine from sqlalchemy import MetaData, Integer, String from sqlalchemy.test.schema import Table from sqlalchemy.test.schema import Column import sqlalchemy as tsa from sqlalchemy.test import TestBase, testing, engines - +from sqlalchemy.test.testing import AssertsCompiledSQL +from nose import SkipTest class DDLEventTest(TestBase): class Canary(object): @@ -15,25 +16,25 @@ class DDLEventTest(TestBase): self.schema_item = schema_item self.bind = bind - def before_create(self, action, schema_item, bind): + def before_create(self, action, schema_item, bind, **kw): assert self.state is None assert schema_item is self.schema_item assert bind is self.bind self.state = action - def after_create(self, action, schema_item, bind): + def after_create(self, action, schema_item, bind, **kw): assert self.state in ('before-create', 'skipped') assert schema_item is self.schema_item assert bind is self.bind self.state = action - def before_drop(self, action, schema_item, bind): + def before_drop(self, action, schema_item, bind, **kw): assert self.state is None assert schema_item is self.schema_item assert bind is self.bind self.state = action - def after_drop(self, action, schema_item, bind): + def after_drop(self, action, schema_item, bind, **kw): assert self.state in ('before-drop', 'skipped') assert schema_item is self.schema_item assert bind is self.bind @@ -232,7 +233,33 @@ class DDLExecutionTest(TestBase): assert 'klptzyxm' not in strings assert 'xyzzy' in strings assert 'fnord' in strings - + + def test_conditional_constraint(self): + metadata, users, engine = self.metadata, self.users, self.engine + nonpg_mock = engines.mock_engine(dialect_name='sqlite') + pg_mock = engines.mock_engine(dialect_name='postgresql') + + constraint = CheckConstraint('a < b',name="my_test_constraint", table=users) + + # by placing the constraint in an Add/Drop construct, + # the 'inline_ddl' flag is set to False + AddConstraint(constraint, on='postgresql').execute_at("after-create", users) + DropConstraint(constraint, on='postgresql').execute_at("before-drop", users) + + metadata.create_all(bind=nonpg_mock) + strings = " ".join(str(x) for x in nonpg_mock.mock) + assert "my_test_constraint" not in strings + metadata.drop_all(bind=nonpg_mock) + strings = " ".join(str(x) for x in nonpg_mock.mock) + assert "my_test_constraint" not in strings + + metadata.create_all(bind=pg_mock) + strings = " ".join(str(x) for x in pg_mock.mock) + assert "my_test_constraint" in strings + metadata.drop_all(bind=pg_mock) + strings = " ".join(str(x) for x in pg_mock.mock) + assert "my_test_constraint" in strings + def test_metadata(self): metadata, engine = self.metadata, self.engine DDL('mxyzptlk').execute_at('before-create', metadata) @@ -255,7 +282,10 @@ class DDLExecutionTest(TestBase): assert 'fnord' in strings def test_ddl_execute(self): - engine = create_engine('sqlite:///') + try: + engine = create_engine('sqlite:///') + except ImportError: + raise SkipTest('Requires sqlite') cx = engine.connect() table = self.users ddl = DDL('SELECT 1') @@ -286,7 +316,7 @@ class DDLExecutionTest(TestBase): r = eval(py) assert list(r) == [(1,)], py -class DDLTest(TestBase): +class DDLTest(TestBase, AssertsCompiledSQL): def mock_engine(self): executor = lambda *a, **kw: None engine = create_engine(testing.db.name + '://', @@ -297,7 +327,6 @@ class DDLTest(TestBase): def test_tokens(self): m = MetaData() - bind = self.mock_engine() sane_alone = Table('t', m, Column('id', Integer)) sane_schema = Table('t', m, Column('id', Integer), schema='s') insane_alone = Table('t t', m, Column('id', Integer)) @@ -305,20 +334,21 @@ class DDLTest(TestBase): ddl = DDL('%(schema)s-%(table)s-%(fullname)s') - eq_(ddl._expand(sane_alone, bind), '-t-t') - eq_(ddl._expand(sane_schema, bind), 's-t-s.t') - eq_(ddl._expand(insane_alone, bind), '-"t t"-"t t"') - eq_(ddl._expand(insane_schema, bind), - '"s s"-"t t"-"s s"."t t"') + dialect = self.mock_engine().dialect + self.assert_compile(ddl.against(sane_alone), '-t-t', dialect=dialect) + self.assert_compile(ddl.against(sane_schema), 's-t-s.t', dialect=dialect) + self.assert_compile(ddl.against(insane_alone), '-"t t"-"t t"', dialect=dialect) + self.assert_compile(ddl.against(insane_schema), '"s s"-"t t"-"s s"."t t"', dialect=dialect) # overrides are used piece-meal and verbatim. ddl = DDL('%(schema)s-%(table)s-%(fullname)s-%(bonus)s', context={'schema':'S S', 'table': 'T T', 'bonus': 'b'}) - eq_(ddl._expand(sane_alone, bind), 'S S-T T-t-b') - eq_(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b') - eq_(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b') - eq_(ddl._expand(insane_schema, bind), - 'S S-T T-"s s"."t t"-b') + + self.assert_compile(ddl.against(sane_alone), 'S S-T T-t-b', dialect=dialect) + self.assert_compile(ddl.against(sane_schema), 'S S-T T-s.t-b', dialect=dialect) + self.assert_compile(ddl.against(insane_alone), 'S S-T T-"t t"-b', dialect=dialect) + self.assert_compile(ddl.against(insane_schema), 'S S-T T-"s s"."t t"-b', dialect=dialect) + def test_filter(self): cx = self.mock_engine() diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 08bf80fe2..4783c5508 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -15,18 +15,20 @@ class ExecuteTest(TestBase): global users, metadata metadata = MetaData(testing.db) users = Table('users', metadata, - Column('user_id', INT, primary_key = True), + Column('user_id', INT, primary_key = True, test_needs_autoincrement=True), Column('user_name', VARCHAR(20)), ) metadata.create_all() + @engines.close_first def teardown(self): testing.db.connect().execute(users.delete()) + @classmethod def teardown_class(cls): metadata.drop_all() - @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite') + @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite', 'mysql+pyodbc', '+zxjdbc') def test_raw_qmark(self): for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack")) @@ -38,7 +40,8 @@ class ExecuteTest(TestBase): assert res.fetchall() == [(1, "jack"), (2, "fred"), (3, "ed"), (4, "horse"), (5, "barney"), (6, "donkey"), (7, 'sally')] conn.execute("delete from users") - @testing.fails_on_everything_except('mysql', 'postgres') + @testing.fails_on_everything_except('mysql+mysqldb', 'postgresql') + @testing.fails_on('postgresql+zxjdbc', 'sprintf not supported') # some psycopg2 versions bomb this. def test_raw_sprintf(self): for conn in (testing.db, testing.db.connect()): @@ -52,8 +55,8 @@ class ExecuteTest(TestBase): # pyformat is supported for mysql, but skipping because a few driver # versions have a bug that bombs out on this test. (1.2.2b3, 1.2.2c1, 1.2.2) - @testing.skip_if(lambda: testing.against('mysql'), 'db-api flaky') - @testing.fails_on_everything_except('postgres') + @testing.skip_if(lambda: testing.against('mysql+mysqldb'), 'db-api flaky') + @testing.fails_on_everything_except('postgresql+psycopg2') def test_raw_python(self): for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':1, 'name':'jack'}) @@ -63,7 +66,7 @@ class ExecuteTest(TestBase): assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')] conn.execute("delete from users") - @testing.fails_on_everything_except('sqlite', 'oracle') + @testing.fails_on_everything_except('sqlite', 'oracle+cx_oracle') def test_raw_named(self): for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':1, 'name':'jack'}) @@ -81,11 +84,12 @@ class ExecuteTest(TestBase): except tsa.exc.DBAPIError: assert True - @testing.fails_on('mssql', 'rowcount returns -1') def test_empty_insert(self): """test that execute() interprets [] as a list with no params""" result = testing.db.execute(users.insert().values(user_name=bindparam('name')), []) - eq_(result.rowcount, 1) + eq_(testing.db.execute(users.select()).fetchall(), [ + (1, None) + ]) class ProxyConnectionTest(TestBase): @testing.fails_on('firebird', 'Data type unknown') @@ -102,6 +106,7 @@ class ProxyConnectionTest(TestBase): return execute(clauseelement, *multiparams, **params) def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): + print "CE", statement, parameters cursor_stmts.append( (statement, parameters, None) ) @@ -118,8 +123,8 @@ class ProxyConnectionTest(TestBase): break for engine in ( - engines.testing_engine(options=dict(proxy=MyProxy())), - engines.testing_engine(options=dict(proxy=MyProxy(), strategy='threadlocal')) + engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy())), + engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy(), strategy='threadlocal')) ): m = MetaData(engine) @@ -131,6 +136,7 @@ class ProxyConnectionTest(TestBase): t1.insert().execute(c1=6) assert engine.execute("select * from t1").fetchall() == [(5, 'some data'), (6, 'foo')] finally: + pass m.drop_all() engine.dispose() @@ -143,14 +149,14 @@ class ProxyConnectionTest(TestBase): ("DROP TABLE t1", {}, None) ] - if engine.dialect.preexecute_pk_sequences: + if True: # or engine.dialect.preexecute_pk_sequences: cursor = [ - ("CREATE TABLE t1", {}, None), + ("CREATE TABLE t1", {}, ()), ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']), ("SELECT lower", {'lower_2':'Foo'}, ['Foo']), ("INSERT INTO t1 (c1, c2)", {'c2': 'foo', 'c1': 6}, [6, 'foo']), - ("select * from t1", {}, None), - ("DROP TABLE t1", {}, None) + ("select * from t1", {}, ()), + ("DROP TABLE t1", {}, ()) ] else: cursor = [ diff --git a/test/engine/test_metadata.py b/test/engine/test_metadata.py index ca4fbaa48..784a7b9ce 100644 --- a/test/engine/test_metadata.py +++ b/test/engine/test_metadata.py @@ -1,11 +1,11 @@ from sqlalchemy.test.testing import assert_raises, assert_raises_message import pickle -from sqlalchemy import MetaData -from sqlalchemy import Integer, String, UniqueConstraint, CheckConstraint, ForeignKey +from sqlalchemy import Integer, String, UniqueConstraint, CheckConstraint, ForeignKey, MetaData from sqlalchemy.test.schema import Table from sqlalchemy.test.schema import Column +from sqlalchemy import schema import sqlalchemy as tsa -from sqlalchemy.test import TestBase, ComparesTables, testing, engines +from sqlalchemy.test import TestBase, ComparesTables, AssertsCompiledSQL, testing, engines from sqlalchemy.test.testing import eq_ class MetaDataTest(TestBase, ComparesTables): @@ -83,7 +83,7 @@ class MetaDataTest(TestBase, ComparesTables): meta.create_all(testing.db) try: - for test, has_constraints in ((test_to_metadata, True), (test_pickle, True), (test_pickle_via_reflect, False)): + for test, has_constraints in ((test_to_metadata, True), (test_pickle, True),(test_pickle_via_reflect, False)): table_c, table2_c = test() self.assert_tables_equal(table, table_c) self.assert_tables_equal(table2, table2_c) @@ -143,29 +143,30 @@ class MetaDataTest(TestBase, ComparesTables): MetaData(testing.db), autoload=True) -class TableOptionsTest(TestBase): - def setup(self): - self.engine = engines.mock_engine() - self.metadata = MetaData(self.engine) - +class TableOptionsTest(TestBase, AssertsCompiledSQL): def test_prefixes(self): - table1 = Table("temporary_table_1", self.metadata, + table1 = Table("temporary_table_1", MetaData(), Column("col1", Integer), prefixes = ["TEMPORARY"]) - table1.create() - assert [str(x) for x in self.engine.mock if 'CREATE TEMPORARY TABLE' in str(x)] - del self.engine.mock[:] - table2 = Table("temporary_table_2", self.metadata, + + self.assert_compile( + schema.CreateTable(table1), + "CREATE TEMPORARY TABLE temporary_table_1 (col1 INTEGER)" + ) + + table2 = Table("temporary_table_2", MetaData(), Column("col1", Integer), prefixes = ["VIRTUAL"]) - table2.create() - assert [str(x) for x in self.engine.mock if 'CREATE VIRTUAL TABLE' in str(x)] + self.assert_compile( + schema.CreateTable(table2), + "CREATE VIRTUAL TABLE temporary_table_2 (col1 INTEGER)" + ) def test_table_info(self): - - t1 = Table('foo', self.metadata, info={'x':'y'}) - t2 = Table('bar', self.metadata, info={}) - t3 = Table('bat', self.metadata) + metadata = MetaData() + t1 = Table('foo', metadata, info={'x':'y'}) + t2 = Table('bar', metadata, info={}) + t3 = Table('bat', metadata) assert t1.info == {'x':'y'} assert t2.info == {} assert t3.info == {} diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py index 6b7ac37b2..90c0969be 100644 --- a/test/engine/test_parseconnect.py +++ b/test/engine/test_parseconnect.py @@ -1,4 +1,6 @@ -import ConfigParser, StringIO +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import ConfigParser +import StringIO import sqlalchemy.engine.url as url from sqlalchemy import create_engine, engine_from_config import sqlalchemy as tsa @@ -28,8 +30,6 @@ class ParseConnectTest(TestBase): 'dbtype://username:apples%2Foranges@hostspec/mydatabase', ): u = url.make_url(text) - print u, text - print "username=", u.username, "password=", u.password, "database=", u.database, "host=", u.host assert u.drivername == 'dbtype' assert u.username == 'username' or u.username is None assert u.password == 'password' or u.password == 'apples/oranges' or u.password is None @@ -41,21 +41,28 @@ class CreateEngineTest(TestBase): def test_connect_query(self): dbapi = MockDBAPI(foober='12', lala='18', fooz='somevalue') - # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg - e = create_engine('postgres://scott:tiger@somehost/test?foober=12&lala=18&fooz=somevalue', module=dbapi) + e = create_engine( + 'postgresql://scott:tiger@somehost/test?foober=12&lala=18&fooz=somevalue', + module=dbapi, + _initialize=False + ) c = e.connect() def test_kwargs(self): dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue') - # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg - e = create_engine('postgres://scott:tiger@somehost/test?fooz=somevalue', connect_args={'foober':12, 'lala':18, 'hoho':{'this':'dict'}}, module=dbapi) + e = create_engine( + 'postgresql://scott:tiger@somehost/test?fooz=somevalue', + connect_args={'foober':12, 'lala':18, 'hoho':{'this':'dict'}}, + module=dbapi, + _initialize=False + ) c = e.connect() def test_coerce_config(self): raw = r""" [prefixed] -sqlalchemy.url=postgres://scott:tiger@somehost/test?fooz=somevalue +sqlalchemy.url=postgresql://scott:tiger@somehost/test?fooz=somevalue sqlalchemy.convert_unicode=0 sqlalchemy.echo=false sqlalchemy.echo_pool=1 @@ -65,7 +72,7 @@ sqlalchemy.pool_size=2 sqlalchemy.pool_threadlocal=1 sqlalchemy.pool_timeout=10 [plain] -url=postgres://scott:tiger@somehost/test?fooz=somevalue +url=postgresql://scott:tiger@somehost/test?fooz=somevalue convert_unicode=0 echo=0 echo_pool=1 @@ -79,7 +86,7 @@ pool_timeout=10 ini.readfp(StringIO.StringIO(raw)) expected = { - 'url': 'postgres://scott:tiger@somehost/test?fooz=somevalue', + 'url': 'postgresql://scott:tiger@somehost/test?fooz=somevalue', 'convert_unicode': 0, 'echo': False, 'echo_pool': True, @@ -97,17 +104,17 @@ pool_timeout=10 self.assert_(tsa.engine._coerce_config(plain, '') == expected) def test_engine_from_config(self): - dbapi = MockDBAPI() + dbapi = mock_dbapi config = { - 'sqlalchemy.url':'postgres://scott:tiger@somehost/test?fooz=somevalue', + 'sqlalchemy.url':'postgresql://scott:tiger@somehost/test?fooz=somevalue', 'sqlalchemy.pool_recycle':'50', 'sqlalchemy.echo':'true' } e = engine_from_config(config, module=dbapi) assert e.pool._recycle == 50 - assert e.url == url.make_url('postgres://scott:tiger@somehost/test?fooz=somevalue') + assert e.url == url.make_url('postgresql://scott:tiger@somehost/test?fooz=somevalue') assert e.echo is True def test_custom(self): @@ -116,109 +123,77 @@ pool_timeout=10 def connect(): return dbapi.connect(foober=12, lala=18, fooz='somevalue', hoho={'this':'dict'}) - # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg - e = create_engine('postgres://', creator=connect, module=dbapi) + # start the postgresql dialect, but put our mock DBAPI as the module instead of psycopg + e = create_engine('postgresql://', creator=connect, module=dbapi, _initialize=False) c = e.connect() def test_recycle(self): dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue') - e = create_engine('postgres://', pool_recycle=472, module=dbapi) + e = create_engine('postgresql://', pool_recycle=472, module=dbapi, _initialize=False) assert e.pool._recycle == 472 def test_badargs(self): - # good arg, use MockDBAPI to prevent oracle import errors - e = create_engine('oracle://', use_ansi=True, module=MockDBAPI()) - - try: - e = create_engine("foobar://", module=MockDBAPI()) - assert False - except ImportError: - assert True + assert_raises(ImportError, create_engine, "foobar://", module=mock_dbapi) # bad arg - try: - e = create_engine('postgres://', use_ansi=True, module=MockDBAPI()) - assert False - except TypeError: - assert True + assert_raises(TypeError, create_engine, 'postgresql://', use_ansi=True, module=mock_dbapi) # bad arg - try: - e = create_engine('oracle://', lala=5, use_ansi=True, module=MockDBAPI()) - assert False - except TypeError: - assert True + assert_raises(TypeError, create_engine, 'oracle://', lala=5, use_ansi=True, module=mock_dbapi) - try: - e = create_engine('postgres://', lala=5, module=MockDBAPI()) - assert False - except TypeError: - assert True + assert_raises(TypeError, create_engine, 'postgresql://', lala=5, module=mock_dbapi) - try: - e = create_engine('sqlite://', lala=5) - assert False - except TypeError: - assert True + assert_raises(TypeError, create_engine,'sqlite://', lala=5, module=mock_sqlite_dbapi) - try: - e = create_engine('mysql://', use_unicode=True, module=MockDBAPI()) - assert False - except TypeError: - assert True - - try: - # sqlite uses SingletonThreadPool which doesnt have max_overflow - e = create_engine('sqlite://', max_overflow=5) - assert False - except TypeError: - assert True + assert_raises(TypeError, create_engine, 'mysql+mysqldb://', use_unicode=True, module=mock_dbapi) - e = create_engine('mysql://', module=MockDBAPI(), connect_args={'use_unicode':True}, convert_unicode=True) + # sqlite uses SingletonThreadPool which doesnt have max_overflow + assert_raises(TypeError, create_engine, 'sqlite://', max_overflow=5, + module=mock_sqlite_dbapi) - e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True) try: - c = e.connect() - assert False - except tsa.exc.DBAPIError: - assert True + e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True) + except ImportError: + # no sqlite + pass + else: + # raises DBAPIerror due to use_unicode not a sqlite arg + assert_raises(tsa.exc.DBAPIError, e.connect) def test_urlattr(self): """test the url attribute on ``Engine``.""" - e = create_engine('mysql://scott:tiger@localhost/test', module=MockDBAPI()) + e = create_engine('mysql://scott:tiger@localhost/test', module=mock_dbapi, _initialize=False) u = url.make_url('mysql://scott:tiger@localhost/test') - e2 = create_engine(u, module=MockDBAPI()) + e2 = create_engine(u, module=mock_dbapi, _initialize=False) assert e.url.drivername == e2.url.drivername == 'mysql' assert e.url.username == e2.url.username == 'scott' assert e2.url is u def test_poolargs(self): """test that connection pool args make it thru""" - e = create_engine('postgres://', creator=None, pool_recycle=50, echo_pool=None, module=MockDBAPI()) + e = create_engine('postgresql://', creator=None, pool_recycle=50, echo_pool=None, module=mock_dbapi, _initialize=False) assert e.pool._recycle == 50 # these args work for QueuePool - e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, module=MockDBAPI()) + e = create_engine('postgresql://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, module=mock_dbapi) - try: - # but not SingletonThreadPool - e = create_engine('sqlite://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.SingletonThreadPool) - assert False - except TypeError: - assert True + # but not SingletonThreadPool + assert_raises(TypeError, create_engine, 'sqlite://', max_overflow=8, pool_timeout=60, + poolclass=tsa.pool.SingletonThreadPool, module=mock_sqlite_dbapi) class MockDBAPI(object): def __init__(self, **kwargs): self.kwargs = kwargs self.paramstyle = 'named' - def connect(self, **kwargs): - print kwargs, self.kwargs + def connect(self, *args, **kwargs): for k in self.kwargs: assert k in kwargs, "key %s not present in dictionary" % k assert kwargs[k]==self.kwargs[k], "value %s does not match %s" % (kwargs[k], self.kwargs[k]) return MockConnection() class MockConnection(object): + def get_server_info(self): + return "5.0" def close(self): pass def cursor(self): @@ -227,4 +202,6 @@ class MockCursor(object): def close(self): pass mock_dbapi = MockDBAPI() - +mock_sqlite_dbapi = msd = MockDBAPI() +msd.version_info = msd.sqlite_version_info = (99, 9, 9) +msd.sqlite_version = '99.9.9' diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index d135ad337..68637281e 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -1,7 +1,8 @@ -import threading, time, gc -from sqlalchemy import pool, interfaces +import threading, time +from sqlalchemy import pool, interfaces, create_engine, select import sqlalchemy as tsa -from sqlalchemy.test import TestBase +from sqlalchemy.test import TestBase, testing +from sqlalchemy.test.util import gc_collect, lazy_gc mcid = 1 @@ -51,7 +52,6 @@ class PoolTest(PoolTestBase): connection2 = manager.connect('foo.db') connection3 = manager.connect('bar.db') - print "connection " + repr(connection) self.assert_(connection.cursor() is not None) self.assert_(connection is connection2) self.assert_(connection2 is not connection3) @@ -70,8 +70,6 @@ class PoolTest(PoolTestBase): connection = manager.connect('foo.db') connection2 = manager.connect('foo.db') - print "connection " + repr(connection) - self.assert_(connection.cursor() is not None) self.assert_(connection is not connection2) @@ -103,7 +101,8 @@ class PoolTest(PoolTestBase): c2.close() else: c2 = None - + lazy_gc() + if useclose: c1 = p.connect() c2 = p.connect() @@ -117,6 +116,8 @@ class PoolTest(PoolTestBase): # extra tests with QueuePool to ensure connections get __del__()ed when dereferenced if isinstance(p, pool.QueuePool): + lazy_gc() + self.assert_(p.checkedout() == 0) c1 = p.connect() c2 = p.connect() @@ -126,6 +127,7 @@ class PoolTest(PoolTestBase): else: c2 = None c1 = None + lazy_gc() self.assert_(p.checkedout() == 0) def test_properties(self): @@ -164,6 +166,8 @@ class PoolTest(PoolTestBase): def __init__(self): if hasattr(self, 'connect'): self.connect = self.inst_connect + if hasattr(self, 'first_connect'): + self.first_connect = self.inst_first_connect if hasattr(self, 'checkout'): self.checkout = self.inst_checkout if hasattr(self, 'checkin'): @@ -171,14 +175,17 @@ class PoolTest(PoolTestBase): self.clear() def clear(self): self.connected = [] + self.first_connected = [] self.checked_out = [] self.checked_in = [] - def assert_total(innerself, conn, cout, cin): + def assert_total(innerself, conn, fconn, cout, cin): self.assert_(len(innerself.connected) == conn) + self.assert_(len(innerself.first_connected) == fconn) self.assert_(len(innerself.checked_out) == cout) self.assert_(len(innerself.checked_in) == cin) - def assert_in(innerself, item, in_conn, in_cout, in_cin): + def assert_in(innerself, item, in_conn, in_fconn, in_cout, in_cin): self.assert_((item in innerself.connected) == in_conn) + self.assert_((item in innerself.first_connected) == in_fconn) self.assert_((item in innerself.checked_out) == in_cout) self.assert_((item in innerself.checked_in) == in_cin) def inst_connect(self, con, record): @@ -186,6 +193,11 @@ class PoolTest(PoolTestBase): assert con is not None assert record is not None self.connected.append(con) + def inst_first_connect(self, con, record): + print "first_connect(%s, %s)" % (con, record) + assert con is not None + assert record is not None + self.first_connected.append(con) def inst_checkout(self, con, record, proxy): print "checkout(%s, %s, %s)" % (con, record, proxy) assert con is not None @@ -203,6 +215,9 @@ class PoolTest(PoolTestBase): class ListenConnect(InstrumentingListener): def connect(self, con, record): pass + class ListenFirstConnect(InstrumentingListener): + def first_connect(self, con, record): + pass class ListenCheckOut(InstrumentingListener): def checkout(self, con, record, proxy, num): pass @@ -214,40 +229,43 @@ class PoolTest(PoolTestBase): return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'), use_threadlocal=False, **kw) - def assert_listeners(p, total, conn, cout, cin): + def assert_listeners(p, total, conn, fconn, cout, cin): for instance in (p, p.recreate()): self.assert_(len(instance.listeners) == total) self.assert_(len(instance._on_connect) == conn) + self.assert_(len(instance._on_first_connect) == fconn) self.assert_(len(instance._on_checkout) == cout) self.assert_(len(instance._on_checkin) == cin) p = _pool() - assert_listeners(p, 0, 0, 0, 0) + assert_listeners(p, 0, 0, 0, 0, 0) p.add_listener(ListenAll()) - assert_listeners(p, 1, 1, 1, 1) + assert_listeners(p, 1, 1, 1, 1, 1) p.add_listener(ListenConnect()) - assert_listeners(p, 2, 2, 1, 1) + assert_listeners(p, 2, 2, 1, 1, 1) + + p.add_listener(ListenFirstConnect()) + assert_listeners(p, 3, 2, 2, 1, 1) p.add_listener(ListenCheckOut()) - assert_listeners(p, 3, 2, 2, 1) + assert_listeners(p, 4, 2, 2, 2, 1) p.add_listener(ListenCheckIn()) - assert_listeners(p, 4, 2, 2, 2) + assert_listeners(p, 5, 2, 2, 2, 2) del p - print "----" snoop = ListenAll() p = _pool(listeners=[snoop]) - assert_listeners(p, 1, 1, 1, 1) + assert_listeners(p, 1, 1, 1, 1, 1) c = p.connect() - snoop.assert_total(1, 1, 0) + snoop.assert_total(1, 1, 1, 0) cc = c.connection - snoop.assert_in(cc, True, True, False) + snoop.assert_in(cc, True, True, True, False) c.close() - snoop.assert_in(cc, True, True, True) + snoop.assert_in(cc, True, True, True, True) del c, cc snoop.clear() @@ -255,10 +273,11 @@ class PoolTest(PoolTestBase): # this one depends on immediate gc c = p.connect() cc = c.connection - snoop.assert_in(cc, False, True, False) - snoop.assert_total(0, 1, 0) + snoop.assert_in(cc, False, False, True, False) + snoop.assert_total(0, 0, 1, 0) del c, cc - snoop.assert_total(0, 1, 1) + lazy_gc() + snoop.assert_total(0, 0, 1, 1) p.dispose() snoop.clear() @@ -266,44 +285,46 @@ class PoolTest(PoolTestBase): c = p.connect() c.close() c = p.connect() - snoop.assert_total(1, 2, 1) + snoop.assert_total(1, 0, 2, 1) c.close() - snoop.assert_total(1, 2, 2) + snoop.assert_total(1, 0, 2, 2) # invalidation p.dispose() snoop.clear() c = p.connect() - snoop.assert_total(1, 1, 0) + snoop.assert_total(1, 0, 1, 0) c.invalidate() - snoop.assert_total(1, 1, 1) + snoop.assert_total(1, 0, 1, 1) c.close() - snoop.assert_total(1, 1, 1) + snoop.assert_total(1, 0, 1, 1) del c - snoop.assert_total(1, 1, 1) + lazy_gc() + snoop.assert_total(1, 0, 1, 1) c = p.connect() - snoop.assert_total(2, 2, 1) + snoop.assert_total(2, 0, 2, 1) c.close() del c - snoop.assert_total(2, 2, 2) + lazy_gc() + snoop.assert_total(2, 0, 2, 2) # detached p.dispose() snoop.clear() c = p.connect() - snoop.assert_total(1, 1, 0) + snoop.assert_total(1, 0, 1, 0) c.detach() - snoop.assert_total(1, 1, 0) + snoop.assert_total(1, 0, 1, 0) c.close() del c - snoop.assert_total(1, 1, 0) + snoop.assert_total(1, 0, 1, 0) c = p.connect() - snoop.assert_total(2, 2, 0) + snoop.assert_total(2, 0, 2, 0) c.close() del c - snoop.assert_total(2, 2, 1) + snoop.assert_total(2, 0, 2, 1) def test_listeners_callables(self): dbapi = MockDBAPI() @@ -362,262 +383,293 @@ class PoolTest(PoolTestBase): c.close() assert counts == [1, 2, 3] + def test_listener_after_oninit(self): + """Test that listeners are called after OnInit is removed""" + called = [] + def listener(*args): + called.append(True) + listener.connect = listener + engine = create_engine(testing.db.url) + engine.pool.add_listener(listener) + engine.execute(select([1])) + assert called, "Listener not called on connect" + + class QueuePoolTest(PoolTestBase): - def testqueuepool_del(self): - self._do_testqueuepool(useclose=False) - - def testqueuepool_close(self): - self._do_testqueuepool(useclose=True) - - def _do_testqueuepool(self, useclose=False): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False) - - def status(pool): - tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout()) - print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup - return tup - - c1 = p.connect() - self.assert_(status(p) == (3,0,-2,1)) - c2 = p.connect() - self.assert_(status(p) == (3,0,-1,2)) - c3 = p.connect() - self.assert_(status(p) == (3,0,0,3)) - c4 = p.connect() - self.assert_(status(p) == (3,0,1,4)) - c5 = p.connect() - self.assert_(status(p) == (3,0,2,5)) - c6 = p.connect() - self.assert_(status(p) == (3,0,3,6)) - if useclose: - c4.close() - c3.close() - c2.close() - else: - c4 = c3 = c2 = None - self.assert_(status(p) == (3,3,3,3)) - if useclose: - c1.close() - c5.close() - c6.close() - else: - c1 = c5 = c6 = None - self.assert_(status(p) == (3,3,0,0)) - c1 = p.connect() - c2 = p.connect() - self.assert_(status(p) == (3, 1, 0, 2), status(p)) - if useclose: - c2.close() - else: - c2 = None - self.assert_(status(p) == (3, 2, 0, 1)) - - def test_timeout(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2) - c1 = p.connect() - c2 = p.connect() - c3 = p.connect() - now = time.time() - try: - c4 = p.connect() - assert False - except tsa.exc.TimeoutError, e: - assert int(time.time() - now) == 2 - - def test_timeout_race(self): - # test a race condition where the initial connecting threads all race - # to queue.Empty, then block on the mutex. each thread consumes a - # connection as they go in. when the limit is reached, the remaining - # threads go in, and get TimeoutError; even though they never got to - # wait for the timeout on queue.get(). the fix involves checking the - # timeout again within the mutex, and if so, unlocking and throwing - # them back to the start of do_get() - p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3) - timeouts = [] - def checkout(): - for x in xrange(1): - now = time.time() - try: - c1 = p.connect() - except tsa.exc.TimeoutError, e: - timeouts.append(int(time.time()) - now) - continue - time.sleep(4) - c1.close() - - threads = [] - for i in xrange(10): - th = threading.Thread(target=checkout) - th.start() - threads.append(th) - for th in threads: - th.join() - - print timeouts - assert len(timeouts) > 0 - for t in timeouts: - assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts) - - def _test_overflow(self, thread_count, max_overflow): - def creator(): - time.sleep(.05) - return mock_dbapi.connect() - - p = pool.QueuePool(creator=creator, - pool_size=3, timeout=2, - max_overflow=max_overflow) - peaks = [] - def whammy(): - for i in range(10): - try: - con = p.connect() - time.sleep(.005) - peaks.append(p.overflow()) - con.close() - del con - except tsa.exc.TimeoutError: - pass - threads = [] - for i in xrange(thread_count): - th = threading.Thread(target=whammy) - th.start() - threads.append(th) - for th in threads: - th.join() - - self.assert_(max(peaks) <= max_overflow) - - def test_no_overflow(self): - self._test_overflow(40, 0) - - def test_max_overflow(self): - self._test_overflow(40, 5) - - def test_mixed_close(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c2 = p.connect() - assert c1 is c2 - c1.close() - c2 = None - assert p.checkedout() == 1 - c1 = None - assert p.checkedout() == 0 - - def test_weakref_kaboom(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c2 = p.connect() - c1.close() - c2 = None - del c1 - del c2 - gc.collect() - assert p.checkedout() == 0 - c3 = p.connect() - assert c3 is not None - - def test_trick_the_counter(self): - """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread - with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an - ambiguous counter. i.e. its not true reference counting.""" - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c2 = p.connect() - assert c1 is c2 - c1.close() - c2 = p.connect() - c2.close() - self.assert_(p.checkedout() != 0) - - c2.close() - self.assert_(p.checkedout() == 0) - - def test_recycle(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3) - - c1 = p.connect() - c_id = id(c1.connection) - c1.close() - c2 = p.connect() - assert id(c2.connection) == c_id - c2.close() - time.sleep(4) - c3= p.connect() - assert id(c3.connection) != c_id - - def test_invalidate(self): - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - c1 = p.connect() - c_id = c1.connection.id - c1.close(); c1=None - c1 = p.connect() - assert c1.connection.id == c_id - c1.invalidate() - c1 = None - - c1 = p.connect() - assert c1.connection.id != c_id - - def test_recreate(self): - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - p2 = p.recreate() - assert p2.size() == 1 - assert p2._use_threadlocal is False - assert p2._max_overflow == 0 - - def test_reconnect(self): - """tests reconnect operations at the pool level. SA's engine/dialect includes another - layer of reconnect support for 'database was lost' errors.""" + def testqueuepool_del(self): + self._do_testqueuepool(useclose=False) + + def testqueuepool_close(self): + self._do_testqueuepool(useclose=True) + + def _do_testqueuepool(self, useclose=False): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False) + + def status(pool): + tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout()) + print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup + return tup + + c1 = p.connect() + self.assert_(status(p) == (3,0,-2,1)) + c2 = p.connect() + self.assert_(status(p) == (3,0,-1,2)) + c3 = p.connect() + self.assert_(status(p) == (3,0,0,3)) + c4 = p.connect() + self.assert_(status(p) == (3,0,1,4)) + c5 = p.connect() + self.assert_(status(p) == (3,0,2,5)) + c6 = p.connect() + self.assert_(status(p) == (3,0,3,6)) + if useclose: + c4.close() + c3.close() + c2.close() + else: + c4 = c3 = c2 = None + lazy_gc() + + self.assert_(status(p) == (3,3,3,3)) + if useclose: + c1.close() + c5.close() + c6.close() + else: + c1 = c5 = c6 = None + lazy_gc() + + self.assert_(status(p) == (3,3,0,0)) + + c1 = p.connect() + c2 = p.connect() + self.assert_(status(p) == (3, 1, 0, 2), status(p)) + if useclose: + c2.close() + else: + c2 = None + lazy_gc() + + self.assert_(status(p) == (3, 2, 0, 1)) + + c1.close() + + lazy_gc() + assert not pool._refs - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - c1 = p.connect() - c_id = c1.connection.id - c1.close(); c1=None - - c1 = p.connect() - assert c1.connection.id == c_id - dbapi.raise_error = True - c1.invalidate() - c1 = None - - c1 = p.connect() - assert c1.connection.id != c_id - - def test_detach(self): - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - - c1 = p.connect() - c1.detach() - c_id = c1.connection.id - - c2 = p.connect() - assert c2.connection.id != c1.connection.id - dbapi.raise_error = True - - c2.invalidate() - c2 = None - - c2 = p.connect() - assert c2.connection.id != c1.connection.id - - con = c1.connection - - assert not con.closed - c1.close() - assert con.closed - - def test_threadfairy(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c1.close() - c2 = p.connect() - assert c2.connection is not None + def test_timeout(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2) + c1 = p.connect() + c2 = p.connect() + c3 = p.connect() + now = time.time() + try: + c4 = p.connect() + assert False + except tsa.exc.TimeoutError, e: + assert int(time.time() - now) == 2 + + def test_timeout_race(self): + # test a race condition where the initial connecting threads all race + # to queue.Empty, then block on the mutex. each thread consumes a + # connection as they go in. when the limit is reached, the remaining + # threads go in, and get TimeoutError; even though they never got to + # wait for the timeout on queue.get(). the fix involves checking the + # timeout again within the mutex, and if so, unlocking and throwing + # them back to the start of do_get() + p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3) + timeouts = [] + def checkout(): + for x in xrange(1): + now = time.time() + try: + c1 = p.connect() + except tsa.exc.TimeoutError, e: + timeouts.append(int(time.time()) - now) + continue + time.sleep(4) + c1.close() + + threads = [] + for i in xrange(10): + th = threading.Thread(target=checkout) + th.start() + threads.append(th) + for th in threads: + th.join() + + print timeouts + assert len(timeouts) > 0 + for t in timeouts: + assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts) + + def _test_overflow(self, thread_count, max_overflow): + def creator(): + time.sleep(.05) + return mock_dbapi.connect() + + p = pool.QueuePool(creator=creator, + pool_size=3, timeout=2, + max_overflow=max_overflow) + peaks = [] + def whammy(): + for i in range(10): + try: + con = p.connect() + time.sleep(.005) + peaks.append(p.overflow()) + con.close() + del con + except tsa.exc.TimeoutError: + pass + threads = [] + for i in xrange(thread_count): + th = threading.Thread(target=whammy) + th.start() + threads.append(th) + for th in threads: + th.join() + + self.assert_(max(peaks) <= max_overflow) + + lazy_gc() + assert not pool._refs + + def test_no_overflow(self): + self._test_overflow(40, 0) + + def test_max_overflow(self): + self._test_overflow(40, 5) + + def test_mixed_close(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c2 = p.connect() + assert c1 is c2 + c1.close() + c2 = None + assert p.checkedout() == 1 + c1 = None + lazy_gc() + assert p.checkedout() == 0 + + lazy_gc() + assert not pool._refs + + def test_weakref_kaboom(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c2 = p.connect() + c1.close() + c2 = None + del c1 + del c2 + gc_collect() + assert p.checkedout() == 0 + c3 = p.connect() + assert c3 is not None + + def test_trick_the_counter(self): + """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread + with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an + ambiguous counter. i.e. its not true reference counting.""" + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c2 = p.connect() + assert c1 is c2 + c1.close() + c2 = p.connect() + c2.close() + self.assert_(p.checkedout() != 0) + + c2.close() + self.assert_(p.checkedout() == 0) + + def test_recycle(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3) + + c1 = p.connect() + c_id = id(c1.connection) + c1.close() + c2 = p.connect() + assert id(c2.connection) == c_id + c2.close() + time.sleep(4) + c3= p.connect() + assert id(c3.connection) != c_id + + def test_invalidate(self): + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + c1 = p.connect() + c_id = c1.connection.id + c1.close(); c1=None + c1 = p.connect() + assert c1.connection.id == c_id + c1.invalidate() + c1 = None + + c1 = p.connect() + assert c1.connection.id != c_id + + def test_recreate(self): + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + p2 = p.recreate() + assert p2.size() == 1 + assert p2._use_threadlocal is False + assert p2._max_overflow == 0 + + def test_reconnect(self): + """tests reconnect operations at the pool level. SA's engine/dialect includes another + layer of reconnect support for 'database was lost' errors.""" + + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + c1 = p.connect() + c_id = c1.connection.id + c1.close(); c1=None + + c1 = p.connect() + assert c1.connection.id == c_id + dbapi.raise_error = True + c1.invalidate() + c1 = None + + c1 = p.connect() + assert c1.connection.id != c_id + + def test_detach(self): + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + + c1 = p.connect() + c1.detach() + c_id = c1.connection.id + + c2 = p.connect() + assert c2.connection.id != c1.connection.id + dbapi.raise_error = True + + c2.invalidate() + c2 = None + + c2 = p.connect() + assert c2.connection.id != c1.connection.id + + con = c1.connection + + assert not con.closed + c1.close() + assert con.closed + + def test_threadfairy(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c1.close() + c2 = p.connect() + assert c2.connection is not None class SingletonThreadPoolTest(PoolTestBase): def test_cleanup(self): diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index 3a525c2a7..6afd71515 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -1,12 +1,13 @@ from sqlalchemy.test.testing import eq_ +import time import weakref from sqlalchemy import select, MetaData, Integer, String, pool from sqlalchemy.test.schema import Table from sqlalchemy.test.schema import Column import sqlalchemy as tsa from sqlalchemy.test import TestBase, testing, engines -import time -import gc +from sqlalchemy.test.util import gc_collect + class MockDisconnect(Exception): pass @@ -54,7 +55,7 @@ class MockReconnectTest(TestBase): dbapi = MockDBAPI() # create engine using our current dburi - db = tsa.create_engine('postgres://foo:bar@localhost/test', module=dbapi) + db = tsa.create_engine('postgresql://foo:bar@localhost/test', module=dbapi, _initialize=False) # monkeypatch disconnect checker db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect) @@ -98,7 +99,7 @@ class MockReconnectTest(TestBase): assert id(db.pool) != pid # ensure all connections closed (pool was recycled) - gc.collect() + gc_collect() assert len(dbapi.connections) == 0 conn =db.connect() @@ -118,7 +119,7 @@ class MockReconnectTest(TestBase): pass # assert was invalidated - gc.collect() + gc_collect() assert len(dbapi.connections) == 0 assert not conn.closed assert conn.invalidated @@ -168,7 +169,7 @@ class MockReconnectTest(TestBase): assert conn.invalidated # ensure all connections closed (pool was recycled) - gc.collect() + gc_collect() assert len(dbapi.connections) == 0 # test reconnects @@ -334,7 +335,8 @@ class InvalidateDuringResultTest(TestBase): meta.drop_all() engine.dispose() - @testing.fails_on('mysql', 'FIXME: unknown') + @testing.fails_on('+mysqldb', "Buffers the result set and doesn't check for connection close") + @testing.fails_on('+pg8000', "Buffers the result set and doesn't check for connection close") def test_invalidate_on_results(self): conn = engine.connect() @@ -344,7 +346,7 @@ class InvalidateDuringResultTest(TestBase): engine.test_shutdown() try: - result.fetchone() + print "ghost result: %r" % result.fetchone() assert False except tsa.exc.DBAPIError, e: if not e.connection_invalidated: diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index ea80776a6..dff9fa1bb 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -1,17 +1,22 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import StringIO, unicodedata import sqlalchemy as sa +from sqlalchemy import types as sql_types +from sqlalchemy import schema +from sqlalchemy.engine.reflection import Inspector from sqlalchemy import MetaData from sqlalchemy.test.schema import Table from sqlalchemy.test.schema import Column import sqlalchemy as tsa from sqlalchemy.test import TestBase, ComparesTables, testing, engines +create_inspector = Inspector.from_engine metadata, users = None, None class ReflectionTest(TestBase, ComparesTables): + @testing.exclude('mssql', '<', (10, 0, 0), 'Date is only supported on MSSQL 2008+') @testing.exclude('mysql', '<', (4, 1, 1), 'early types are squirrely') def test_basic_reflection(self): meta = MetaData(testing.db) @@ -22,16 +27,16 @@ class ReflectionTest(TestBase, ComparesTables): Column('test1', sa.CHAR(5), nullable=False), Column('test2', sa.Float(5), nullable=False), Column('test3', sa.Text), - Column('test4', sa.Numeric, nullable = False), - Column('test5', sa.DateTime), + Column('test4', sa.Numeric(10, 2), nullable = False), + Column('test5', sa.Date), Column('parent_user_id', sa.Integer, sa.ForeignKey('engine_users.user_id')), - Column('test6', sa.DateTime, nullable=False), + Column('test6', sa.Date, nullable=False), Column('test7', sa.Text), Column('test8', sa.Binary), Column('test_passivedefault2', sa.Integer, server_default='5'), Column('test9', sa.Binary(100)), - Column('test_numeric', sa.Numeric()), + Column('test10', sa.Numeric(10, 2)), test_needs_fk=True, ) @@ -52,9 +57,35 @@ class ReflectionTest(TestBase, ComparesTables): self.assert_tables_equal(users, reflected_users) self.assert_tables_equal(addresses, reflected_addresses) finally: - addresses.drop() - users.drop() - + meta.drop_all() + + def test_two_foreign_keys(self): + meta = MetaData(testing.db) + t1 = Table('t1', meta, + Column('id', sa.Integer, primary_key=True), + Column('t2id', sa.Integer, sa.ForeignKey('t2.id')), + Column('t3id', sa.Integer, sa.ForeignKey('t3.id')), + test_needs_fk=True + ) + t2 = Table('t2', meta, + Column('id', sa.Integer, primary_key=True), + test_needs_fk=True + ) + t3 = Table('t3', meta, + Column('id', sa.Integer, primary_key=True), + test_needs_fk=True + ) + meta.create_all() + try: + meta2 = MetaData() + t1r, t2r, t3r = [Table(x, meta2, autoload=True, autoload_with=testing.db) for x in ('t1', 't2', 't3')] + + assert t1r.c.t2id.references(t2r.c.id) + assert t1r.c.t3id.references(t3r.c.id) + + finally: + meta.drop_all() + def test_include_columns(self): meta = MetaData(testing.db) foo = Table('foo', meta, *[Column(n, sa.String(30)) @@ -84,26 +115,68 @@ class ReflectionTest(TestBase, ComparesTables): finally: meta.drop_all() + @testing.emits_warning(r".*omitted columns") + def test_include_columns_indexes(self): + m = MetaData(testing.db) + + t1 = Table('t1', m, Column('a', sa.Integer), Column('b', sa.Integer)) + sa.Index('foobar', t1.c.a, t1.c.b) + sa.Index('bat', t1.c.a) + m.create_all() + try: + m2 = MetaData(testing.db) + t2 = Table('t1', m2, autoload=True) + assert len(t2.indexes) == 2 + m2 = MetaData(testing.db) + t2 = Table('t1', m2, autoload=True, include_columns=['a']) + assert len(t2.indexes) == 1 + + m2 = MetaData(testing.db) + t2 = Table('t1', m2, autoload=True, include_columns=['a', 'b']) + assert len(t2.indexes) == 2 + finally: + m.drop_all() + + def test_autoincrement_col(self): + """test that 'autoincrement' is reflected according to sqla's policy. + + Don't mark this test as unsupported for any backend ! + + (technically it fails with MySQL InnoDB since "id" comes before "id2") + + """ + + meta = MetaData(testing.db) + t1 = Table('test', meta, + Column('id', sa.Integer, primary_key=True), + Column('data', sa.String(50)), + ) + t2 = Table('test2', meta, + Column('id', sa.Integer, sa.ForeignKey('test.id'), primary_key=True), + Column('id2', sa.Integer, primary_key=True), + Column('data', sa.String(50)), + ) + meta.create_all() + try: + m2 = MetaData(testing.db) + t1a = Table('test', m2, autoload=True) + assert t1a._autoincrement_column is t1a.c.id + + t2a = Table('test2', m2, autoload=True) + assert t2a._autoincrement_column is t2a.c.id2 + + finally: + meta.drop_all() + def test_unknown_types(self): meta = MetaData(testing.db) t = Table("test", meta, Column('foo', sa.DateTime)) - import sys - dialect_module = sys.modules[testing.db.dialect.__module__] - - # we're relying on the presence of "ischema_names" in the - # dialect module, else we can't test this. we need to be able - # to get the dialect to not be aware of some type so we temporarily - # monkeypatch. not sure what a better way for this could be, - # except for an established dialect hook or dialect-specific tests - if not hasattr(dialect_module, 'ischema_names'): - return - - ischema_names = dialect_module.ischema_names + ischema_names = testing.db.dialect.ischema_names t.create() - dialect_module.ischema_names = {} + testing.db.dialect.ischema_names = {} try: m2 = MetaData(testing.db) assert_raises(tsa.exc.SAWarning, Table, "test", m2, autoload=True) @@ -115,7 +188,7 @@ class ReflectionTest(TestBase, ComparesTables): assert t3.c.foo.type.__class__ == sa.types.NullType finally: - dialect_module.ischema_names = ischema_names + testing.db.dialect.ischema_names = ischema_names t.drop() def test_basic_override(self): @@ -578,7 +651,6 @@ class ReflectionTest(TestBase, ComparesTables): m9.reflect() self.assert_(not m9.tables) - @testing.fails_on_everything_except('postgres', 'mysql') def test_index_reflection(self): m1 = MetaData(testing.db) t1 = Table('party', m1, @@ -698,7 +770,7 @@ class UnicodeReflectionTest(TestBase): def test_basic(self): try: # the 'convert_unicode' should not get in the way of the reflection - # process. reflecttable for oracle, postgres (others?) expect non-unicode + # process. reflecttable for oracle, postgresql (others?) expect non-unicode # strings in result sets/bind params bind = engines.utf8_engine(options={'convert_unicode':True}) metadata = MetaData(bind) @@ -713,7 +785,8 @@ class UnicodeReflectionTest(TestBase): metadata.create_all() reflected = set(bind.table_names()) - if not names.issubset(reflected): + # Jython 2.5 on Java 5 lacks unicodedata.normalize + if not names.issubset(reflected) and hasattr(unicodedata, 'normalize'): # Python source files in the utf-8 coding seem to normalize # literals as NFC (and the above are explicitly NFC). Maybe # this database normalizes NFD on reflection. @@ -741,23 +814,15 @@ class SchemaTest(TestBase): Column('col1', sa.Integer, primary_key=True), Column('col2', sa.Integer, sa.ForeignKey('someschema.table1.col1')), schema='someschema') - # ensure this doesnt crash - print [t for t in metadata.sorted_tables] - buf = StringIO.StringIO() - def foo(s, p=None): - buf.write(s) - gen = sa.create_engine(testing.db.name + "://", strategy="mock", executor=foo) - gen = gen.dialect.schemagenerator(gen.dialect, gen) - gen.traverse(table1) - gen.traverse(table2) - buf = buf.getvalue() - print buf + + t1 = str(schema.CreateTable(table1).compile(bind=testing.db)) + t2 = str(schema.CreateTable(table2).compile(bind=testing.db)) if testing.db.dialect.preparer(testing.db.dialect).omit_schema: - assert buf.index("CREATE TABLE table1") > -1 - assert buf.index("CREATE TABLE table2") > -1 + assert t1.index("CREATE TABLE table1") > -1 + assert t2.index("CREATE TABLE table2") > -1 else: - assert buf.index("CREATE TABLE someschema.table1") > -1 - assert buf.index("CREATE TABLE someschema.table2") > -1 + assert t1.index("CREATE TABLE someschema.table1") > -1 + assert t2.index("CREATE TABLE someschema.table2") > -1 @testing.crashes('firebird', 'No schema support') @testing.fails_on('sqlite', 'FIXME: unknown') @@ -767,9 +832,9 @@ class SchemaTest(TestBase): def test_explicit_default_schema(self): engine = testing.db - if testing.against('mysql'): + if testing.against('mysql+mysqldb'): schema = testing.db.url.database - elif testing.against('postgres'): + elif testing.against('postgresql'): schema = 'public' elif testing.against('sqlite'): # Works for CREATE TABLE main.foo, SELECT FROM main.foo, etc., @@ -820,4 +885,324 @@ class HasSequenceTest(TestBase): metadata.drop_all(bind=testing.db) eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), False) +# Tests related to engine.reflection + +def get_schema(): + if testing.against('oracle'): + return 'scott' + return 'test_schema' + +def createTables(meta, schema=None): + if schema: + parent_user_id = Column('parent_user_id', sa.Integer, + sa.ForeignKey('%s.users.user_id' % schema) + ) + else: + parent_user_id = Column('parent_user_id', sa.Integer, + sa.ForeignKey('users.user_id') + ) + + users = Table('users', meta, + Column('user_id', sa.INT, primary_key=True), + Column('user_name', sa.VARCHAR(20), nullable=False), + Column('test1', sa.CHAR(5), nullable=False), + Column('test2', sa.Float(5), nullable=False), + Column('test3', sa.Text), + Column('test4', sa.Numeric(10, 2), nullable = False), + Column('test5', sa.DateTime), + Column('test5-1', sa.TIMESTAMP), + parent_user_id, + Column('test6', sa.DateTime, nullable=False), + Column('test7', sa.Text), + Column('test8', sa.Binary), + Column('test_passivedefault2', sa.Integer, server_default='5'), + Column('test9', sa.Binary(100)), + Column('test10', sa.Numeric(10, 2)), + schema=schema, + test_needs_fk=True, + ) + addresses = Table('email_addresses', meta, + Column('address_id', sa.Integer, primary_key = True), + Column('remote_user_id', sa.Integer, + sa.ForeignKey(users.c.user_id)), + Column('email_address', sa.String(20)), + schema=schema, + test_needs_fk=True, + ) + return (users, addresses) + +def createIndexes(con, schema=None): + fullname = 'users' + if schema: + fullname = "%s.%s" % (schema, 'users') + query = "CREATE INDEX users_t_idx ON %s (test1, test2)" % fullname + con.execute(sa.sql.text(query)) + +def createViews(con, schema=None): + for table_name in ('users', 'email_addresses'): + fullname = table_name + if schema: + fullname = "%s.%s" % (schema, table_name) + view_name = fullname + '_v' + query = "CREATE VIEW %s AS SELECT * FROM %s" % (view_name, + fullname) + con.execute(sa.sql.text(query)) + +def dropViews(con, schema=None): + for table_name in ('email_addresses', 'users'): + fullname = table_name + if schema: + fullname = "%s.%s" % (schema, table_name) + view_name = fullname + '_v' + query = "DROP VIEW %s" % view_name + con.execute(sa.sql.text(query)) + + +class ComponentReflectionTest(TestBase): + + @testing.requires.schemas + def test_get_schema_names(self): + meta = MetaData(testing.db) + insp = Inspector(meta.bind) + + self.assert_(get_schema() in insp.get_schema_names()) + + def _test_get_table_names(self, schema=None, table_type='table', + order_by=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + createViews(meta.bind, schema) + try: + insp = Inspector(meta.bind) + if table_type == 'view': + table_names = insp.get_view_names(schema) + table_names.sort() + answer = ['email_addresses_v', 'users_v'] + else: + table_names = insp.get_table_names(schema, + order_by=order_by) + table_names.sort() + if order_by == 'foreign_key': + answer = ['users', 'email_addresses'] + else: + answer = ['email_addresses', 'users'] + eq_(table_names, answer) + finally: + dropViews(meta.bind, schema) + addresses.drop() + users.drop() + + def test_get_table_names(self): + self._test_get_table_names() + + @testing.requires.schemas + def test_get_table_names_with_schema(self): + self._test_get_table_names(get_schema()) + + def test_get_view_names(self): + self._test_get_table_names(table_type='view') + + @testing.requires.schemas + def test_get_view_names_with_schema(self): + self._test_get_table_names(get_schema(), table_type='view') + + def _test_get_columns(self, schema=None, table_type='table'): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + table_names = ['users', 'email_addresses'] + meta.create_all() + if table_type == 'view': + createViews(meta.bind, schema) + table_names = ['users_v', 'email_addresses_v'] + try: + insp = Inspector(meta.bind) + for (table_name, table) in zip(table_names, (users, addresses)): + schema_name = schema + cols = insp.get_columns(table_name, schema=schema_name) + self.assert_(len(cols) > 0, len(cols)) + # should be in order + for (i, col) in enumerate(table.columns): + eq_(col.name, cols[i]['name']) + ctype = cols[i]['type'].__class__ + ctype_def = col.type + if isinstance(ctype_def, sa.types.TypeEngine): + ctype_def = ctype_def.__class__ + + # Oracle returns Date for DateTime. + if testing.against('oracle') \ + and ctype_def in (sql_types.Date, sql_types.DateTime): + ctype_def = sql_types.Date + + # assert that the desired type and return type + # share a base within one of the generic types. + self.assert_( + len( + set( + ctype.__mro__ + ).intersection(ctype_def.__mro__) + .intersection([sql_types.Integer, sql_types.Numeric, + sql_types.DateTime, sql_types.Date, sql_types.Time, + sql_types.String, sql_types.Binary]) + ) > 0 + ,("%s(%s), %s(%s)" % (col.name, col.type, cols[i]['name'], + ctype))) + finally: + if table_type == 'view': + dropViews(meta.bind, schema) + addresses.drop() + users.drop() + + def test_get_columns(self): + self._test_get_columns() + + @testing.requires.schemas + def test_get_columns_with_schema(self): + self._test_get_columns(schema=get_schema()) + + def test_get_view_columns(self): + self._test_get_columns(table_type='view') + + @testing.requires.schemas + def test_get_view_columns_with_schema(self): + self._test_get_columns(schema=get_schema(), table_type='view') + + def _test_get_primary_keys(self, schema=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + insp = Inspector(meta.bind) + try: + users_pkeys = insp.get_primary_keys(users.name, + schema=schema) + eq_(users_pkeys, ['user_id']) + addr_pkeys = insp.get_primary_keys(addresses.name, + schema=schema) + eq_(addr_pkeys, ['address_id']) + + finally: + addresses.drop() + users.drop() + + def test_get_primary_keys(self): + self._test_get_primary_keys() + + @testing.fails_on('sqlite', 'no schemas') + def test_get_primary_keys_with_schema(self): + self._test_get_primary_keys(schema=get_schema()) + + def _test_get_foreign_keys(self, schema=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + insp = Inspector(meta.bind) + try: + expected_schema = schema + # users + users_fkeys = insp.get_foreign_keys(users.name, + schema=schema) + fkey1 = users_fkeys[0] + self.assert_(fkey1['name'] is not None) + eq_(fkey1['referred_schema'], expected_schema) + eq_(fkey1['referred_table'], users.name) + eq_(fkey1['referred_columns'], ['user_id', ]) + eq_(fkey1['constrained_columns'], ['parent_user_id']) + #addresses + addr_fkeys = insp.get_foreign_keys(addresses.name, + schema=schema) + fkey1 = addr_fkeys[0] + self.assert_(fkey1['name'] is not None) + eq_(fkey1['referred_schema'], expected_schema) + eq_(fkey1['referred_table'], users.name) + eq_(fkey1['referred_columns'], ['user_id', ]) + eq_(fkey1['constrained_columns'], ['remote_user_id']) + finally: + addresses.drop() + users.drop() + + def test_get_foreign_keys(self): + self._test_get_foreign_keys() + + @testing.requires.schemas + def test_get_foreign_keys_with_schema(self): + self._test_get_foreign_keys(schema=get_schema()) + + def _test_get_indexes(self, schema=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + createIndexes(meta.bind, schema) + try: + # The database may decide to create indexes for foreign keys, etc. + # so there may be more indexes than expected. + insp = Inspector(meta.bind) + indexes = insp.get_indexes('users', schema=schema) + indexes.sort() + expected_indexes = [ + {'unique': False, + 'column_names': ['test1', 'test2'], + 'name': 'users_t_idx'}] + index_names = [d['name'] for d in indexes] + for e_index in expected_indexes: + assert e_index['name'] in index_names + index = indexes[index_names.index(e_index['name'])] + for key in e_index: + eq_(e_index[key], index[key]) + + finally: + addresses.drop() + users.drop() + + def test_get_indexes(self): + self._test_get_indexes() + + @testing.requires.schemas + def test_get_indexes_with_schema(self): + self._test_get_indexes(schema=get_schema()) + + def _test_get_view_definition(self, schema=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + createViews(meta.bind, schema) + view_name1 = 'users_v' + view_name2 = 'email_addresses_v' + try: + insp = Inspector(meta.bind) + v1 = insp.get_view_definition(view_name1, schema=schema) + self.assert_(v1) + v2 = insp.get_view_definition(view_name2, schema=schema) + self.assert_(v2) + finally: + dropViews(meta.bind, schema) + addresses.drop() + users.drop() + + def test_get_view_definition(self): + self._test_get_view_definition() + + @testing.requires.schemas + def test_get_view_definition_with_schema(self): + self._test_get_view_definition(schema=get_schema()) + + def _test_get_table_oid(self, table_name, schema=None): + if testing.against('postgresql'): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + try: + insp = create_inspector(meta.bind) + oid = insp.get_table_oid(table_name, schema) + self.assert_(isinstance(oid, (int, long))) + finally: + addresses.drop() + users.drop() + + def test_get_table_oid(self): + self._test_get_table_oid('users') + + @testing.requires.schemas + def test_get_table_oid_with_schema(self): + self._test_get_table_oid('users', schema=get_schema()) + diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py index 6698259a4..8e3f3412d 100644 --- a/test/engine/test_transaction.py +++ b/test/engine/test_transaction.py @@ -20,7 +20,8 @@ class TransactionTest(TestBase): users.create(testing.db) def teardown(self): - testing.db.connect().execute(users.delete()) + testing.db.execute(users.delete()).close() + @classmethod def teardown_class(cls): users.drop(testing.db) @@ -40,6 +41,7 @@ class TransactionTest(TestBase): result = connection.execute("select * from query_users") assert len(result.fetchall()) == 3 transaction.commit() + connection.close() def test_rollback(self): """test a basic rollback""" @@ -176,6 +178,7 @@ class TransactionTest(TestBase): connection.close() @testing.requires.savepoints + @testing.crashes('oracle+zxjdbc', 'Errors out and causes subsequent tests to deadlock') def test_nested_subtransaction_commit(self): connection = testing.db.connect() transaction = connection.begin() @@ -274,6 +277,7 @@ class TransactionTest(TestBase): connection.close() @testing.requires.two_phase_transactions + @testing.crashes('mysql+zxjdbc', 'Deadlocks, causing subsequent tests to fail') @testing.fails_on('mysql', 'FIXME: unknown') def test_two_phase_recover(self): # MySQL recovery doesn't currently seem to work correctly @@ -369,7 +373,7 @@ class ExplicitAutoCommitTest(TestBase): Requires PostgreSQL so that we may define a custom function which modifies the database. """ - __only_on__ = 'postgres' + __only_on__ = 'postgresql' @classmethod def setup_class(cls): @@ -380,7 +384,7 @@ class ExplicitAutoCommitTest(TestBase): testing.db.execute("create function insert_foo(varchar) returns integer as 'insert into foo(data) values ($1);select 1;' language sql") def teardown(self): - foo.delete().execute() + foo.delete().execute().close() @classmethod def teardown_class(cls): @@ -453,8 +457,10 @@ class TLTransactionTest(TestBase): test_needs_acid=True, ) users.create(tlengine) + def teardown(self): - tlengine.execute(users.delete()) + tlengine.execute(users.delete()).close() + @classmethod def teardown_class(cls): users.drop(tlengine) @@ -497,6 +503,7 @@ class TLTransactionTest(TestBase): try: assert len(result.fetchall()) == 0 finally: + c.close() external_connection.close() def test_rollback(self): @@ -530,7 +537,9 @@ class TLTransactionTest(TestBase): external_connection.close() def test_commits(self): - assert tlengine.connect().execute("select count(1) from query_users").scalar() == 0 + connection = tlengine.connect() + assert connection.execute("select count(1) from query_users").scalar() == 0 + connection.close() connection = tlengine.contextual_connect() transaction = connection.begin() @@ -547,6 +556,7 @@ class TLTransactionTest(TestBase): l = result.fetchall() assert len(l) == 3, "expected 3 got %d" % len(l) transaction.commit() + connection.close() def test_rollback_off_conn(self): # test that a TLTransaction opened off a TLConnection allows that @@ -563,6 +573,7 @@ class TLTransactionTest(TestBase): try: assert len(result.fetchall()) == 0 finally: + conn.close() external_connection.close() def test_morerollback_off_conn(self): @@ -581,6 +592,8 @@ class TLTransactionTest(TestBase): try: assert len(result.fetchall()) == 0 finally: + conn.close() + conn2.close() external_connection.close() def test_commit_off_connection(self): @@ -596,6 +609,7 @@ class TLTransactionTest(TestBase): try: assert len(result.fetchall()) == 3 finally: + conn.close() external_connection.close() def test_nesting(self): @@ -712,8 +726,10 @@ class ForUpdateTest(TestBase): test_needs_acid=True, ) counters.create(testing.db) + def teardown(self): - testing.db.connect().execute(counters.delete()) + testing.db.execute(counters.delete()).close() + @classmethod def teardown_class(cls): counters.drop(testing.db) @@ -726,7 +742,7 @@ class ForUpdateTest(TestBase): for i in xrange(count): trans = con.begin() try: - existing = con.execute(sel).fetchone() + existing = con.execute(sel).first() incr = existing['counter_value'] + 1 time.sleep(delay) @@ -734,7 +750,7 @@ class ForUpdateTest(TestBase): values={'counter_value':incr})) time.sleep(delay) - readback = con.execute(sel).fetchone() + readback = con.execute(sel).first() if (readback['counter_value'] != incr): raise AssertionError("Got %s post-update, expected %s" % (readback['counter_value'], incr)) @@ -778,7 +794,7 @@ class ForUpdateTest(TestBase): self.assert_(len(errors) == 0) sel = counters.select(whereclause=counters.c.counter_id==1) - final = db.execute(sel).fetchone() + final = db.execute(sel).first() self.assert_(final['counter_value'] == iterations * thread_count) def overlap(self, ids, errors, update_style): diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index 8df449718..4a5775218 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -1,6 +1,5 @@ from sqlalchemy.test.testing import eq_, assert_raises import copy -import gc import pickle from sqlalchemy import * @@ -8,6 +7,7 @@ from sqlalchemy.orm import * from sqlalchemy.orm.collections import collection from sqlalchemy.ext.associationproxy import * from sqlalchemy.test import * +from sqlalchemy.test.util import gc_collect class DictCollection(dict): @@ -880,7 +880,7 @@ class ReconstitutionTest(TestBase): add_child('p1', 'c1') - gc.collect() + gc_collect() add_child('p1', 'c2') session.flush() @@ -895,7 +895,7 @@ class ReconstitutionTest(TestBase): p.kids.extend(['c1', 'c2']) p_copy = copy.copy(p) del p - gc.collect() + gc_collect() assert set(p_copy.kids) == set(['c1', 'c2']), p.kids diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index ce2549099..3ee94d027 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -1,5 +1,7 @@ from sqlalchemy import * +from sqlalchemy.types import TypeEngine from sqlalchemy.sql.expression import ClauseElement, ColumnClause +from sqlalchemy.schema import DDLElement from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import table, column from sqlalchemy.test import * @@ -25,7 +27,35 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL): select([MyThingy('x'), MyThingy('y')]).where(MyThingy() == 5), "SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1" ) + + def test_types(self): + class MyType(TypeEngine): + pass + + @compiles(MyType, 'sqlite') + def visit_type(type, compiler, **kw): + return "SQLITE_FOO" + + @compiles(MyType, 'postgresql') + def visit_type(type, compiler, **kw): + return "POSTGRES_FOO" + + from sqlalchemy.dialects.sqlite import base as sqlite + from sqlalchemy.dialects.postgresql import base as postgresql + self.assert_compile( + MyType(), + "SQLITE_FOO", + dialect=sqlite.dialect() + ) + + self.assert_compile( + MyType(), + "POSTGRES_FOO", + dialect=postgresql.dialect() + ) + + def test_stateful(self): class MyThingy(ColumnClause): def __init__(self): @@ -71,10 +101,10 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL): ) def test_dialect_specific(self): - class AddThingy(ClauseElement): + class AddThingy(DDLElement): __visit_name__ = 'add_thingy' - class DropThingy(ClauseElement): + class DropThingy(DDLElement): __visit_name__ = 'drop_thingy' @compiles(AddThingy, 'sqlite') @@ -97,7 +127,7 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL): "DROP THINGY" ) - from sqlalchemy.databases import sqlite as base + from sqlalchemy.dialects.sqlite import base self.assert_compile(AddThingy(), "ADD SPECIAL SL THINGY", dialect=base.dialect() diff --git a/test/ext/test_declarative.py b/test/ext/test_declarative.py index 224f41731..745e3b7cf 100644 --- a/test/ext/test_declarative.py +++ b/test/ext/test_declarative.py @@ -5,8 +5,7 @@ from sqlalchemy import exc import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import MetaData, Integer, String, ForeignKey, ForeignKeyConstraint, asc, Index -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import relation, create_session, class_mapper, eagerload, compile_mappers, backref, clear_mappers, polymorphic_union, deferred from sqlalchemy.test.testing import eq_ @@ -27,14 +26,14 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relation("Address", backref="user") class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) email = Column(String(50), key='_email') user_id = Column('user_id', Integer, ForeignKey('users.id'), key='_user_id') @@ -127,7 +126,7 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) name = Column(String(50)) addresses = relation("Address", order_by="desc(Address.email)", primaryjoin="User.id==Address.user_id", foreign_keys="[Address.user_id]", @@ -136,7 +135,7 @@ class DeclarativeTest(DeclarativeTestBase): class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) email = Column(String(50)) user_id = Column(Integer) # note no foreign key @@ -180,13 +179,13 @@ class DeclarativeTest(DeclarativeTestBase): def test_uncompiled_attributes_in_relation(self): class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) email = Column(String(50)) user_id = Column(Integer, ForeignKey('users.id')) class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) name = Column(String(50)) addresses = relation("Address", order_by=Address.email, foreign_keys=Address.user_id, @@ -272,14 +271,14 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) User.name = Column('name', String(50)) User.addresses = relation("Address", backref="user") class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) Address.email = Column(String(50), key='_email') Address.user_id = Column('user_id', Integer, ForeignKey('users.id'), key='_user_id') @@ -312,14 +311,14 @@ class DeclarativeTest(DeclarativeTestBase): class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) email = Column('email', String(50)) user_id = Column('user_id', Integer, ForeignKey('users.id')) class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relation("Address", order_by=Address.email) @@ -341,14 +340,14 @@ class DeclarativeTest(DeclarativeTestBase): class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) email = Column('email', String(50)) user_id = Column('user_id', Integer, ForeignKey('users.id')) class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relation("Address", order_by=(Address.email, Address.id)) @@ -368,14 +367,14 @@ class DeclarativeTest(DeclarativeTestBase): class User(ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relation("Address", backref="user") class Address(ComparableEntity): __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) email = Column('email', String(50)) user_id = Column('user_id', Integer, ForeignKey('users.id')) @@ -478,14 +477,14 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relation("Address", backref="user") class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) email = Column('email', String(50)) user_id = Column('user_id', Integer, ForeignKey('users.id')) @@ -513,7 +512,7 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) User.a = Column('a', String(10)) @@ -535,14 +534,14 @@ class DeclarativeTest(DeclarativeTestBase): def test_column_properties(self): class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) email = Column(String(50)) user_id = Column(Integer, ForeignKey('users.id')) class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) adr_count = sa.orm.column_property( sa.select([sa.func.count(Address.id)], Address.user_id == id). @@ -588,7 +587,7 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) name = sa.orm.deferred(Column(String(50))) Base.metadata.create_all() @@ -607,7 +606,7 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) _name = Column('name', String(50)) def _set_name(self, name): self._name = "SOMENAME " + name @@ -636,7 +635,7 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) _name = Column('name', String(50)) name = sa.orm.synonym('_name', comparator_factory=CustomCompare) @@ -652,7 +651,7 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) _name = Column('name', String(50)) def _set_name(self, name): self._name = "SOMENAME " + name @@ -674,14 +673,14 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relation("Address", backref="user") class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) email = Column('email', String(50)) user_id = Column('user_id', Integer, ForeignKey(User.id)) @@ -711,14 +710,14 @@ class DeclarativeTest(DeclarativeTestBase): class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) email = Column('email', String(50)) user_id = Column('user_id', Integer, ForeignKey('users.id')) class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relation("Address", backref="user", primaryjoin=id == Address.user_id) @@ -763,7 +762,7 @@ class DeclarativeTest(DeclarativeTestBase): def test_with_explicit_autoloaded(self): meta = MetaData(testing.db) t1 = Table('t1', meta, - Column('id', String(50), primary_key=True), + Column('id', String(50), primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) meta.create_all() try: @@ -779,6 +778,70 @@ class DeclarativeTest(DeclarativeTestBase): finally: meta.drop_all() + def test_synonym_for(self): + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) + name = Column('name', String(50)) + + @decl.synonym_for('name') + @property + def namesyn(self): + return self.name + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='someuser') + eq_(u1.name, "someuser") + eq_(u1.namesyn, 'someuser') + sess.add(u1) + sess.flush() + + rt = sess.query(User).filter(User.namesyn == 'someuser').one() + eq_(rt, u1) + + def test_comparable_using(self): + class NameComparator(sa.orm.PropComparator): + @property + def upperself(self): + cls = self.prop.parent.class_ + col = getattr(cls, 'name') + return sa.func.upper(col) + + def operate(self, op, other, **kw): + return op(self.upperself, other, **kw) + + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) + name = Column('name', String(50)) + + @decl.comparable_using(NameComparator) + @property + def uc_name(self): + return self.name is not None and self.name.upper() or None + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='someuser') + eq_(u1.name, "someuser", u1.name) + eq_(u1.uc_name, 'SOMEUSER', u1.uc_name) + sess.add(u1) + sess.flush() + sess.expunge_all() + + rt = sess.query(User).filter(User.uc_name == 'SOMEUSER').one() + eq_(rt, u1) + sess.expunge_all() + + rt = sess.query(User).filter(User.uc_name.startswith('SOMEUSE')).one() + eq_(rt, u1) + + class DeclarativeInheritanceTest(DeclarativeTestBase): def test_custom_join_condition(self): class Foo(Base): @@ -797,13 +860,13 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): def test_joined(self): class Company(Base, ComparableEntity): __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) employees = relation("Person") class Person(Base, ComparableEntity): __tablename__ = 'people' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) company_id = Column('company_id', Integer, ForeignKey('companies.id')) name = Column('name', String(50)) @@ -911,13 +974,13 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): class Company(Base, ComparableEntity): __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) employees = relation("Person") class Person(Base, ComparableEntity): __tablename__ = 'people' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) company_id = Column('company_id', Integer, ForeignKey('companies.id')) name = Column('name', String(50)) @@ -967,13 +1030,13 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): class Company(Base, ComparableEntity): __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) employees = relation("Person") class Person(Base, ComparableEntity): __tablename__ = 'people' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) company_id = Column(Integer, ForeignKey('companies.id')) name = Column(String(50)) @@ -1037,13 +1100,13 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): def test_joined_from_single(self): class Company(Base, ComparableEntity): __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) employees = relation("Person") class Person(Base, ComparableEntity): __tablename__ = 'people' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) company_id = Column(Integer, ForeignKey('companies.id')) name = Column(String(50)) discriminator = Column('type', String(50)) @@ -1100,7 +1163,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): def test_add_deferred(self): class Person(Base, ComparableEntity): __tablename__ = 'people' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) Person.name = deferred(Column(String(10))) @@ -1117,6 +1180,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): Person(name='ratbert') ] ) + sess.expunge_all() + person = sess.query(Person).filter(Person.name == 'ratbert').one() assert 'name' not in person.__dict__ @@ -1127,7 +1192,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): class Person(Base, ComparableEntity): __tablename__ = 'people' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) name = Column(String(50)) discriminator = Column('type', String(50)) __mapper_args__ = {'polymorphic_on':discriminator} @@ -1139,7 +1204,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): class Language(Base, ComparableEntity): __tablename__ = 'languages' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) name = Column(String(50)) assert not hasattr(Person, 'primary_language_id') @@ -1236,12 +1301,12 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): def test_concrete(self): engineers = Table('engineers', Base.metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('primary_language', String(50)) ) managers = Table('managers', Base.metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('golf_swing', String(50)) ) @@ -1293,12 +1358,12 @@ def _produce_test(inline, stringbased): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) name = Column(String(50)) class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) email = Column(String(50)) user_id = Column(Integer, ForeignKey('users.id')) if inline: @@ -1363,16 +1428,16 @@ class DeclarativeReflectionTest(testing.TestBase): reflection_metadata = MetaData(testing.db) Table('users', reflection_metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), test_needs_fk=True) Table('addresses', reflection_metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('email', String(50)), Column('user_id', Integer, ForeignKey('users.id')), test_needs_fk=True) Table('imhandles', reflection_metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer), Column('network', String(50)), Column('handle', String(50)), @@ -1398,12 +1463,17 @@ class DeclarativeReflectionTest(testing.TestBase): class User(Base, ComparableEntity): __tablename__ = 'users' __autoload__ = True + if testing.against('oracle', 'firebird'): + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) addresses = relation("Address", backref="user") class Address(Base, ComparableEntity): __tablename__ = 'addresses' __autoload__ = True + if testing.against('oracle', 'firebird'): + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) + u1 = User(name='u1', addresses=[ Address(email='one'), Address(email='two'), @@ -1428,12 +1498,16 @@ class DeclarativeReflectionTest(testing.TestBase): class User(Base, ComparableEntity): __tablename__ = 'users' __autoload__ = True + if testing.against('oracle', 'firebird'): + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) nom = Column('name', String(50), key='nom') addresses = relation("Address", backref="user") class Address(Base, ComparableEntity): __tablename__ = 'addresses' __autoload__ = True + if testing.against('oracle', 'firebird'): + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) u1 = User(nom='u1', addresses=[ Address(email='one'), @@ -1461,12 +1535,16 @@ class DeclarativeReflectionTest(testing.TestBase): class IMHandle(Base, ComparableEntity): __tablename__ = 'imhandles' __autoload__ = True + if testing.against('oracle', 'firebird'): + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) user_id = Column('user_id', Integer, ForeignKey('users.id')) class User(Base, ComparableEntity): __tablename__ = 'users' __autoload__ = True + if testing.against('oracle', 'firebird'): + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) handles = relation("IMHandle", backref="user") u1 = User(name='u1', handles=[ @@ -1487,69 +1565,3 @@ class DeclarativeReflectionTest(testing.TestBase): eq_(a1, IMHandle(network='lol', handle='zomg')) eq_(a1.user, User(name='u1')) - def test_synonym_for(self): - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - - @decl.synonym_for('name') - @property - def namesyn(self): - return self.name - - Base.metadata.create_all() - - sess = create_session() - u1 = User(name='someuser') - eq_(u1.name, "someuser") - eq_(u1.namesyn, 'someuser') - sess.add(u1) - sess.flush() - - rt = sess.query(User).filter(User.namesyn == 'someuser').one() - eq_(rt, u1) - - def test_comparable_using(self): - class NameComparator(sa.orm.PropComparator): - @property - def upperself(self): - cls = self.prop.parent.class_ - col = getattr(cls, 'name') - return sa.func.upper(col) - - def operate(self, op, other, **kw): - return op(self.upperself, other, **kw) - - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - - @decl.comparable_using(NameComparator) - @property - def uc_name(self): - return self.name is not None and self.name.upper() or None - - Base.metadata.create_all() - - sess = create_session() - u1 = User(name='someuser') - eq_(u1.name, "someuser", u1.name) - eq_(u1.uc_name, 'SOMEUSER', u1.uc_name) - sess.add(u1) - sess.flush() - sess.expunge_all() - - rt = sess.query(User).filter(User.uc_name == 'SOMEUSER').one() - eq_(rt, u1) - sess.expunge_all() - - rt = sess.query(User).filter(User.uc_name.startswith('SOMEUSE')).one() - eq_(rt, u1) - - -if __name__ == '__main__': - testing.main() diff --git a/test/ext/test_serializer.py b/test/ext/test_serializer.py index b8a8e3fef..c400797b0 100644 --- a/test/ext/test_serializer.py +++ b/test/ext/test_serializer.py @@ -96,8 +96,6 @@ class SerializeTest(MappedTest): [(7, u'jack'), (8, u'ed'), (8, u'ed'), (8, u'ed'), (9, u'fred')] ) - # fails due to pure Python pickle bug: http://bugs.python.org/issue998998 - @testing.fails_if(lambda: util.py3k) def test_query(self): q = Session.query(User).filter(User.name=='ed').options(eagerload(User.addresses)) eq_(q.all(), [User(name='ed', addresses=[Address(id=2), Address(id=3), Address(id=4)])]) diff --git a/test/orm/_base.py b/test/orm/_base.py index 8d695e912..f08d253d5 100644 --- a/test/orm/_base.py +++ b/test/orm/_base.py @@ -1,10 +1,11 @@ -import gc import inspect import sys import types import sqlalchemy as sa +import sqlalchemy.exceptions as sa_exc from sqlalchemy.test import config, testing from sqlalchemy.test.testing import resolve_artifact_names, adict +from sqlalchemy.test.engines import drop_all_tables from sqlalchemy.util import function_named @@ -74,20 +75,19 @@ class ComparableEntity(BasicEntity): if attr.startswith('_'): continue value = getattr(a, attr) - if (hasattr(value, '__iter__') and - not isinstance(value, basestring)): - try: - # catch AttributeError so that lazy loaders trigger - battr = getattr(b, attr) - except AttributeError: - return False + try: + # handle lazy loader errors + battr = getattr(b, attr) + except (AttributeError, sa_exc.UnboundExecutionError): + return False + + if hasattr(value, '__iter__'): if list(value) != list(battr): return False else: - if value is not None: - if value != getattr(b, attr, None): - return False + if value is not None and value != battr: + return False return True finally: _recursion_stack.remove(id(self)) @@ -173,7 +173,7 @@ class MappedTest(ORMTest): def setup(self): if self.run_define_tables == 'each': self.tables.clear() - self.metadata.drop_all() + drop_all_tables(self.metadata) self.metadata.clear() self.define_tables(self.metadata) self.metadata.create_all() @@ -217,7 +217,7 @@ class MappedTest(ORMTest): for cl in cls.classes.values(): cls.unregister_class(cl) ORMTest.teardown_class() - cls.metadata.drop_all() + drop_all_tables(cls.metadata) cls.metadata.bind = None @classmethod diff --git a/test/orm/_fixtures.py b/test/orm/_fixtures.py index 931d8cadf..e9d6ac165 100644 --- a/test/orm/_fixtures.py +++ b/test/orm/_fixtures.py @@ -60,7 +60,7 @@ email_bounces = fixture_table( orders = fixture_table( Table('orders', fixture_metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', None, ForeignKey('users.id')), Column('address_id', None, ForeignKey('addresses.id')), Column('description', String(30)), @@ -76,7 +76,7 @@ orders = fixture_table( dingalings = fixture_table( Table("dingalings", fixture_metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('address_id', None, ForeignKey('addresses.id')), Column('data', String(30)), test_needs_acid=True, diff --git a/test/orm/inheritance/test_abc_inheritance.py b/test/orm/inheritance/test_abc_inheritance.py index 4e55cf70e..f6d5111b2 100644 --- a/test/orm/inheritance/test_abc_inheritance.py +++ b/test/orm/inheritance/test_abc_inheritance.py @@ -3,6 +3,7 @@ from sqlalchemy.orm import * from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE from sqlalchemy.test import testing +from sqlalchemy.test.schema import Table, Column from test.orm import _base @@ -15,7 +16,7 @@ def produce_test(parent, child, direction): def define_tables(cls, metadata): global ta, tb, tc ta = ["a", metadata] - ta.append(Column('id', Integer, primary_key=True)), + ta.append(Column('id', Integer, primary_key=True, test_needs_autoincrement=True)), ta.append(Column('a_data', String(30))) if "a"== parent and direction == MANYTOONE: ta.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo"))) diff --git a/test/orm/inheritance/test_abc_polymorphic.py b/test/orm/inheritance/test_abc_polymorphic.py index 8cad8ed78..2dab59bb2 100644 --- a/test/orm/inheritance/test_abc_polymorphic.py +++ b/test/orm/inheritance/test_abc_polymorphic.py @@ -4,13 +4,14 @@ from sqlalchemy.orm import * from sqlalchemy.util import function_named from test.orm import _base, _fixtures +from sqlalchemy.test.schema import Table, Column class ABCTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): global a, b, c a = Table('a', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('adata', String(30)), Column('type', String(30)), ) @@ -61,7 +62,7 @@ class ABCTest(_base.MappedTest): C(cdata='c1', bdata='c1', adata='c1'), C(cdata='c2', bdata='c2', adata='c2'), C(cdata='c2', bdata='c2', adata='c2'), - ] == sess.query(A).all() + ] == sess.query(A).order_by(A.id).all() assert [ B(bdata='b1', adata='b1'), diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index bad6920de..b2e00de35 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import exc as orm_exc from sqlalchemy.test import testing, engines from sqlalchemy.util import function_named from test.orm import _base, _fixtures +from sqlalchemy.test.schema import Table, Column class O2MTest(_base.MappedTest): """deals with inheritance and one-to-many relationships""" @@ -14,8 +15,7 @@ class O2MTest(_base.MappedTest): def define_tables(cls, metadata): global foo, bar, blub foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq', optional=True), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(20))) bar = Table('bar', metadata, @@ -73,9 +73,8 @@ class FalseDiscriminatorTest(_base.MappedTest): def define_tables(cls, metadata): global t1 t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True), - Column('type', Boolean, nullable=False) - ) + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('type', Boolean, nullable=False)) def test_false_on_sub(self): class Foo(object):pass @@ -108,7 +107,7 @@ class PolymorphicSynonymTest(_base.MappedTest): def define_tables(cls, metadata): global t1, t2 t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('type', String(10), nullable=False), Column('info', String(255))) t2 = Table('t2', metadata, @@ -149,12 +148,12 @@ class CascadeTest(_base.MappedTest): def define_tables(cls, metadata): global t1, t2, t3, t4 t1= Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)) ) t2 = Table('t2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('t1id', Integer, ForeignKey('t1.id')), Column('type', String(30)), Column('data', String(30)) @@ -164,7 +163,7 @@ class CascadeTest(_base.MappedTest): Column('moredata', String(30))) t4 = Table('t4', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('t3id', Integer, ForeignKey('t3.id')), Column('data', String(30))) @@ -214,8 +213,7 @@ class GetTest(_base.MappedTest): def define_tables(cls, metadata): global foo, bar, blub foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq', optional=True), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('type', String(30)), Column('data', String(20))) @@ -224,7 +222,7 @@ class GetTest(_base.MappedTest): Column('data', String(20))) blub = Table('blub', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('foo_id', Integer, ForeignKey('foo.id')), Column('bar_id', Integer, ForeignKey('bar.id')), Column('data', String(20))) @@ -304,8 +302,7 @@ class EagerLazyTest(_base.MappedTest): def define_tables(cls, metadata): global foo, bar, bar_foo foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq', optional=True), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30))) bar = Table('bar', metadata, Column('id', Integer, ForeignKey('foo.id'), primary_key=True), @@ -350,13 +347,13 @@ class FlushTest(_base.MappedTest): def define_tables(cls, metadata): global users, roles, user_roles, admins users = Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('email', String(128)), Column('password', String(16)), ) roles = Table('role', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('description', String(32)) ) @@ -366,7 +363,7 @@ class FlushTest(_base.MappedTest): ) admins = Table('admin', metadata, - Column('admin_id', Integer, primary_key=True), + Column('admin_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer, ForeignKey('users.id')) ) @@ -439,7 +436,7 @@ class VersioningTest(_base.MappedTest): def define_tables(cls, metadata): global base, subtable, stuff base = Table('base', metadata, - Column('id', Integer, Sequence('version_test_seq', optional=True), primary_key=True ), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('version_id', Integer, nullable=False), Column('value', String(40)), Column('discriminator', Integer, nullable=False) @@ -449,11 +446,10 @@ class VersioningTest(_base.MappedTest): Column('subdata', String(50)) ) stuff = Table('stuff', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('parent', Integer, ForeignKey('base.id')) ) - @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.') @engines.close_open_connections def test_save_update(self): class Base(_fixtures.Base): @@ -493,16 +489,16 @@ class VersioningTest(_base.MappedTest): try: sess2.flush() - assert False + assert not testing.db.dialect.supports_sane_rowcount except orm_exc.ConcurrentModificationError, e: assert True sess2.refresh(s2) - assert s2.subdata == 'sess1 subdata' + if testing.db.dialect.supports_sane_rowcount: + assert s2.subdata == 'sess1 subdata' s2.subdata = 'sess2 subdata' sess2.flush() - @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.') def test_delete(self): class Base(_fixtures.Base): pass @@ -534,7 +530,7 @@ class VersioningTest(_base.MappedTest): try: s1.subdata = 'some new subdata' sess.flush() - assert False + assert not testing.db.dialect.supports_sane_rowcount except orm_exc.ConcurrentModificationError, e: assert True @@ -550,12 +546,12 @@ class DistinctPKTest(_base.MappedTest): global person_table, employee_table, Person, Employee person_table = Table("persons", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("name", String(80)), ) employee_table = Table("employees", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("salary", Integer), Column("person_id", Integer, ForeignKey("persons.id")), ) @@ -623,7 +619,7 @@ class SyncCompileTest(_base.MappedTest): global _a_table, _b_table, _c_table _a_table = Table('a', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data1', String(128)) ) @@ -691,7 +687,7 @@ class OverrideColKeyTest(_base.MappedTest): global base, subtable base = Table('base', metadata, - Column('base_id', Integer, primary_key=True), + Column('base_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(255)), Column('sqlite_fixer', String(10)) ) @@ -921,7 +917,7 @@ class OptimizedLoadTest(_base.MappedTest): def define_tables(cls, metadata): global base, sub base = Table('base', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)), Column('type', String(50)) ) @@ -1008,7 +1004,7 @@ class PKDiscriminatorTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): parents = Table('parents', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(60))) children = Table('children', metadata, @@ -1061,14 +1057,14 @@ class DeleteOrphanTest(_base.MappedTest): def define_tables(cls, metadata): global single, parent single = Table('single', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('type', String(50), nullable=False), Column('data', String(50)), Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False), ) parent = Table('parent', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)) ) diff --git a/test/orm/inheritance/test_concrete.py b/test/orm/inheritance/test_concrete.py index 46bd171e4..3a78be9d7 100644 --- a/test/orm/inheritance/test_concrete.py +++ b/test/orm/inheritance/test_concrete.py @@ -8,6 +8,7 @@ from sqlalchemy.test import testing from test.orm import _base from sqlalchemy.orm import attributes from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.schema import Table, Column class Employee(object): def __init__(self, name): @@ -48,31 +49,31 @@ class ConcreteTest(_base.MappedTest): global managers_table, engineers_table, hackers_table, companies, employees_table companies = Table('companies', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50))) employees_table = Table('employees', metadata, - Column('employee_id', Integer, primary_key=True), + Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('company_id', Integer, ForeignKey('companies.id')) ) managers_table = Table('managers', metadata, - Column('employee_id', Integer, primary_key=True), + Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('manager_data', String(50)), Column('company_id', Integer, ForeignKey('companies.id')) ) engineers_table = Table('engineers', metadata, - Column('employee_id', Integer, primary_key=True), + Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('engineer_info', String(50)), Column('company_id', Integer, ForeignKey('companies.id')) ) hackers_table = Table('hackers', metadata, - Column('employee_id', Integer, primary_key=True), + Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('engineer_info', String(50)), Column('company_id', Integer, ForeignKey('companies.id')), @@ -320,17 +321,17 @@ class PropertyInheritanceTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('a_table', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('some_c_id', Integer, ForeignKey('c_table.id')), Column('aname', String(50)), ) Table('b_table', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('some_c_id', Integer, ForeignKey('c_table.id')), Column('bname', String(50)), ) Table('c_table', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('cname', String(50)), ) @@ -525,11 +526,11 @@ class ColKeysTest(_base.MappedTest): def define_tables(cls, metadata): global offices_table, refugees_table refugees_table = Table('refugee', metadata, - Column('refugee_fid', Integer, primary_key=True), + Column('refugee_fid', Integer, primary_key=True, test_needs_autoincrement=True), Column('refugee_name', Unicode(30), key='name')) offices_table = Table('office', metadata, - Column('office_fid', Integer, primary_key=True), + Column('office_fid', Integer, primary_key=True, test_needs_autoincrement=True), Column('office_name', Unicode(30), key='name')) @classmethod diff --git a/test/orm/inheritance/test_magazine.py b/test/orm/inheritance/test_magazine.py index 067301251..f94781c27 100644 --- a/test/orm/inheritance/test_magazine.py +++ b/test/orm/inheritance/test_magazine.py @@ -4,6 +4,7 @@ from sqlalchemy.orm import * from sqlalchemy.test import testing from sqlalchemy.util import function_named from test.orm import _base +from sqlalchemy.test.schema import Table, Column class BaseObject(object): def __init__(self, *args, **kwargs): @@ -75,49 +76,48 @@ class MagazineTest(_base.MappedTest): global publication_table, issue_table, location_table, location_name_table, magazine_table, \ page_table, magazine_page_table, classified_page_table, page_size_table - zerodefault = {} #{'default':0} publication_table = Table('publication', metadata, - Column('id', Integer, primary_key=True, default=None), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(45), default=''), ) issue_table = Table('issue', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('publication_id', Integer, ForeignKey('publication.id'), **zerodefault), - Column('issue', Integer, **zerodefault), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('publication_id', Integer, ForeignKey('publication.id')), + Column('issue', Integer), ) location_table = Table('location', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('issue_id', Integer, ForeignKey('issue.id'), **zerodefault), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('issue_id', Integer, ForeignKey('issue.id')), Column('ref', CHAR(3), default=''), - Column('location_name_id', Integer, ForeignKey('location_name.id'), **zerodefault), + Column('location_name_id', Integer, ForeignKey('location_name.id')), ) location_name_table = Table('location_name', metadata, - Column('id', Integer, primary_key=True, default=None), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(45), default=''), ) magazine_table = Table('magazine', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('location_id', Integer, ForeignKey('location.id'), **zerodefault), - Column('page_size_id', Integer, ForeignKey('page_size.id'), **zerodefault), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('location_id', Integer, ForeignKey('location.id')), + Column('page_size_id', Integer, ForeignKey('page_size.id')), ) page_table = Table('page', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('page_no', Integer, **zerodefault), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('page_no', Integer), Column('type', CHAR(1), default='p'), ) magazine_page_table = Table('magazine_page', metadata, - Column('page_id', Integer, ForeignKey('page.id'), primary_key=True, **zerodefault), - Column('magazine_id', Integer, ForeignKey('magazine.id'), **zerodefault), - Column('orders', TEXT, default=''), + Column('page_id', Integer, ForeignKey('page.id'), primary_key=True), + Column('magazine_id', Integer, ForeignKey('magazine.id')), + Column('orders', Text, default=''), ) classified_page_table = Table('classified_page', metadata, - Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True, **zerodefault), + Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True), Column('titles', String(45), default=''), ) page_size_table = Table('page_size', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('width', Integer, **zerodefault), - Column('height', Integer, **zerodefault), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('width', Integer), + Column('height', Integer), Column('name', String(45), default=''), ) @@ -176,10 +176,11 @@ def generate_round_trip_test(use_unions=False, use_joins=False): 'magazine': relation(Magazine, backref=backref('pages', order_by=page_table.c.page_no)) }) - classified_page_mapper = mapper(ClassifiedPage, classified_page_table, inherits=magazine_page_mapper, polymorphic_identity='c', primary_key=[page_table.c.id]) - #compile_mappers() - #print [str(s) for s in classified_page_mapper.primary_key] - #print classified_page_mapper.columntoproperty[page_table.c.id] + classified_page_mapper = mapper(ClassifiedPage, + classified_page_table, + inherits=magazine_page_mapper, + polymorphic_identity='c', + primary_key=[page_table.c.id]) session = create_session() diff --git a/test/orm/inheritance/test_manytomany.py b/test/orm/inheritance/test_manytomany.py index f7e676bbb..7b6ad04eb 100644 --- a/test/orm/inheritance/test_manytomany.py +++ b/test/orm/inheritance/test_manytomany.py @@ -194,11 +194,11 @@ class InheritTest3(_base.MappedTest): b.foos.append(Foo("foo #1")) b.foos.append(Foo("foo #2")) sess.flush() - compare = repr(b) + repr(sorted([repr(o) for o in b.foos])) + compare = [repr(b)] + sorted([repr(o) for o in b.foos]) sess.expunge_all() l = sess.query(Bar).all() print repr(l[0]) + repr(l[0].foos) - found = repr(l[0]) + repr(sorted([repr(o) for o in l[0].foos])) + found = [repr(l[0])] + sorted([repr(o) for o in l[0].foos]) eq_(found, compare) @testing.fails_on('maxdb', 'FIXME: unknown') diff --git a/test/orm/inheritance/test_poly_linked_list.py b/test/orm/inheritance/test_poly_linked_list.py index 67b543f31..e434218b9 100644 --- a/test/orm/inheritance/test_poly_linked_list.py +++ b/test/orm/inheritance/test_poly_linked_list.py @@ -3,6 +3,7 @@ from sqlalchemy.orm import * from test.orm import _base from sqlalchemy.test import testing +from sqlalchemy.test.schema import Table, Column class PolymorphicCircularTest(_base.MappedTest): @@ -12,7 +13,7 @@ class PolymorphicCircularTest(_base.MappedTest): def define_tables(cls, metadata): global Table1, Table1B, Table2, Table3, Data table1 = Table('table1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('related_id', Integer, ForeignKey('table1.id'), nullable=True), Column('type', String(30)), Column('name', String(30)) @@ -27,7 +28,7 @@ class PolymorphicCircularTest(_base.MappedTest): ) data = Table('data', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('node_id', Integer, ForeignKey('table1.id')), Column('data', String(30)) ) @@ -72,7 +73,7 @@ class PolymorphicCircularTest(_base.MappedTest): polymorphic_on=table1.c.type, polymorphic_identity='table1', properties={ - 'next': relation(Table1, + 'nxt': relation(Table1, backref=backref('prev', foreignkey=join.c.id, uselist=False), uselist=False, primaryjoin=join.c.id==join.c.related_id), 'data':relation(mapper(Data, data)) @@ -86,15 +87,16 @@ class PolymorphicCircularTest(_base.MappedTest): # currently, the "eager" relationships degrade to lazy relationships # due to the polymorphic load. - # the "next" relation used to have a "lazy=False" on it, but the EagerLoader raises the "self-referential" + # the "nxt" relation used to have a "lazy=False" on it, but the EagerLoader raises the "self-referential" # exception now. since eager loading would never work for that relation anyway, its better that the user # gets an exception instead of it silently not eager loading. + # NOTE: using "nxt" instead of "next" to avoid 2to3 turning it into __next__() for some reason. table1_mapper = mapper(Table1, table1, #select_table=join, polymorphic_on=table1.c.type, polymorphic_identity='table1', properties={ - 'next': relation(Table1, + 'nxt': relation(Table1, backref=backref('prev', remote_side=table1.c.id, uselist=False), uselist=False, primaryjoin=table1.c.id==table1.c.related_id), 'data':relation(mapper(Data, data), lazy=False, order_by=data.c.id) @@ -147,7 +149,7 @@ class PolymorphicCircularTest(_base.MappedTest): else: newobj = c if obj is not None: - obj.next = newobj + obj.nxt = newobj else: t = newobj obj = newobj @@ -161,7 +163,7 @@ class PolymorphicCircularTest(_base.MappedTest): node = t while (node): assertlist.append(node) - n = node.next + n = node.nxt if n is not None: assert n.prev is node node = n @@ -174,7 +176,7 @@ class PolymorphicCircularTest(_base.MappedTest): assertlist = [] while (node): assertlist.append(node) - n = node.next + n = node.nxt if n is not None: assert n.prev is node node = n @@ -188,7 +190,7 @@ class PolymorphicCircularTest(_base.MappedTest): assertlist.insert(0, node) n = node.prev if n is not None: - assert n.next is node + assert n.nxt is node node = n backwards = repr(assertlist) diff --git a/test/orm/inheritance/test_polymorph2.py b/test/orm/inheritance/test_polymorph2.py index 51b6d4970..80c14413a 100644 --- a/test/orm/inheritance/test_polymorph2.py +++ b/test/orm/inheritance/test_polymorph2.py @@ -11,6 +11,7 @@ from sqlalchemy.test import TestBase, AssertsExecutionResults, testing from sqlalchemy.util import function_named from test.orm import _base, _fixtures from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.schema import Table, Column class AttrSettable(object): def __init__(self, **kwargs): @@ -105,7 +106,7 @@ class RelationTest2(_base.MappedTest): def define_tables(cls, metadata): global people, managers, data people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('type', String(30))) @@ -201,7 +202,7 @@ class RelationTest3(_base.MappedTest): def define_tables(cls, metadata): global people, managers, data people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('colleague_id', Integer, ForeignKey('people.person_id')), Column('name', String(50)), Column('type', String(30))) @@ -307,7 +308,7 @@ class RelationTest4(_base.MappedTest): def define_tables(cls, metadata): global people, engineers, managers, cars people = Table('people', metadata, - Column('person_id', Integer, primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50))) engineers = Table('engineers', metadata, @@ -319,7 +320,7 @@ class RelationTest4(_base.MappedTest): Column('longer_status', String(70))) cars = Table('cars', metadata, - Column('car_id', Integer, primary_key=True), + Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('owner', Integer, ForeignKey('people.person_id'))) def testmanytoonepolymorphic(self): @@ -420,7 +421,7 @@ class RelationTest5(_base.MappedTest): def define_tables(cls, metadata): global people, engineers, managers, cars people = Table('people', metadata, - Column('person_id', Integer, primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('type', String(50))) @@ -433,7 +434,7 @@ class RelationTest5(_base.MappedTest): Column('longer_status', String(70))) cars = Table('cars', metadata, - Column('car_id', Integer, primary_key=True), + Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('owner', Integer, ForeignKey('people.person_id'))) def testeagerempty(self): @@ -482,7 +483,7 @@ class RelationTest6(_base.MappedTest): def define_tables(cls, metadata): global people, managers, data people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), ) @@ -525,14 +526,14 @@ class RelationTest7(_base.MappedTest): def define_tables(cls, metadata): global people, engineers, managers, cars, offroad_cars cars = Table('cars', metadata, - Column('car_id', Integer, primary_key=True), + Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(30))) offroad_cars = Table('offroad_cars', metadata, Column('car_id',Integer, ForeignKey('cars.car_id'),nullable=False,primary_key=True)) people = Table('people', metadata, - Column('person_id', Integer, primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('car_id', Integer, ForeignKey('cars.car_id'), nullable=False), Column('name', String(50))) @@ -625,7 +626,7 @@ class RelationTest8(_base.MappedTest): def define_tables(cls, metadata): global taggable, users taggable = Table('taggable', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('type', String(30)), Column('owner_id', Integer, ForeignKey('taggable.id')), ) @@ -680,11 +681,11 @@ class GenerativeTest(TestBase, AssertsExecutionResults): metadata = MetaData(testing.db) # table definitions status = Table('status', metadata, - Column('status_id', Integer, primary_key=True), + Column('status_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(20))) people = Table('people', metadata, - Column('person_id', Integer, primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False), Column('name', String(50))) @@ -697,7 +698,7 @@ class GenerativeTest(TestBase, AssertsExecutionResults): Column('category', String(70))) cars = Table('cars', metadata, - Column('car_id', Integer, primary_key=True), + Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False), Column('owner', Integer, ForeignKey('people.person_id'), nullable=False)) @@ -786,13 +787,13 @@ class GenerativeTest(TestBase, AssertsExecutionResults): e = exists([Car.owner], Car.owner==employee_join.c.person_id) Query(Person)._adapt_clause(employee_join, False, False) - r = session.query(Person).filter(Person.name.like('%2')).join('status').filter_by(name="active") - assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]" + r = session.query(Person).filter(Person.name.like('%2')).join('status').filter_by(name="active").order_by(Person.person_id) + eq_(str(list(r)), "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]") r = session.query(Engineer).join('status').filter(Person.name.in_(['E2', 'E3', 'E4', 'M4', 'M2', 'M1']) & (status.c.name=="active")).order_by(Person.name) - assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]" + eq_(str(list(r)), "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]") r = session.query(Person).filter(exists([1], Car.owner==Person.person_id)) - assert str(list(r)) == "[Engineer E4, field X, status Status dead]" + eq_(str(list(r)), "[Engineer E4, field X, status Status dead]") class MultiLevelTest(_base.MappedTest): @classmethod @@ -800,7 +801,7 @@ class MultiLevelTest(_base.MappedTest): global table_Employee, table_Engineer, table_Manager table_Employee = Table( 'Employee', metadata, Column( 'name', type_= String(100), ), - Column( 'id', primary_key= True, type_= Integer, ), + Column( 'id', primary_key= True, type_= Integer, test_needs_autoincrement=True), Column( 'atype', type_= String(100), ), ) @@ -878,7 +879,7 @@ class ManyToManyPolyTest(_base.MappedTest): global base_item_table, item_table, base_item_collection_table, collection_table base_item_table = Table( 'base_item', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('child_name', String(255), default=None)) item_table = Table( @@ -893,7 +894,7 @@ class ManyToManyPolyTest(_base.MappedTest): collection_table = Table( 'collection', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', Unicode(255))) def test_pjoin_compile(self): @@ -928,7 +929,7 @@ class CustomPKTest(_base.MappedTest): def define_tables(cls, metadata): global t1, t2 t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('type', String(30), nullable=False), Column('data', String(30))) # note that the primary key column in t2 is named differently @@ -1013,7 +1014,7 @@ class InheritingEagerTest(_base.MappedTest): global people, employees, tags, peopleTags people = Table('people', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('_type', String(30), nullable=False), ) @@ -1023,7 +1024,7 @@ class InheritingEagerTest(_base.MappedTest): ) tags = Table('tags', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('label', String(50), nullable=False), ) @@ -1074,11 +1075,11 @@ class MissingPolymorphicOnTest(_base.MappedTest): def define_tables(cls, metadata): global tablea, tableb, tablec, tabled tablea = Table('tablea', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('adata', String(50)), ) tableb = Table('tableb', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('aid', Integer, ForeignKey('tablea.id')), Column('data', String(50)), ) diff --git a/test/orm/inheritance/test_productspec.py b/test/orm/inheritance/test_productspec.py index b2bcb85d5..4c593e2a3 100644 --- a/test/orm/inheritance/test_productspec.py +++ b/test/orm/inheritance/test_productspec.py @@ -2,10 +2,9 @@ from datetime import datetime from sqlalchemy import * from sqlalchemy.orm import * - from sqlalchemy.test import testing from test.orm import _base - +from sqlalchemy.test.schema import Table, Column class InheritTest(_base.MappedTest): """tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships""" @@ -15,14 +14,14 @@ class InheritTest(_base.MappedTest): global Product, Detail, Assembly, SpecLine, Document, RasterDocument products_table = Table('products', metadata, - Column('product_id', Integer, primary_key=True), + Column('product_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('product_type', String(128)), Column('name', String(128)), Column('mark', String(128)), ) specification_table = Table('specification', metadata, - Column('spec_line_id', Integer, primary_key=True), + Column('spec_line_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('master_id', Integer, ForeignKey("products.product_id"), nullable=True), Column('slave_id', Integer, ForeignKey("products.product_id"), @@ -31,7 +30,7 @@ class InheritTest(_base.MappedTest): ) documents_table = Table('documents', metadata, - Column('document_id', Integer, primary_key=True), + Column('document_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('document_type', String(128)), Column('product_id', Integer, ForeignKey('products.product_id')), Column('create_date', DateTime, default=lambda:datetime.now()), diff --git a/test/orm/inheritance/test_query.py b/test/orm/inheritance/test_query.py index 5b57e8f45..daf8bf3bd 100644 --- a/test/orm/inheritance/test_query.py +++ b/test/orm/inheritance/test_query.py @@ -8,6 +8,7 @@ from sqlalchemy.engine import default from sqlalchemy.test import AssertsCompiledSQL, testing from test.orm import _base, _fixtures from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.schema import Table, Column class Company(_fixtures.Base): pass @@ -38,11 +39,11 @@ def _produce_test(select_type): global companies, people, engineers, managers, boss, paperwork, machines companies = Table('companies', metadata, - Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key=True), + Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50))) people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('company_id', Integer, ForeignKey('companies.company_id')), Column('name', String(50)), Column('type', String(30))) @@ -55,7 +56,7 @@ def _produce_test(select_type): ) machines = Table('machines', metadata, - Column('machine_id', Integer, primary_key=True), + Column('machine_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('engineer_id', Integer, ForeignKey('engineers.person_id'))) @@ -71,7 +72,7 @@ def _produce_test(select_type): ) paperwork = Table('paperwork', metadata, - Column('paperwork_id', Integer, primary_key=True), + Column('paperwork_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('description', String(50)), Column('person_id', Integer, ForeignKey('people.person_id'))) @@ -771,7 +772,7 @@ class SelfReferentialTestJoinedToBase(_base.MappedTest): def define_tables(cls, metadata): global people, engineers people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('type', String(30))) @@ -831,7 +832,7 @@ class SelfReferentialJ2JTest(_base.MappedTest): def define_tables(cls, metadata): global people, engineers, managers people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('type', String(30))) @@ -947,7 +948,7 @@ class M2MFilterTest(_base.MappedTest): global people, engineers, organizations, engineers_to_org organizations = Table('organizations', metadata, - Column('id', Integer, Sequence('org_id_seq', optional=True), primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), ) engineers_to_org = Table('engineers_org', metadata, @@ -956,7 +957,7 @@ class M2MFilterTest(_base.MappedTest): ) people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('type', String(30))) @@ -1023,7 +1024,7 @@ class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL): class Parent(Base): __tablename__ = 'parent' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) cls = Column(String(50)) __mapper_args__ = dict(polymorphic_on = cls ) diff --git a/test/orm/inheritance/test_selects.py b/test/orm/inheritance/test_selects.py index a151af4fa..7c9920f6f 100644 --- a/test/orm/inheritance/test_selects.py +++ b/test/orm/inheritance/test_selects.py @@ -46,6 +46,6 @@ class InheritingSelectablesTest(MappedTest): s = sessionmaker(bind=testing.db)() - assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).all() + assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).order_by(Foo.b.desc()).all() assert [Bar(), Bar()] == s.query(Bar).all() diff --git a/test/orm/inheritance/test_single.py b/test/orm/inheritance/test_single.py index 705826885..fc30955db 100644 --- a/test/orm/inheritance/test_single.py +++ b/test/orm/inheritance/test_single.py @@ -5,20 +5,21 @@ from sqlalchemy.orm import * from sqlalchemy.test import testing from test.orm import _fixtures from test.orm._base import MappedTest, ComparableEntity +from sqlalchemy.test.schema import Table, Column class SingleInheritanceTest(MappedTest): @classmethod def define_tables(cls, metadata): Table('employees', metadata, - Column('employee_id', Integer, primary_key=True), + Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('manager_data', String(50)), Column('engineer_info', String(50)), Column('type', String(20))) Table('reports', metadata, - Column('report_id', Integer, primary_key=True), + Column('report_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('employee_id', ForeignKey('employees.employee_id')), Column('name', String(50)), ) @@ -186,7 +187,7 @@ class RelationToSingleTest(MappedTest): @classmethod def define_tables(cls, metadata): Table('employees', metadata, - Column('employee_id', Integer, primary_key=True), + Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('manager_data', String(50)), Column('engineer_info', String(50)), @@ -195,7 +196,7 @@ class RelationToSingleTest(MappedTest): ) Table('companies', metadata, - Column('company_id', Integer, primary_key=True), + Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), ) @@ -342,7 +343,7 @@ class SingleOnJoinedTest(MappedTest): global persons_table, employees_table persons_table = Table('persons', metadata, - Column('person_id', Integer, primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('type', String(20), nullable=False) ) diff --git a/test/orm/sharding/test_shard.py b/test/orm/sharding/test_shard.py index 89e23fb75..e8ffaa7ca 100644 --- a/test/orm/sharding/test_shard.py +++ b/test/orm/sharding/test_shard.py @@ -6,6 +6,7 @@ from sqlalchemy.orm.shard import ShardedSession from sqlalchemy.sql import operators from sqlalchemy.test import * from sqlalchemy.test.testing import eq_ +from nose import SkipTest # TODO: ShardTest can be turned into a base for further subclasses @@ -14,7 +15,10 @@ class ShardTest(TestBase): def setup_class(cls): global db1, db2, db3, db4, weather_locations, weather_reports - db1 = create_engine('sqlite:///shard1.db') + try: + db1 = create_engine('sqlite:///shard1.db') + except ImportError: + raise SkipTest('Requires sqlite') db2 = create_engine('sqlite:///shard2.db') db3 = create_engine('sqlite:///shard3.db') db4 = create_engine('sqlite:///shard4.db') diff --git a/test/orm/test_association.py b/test/orm/test_association.py index ee7fb7af9..d537430cc 100644 --- a/test/orm/test_association.py +++ b/test/orm/test_association.py @@ -1,8 +1,7 @@ from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session from test.orm import _base from sqlalchemy.test.testing import eq_ @@ -15,14 +14,14 @@ class AssociationTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('items', metadata, - Column('item_id', Integer, primary_key=True), + Column('item_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(40))) Table('item_keywords', metadata, Column('item_id', Integer, ForeignKey('items.item_id')), Column('keyword_id', Integer, ForeignKey('keywords.keyword_id')), Column('data', String(40))) Table('keywords', metadata, - Column('keyword_id', Integer, primary_key=True), + Column('keyword_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(40))) @classmethod diff --git a/test/orm/test_assorted_eager.py b/test/orm/test_assorted_eager.py index 09f007547..94a98d9ae 100644 --- a/test/orm/test_assorted_eager.py +++ b/test/orm/test_assorted_eager.py @@ -8,8 +8,7 @@ import datetime import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, backref, create_session from sqlalchemy.test.testing import eq_ from test.orm import _base @@ -37,27 +36,24 @@ class EagerTest(_base.MappedTest): cls.other_artifacts['false'] = false Table('owners', metadata , - Column('id', Integer, primary_key=True, nullable=False), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30))) Table('categories', metadata, - Column('id', Integer, primary_key=True, nullable=False), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(20))) Table('tests', metadata , - Column('id', Integer, primary_key=True, nullable=False ), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('owner_id', Integer, ForeignKey('owners.id'), nullable=False), Column('category_id', Integer, ForeignKey('categories.id'), nullable=False)) Table('options', metadata , - Column('test_id', Integer, ForeignKey('tests.id'), - primary_key=True, nullable=False), - Column('owner_id', Integer, ForeignKey('owners.id'), - primary_key=True, nullable=False), - Column('someoption', sa.Boolean, server_default=false, - nullable=False)) + Column('test_id', Integer, ForeignKey('tests.id'), primary_key=True), + Column('owner_id', Integer, ForeignKey('owners.id'), primary_key=True), + Column('someoption', sa.Boolean, server_default=false, nullable=False)) @classmethod def setup_classes(cls): @@ -219,7 +215,7 @@ class EagerTest2(_base.MappedTest): Column('data', String(50), primary_key=True)) Table('middle', metadata, - Column('id', Integer, primary_key = True), + Column('id', Integer, primary_key = True, test_needs_autoincrement=True), Column('data', String(50))) Table('right', metadata, @@ -280,17 +276,15 @@ class EagerTest3(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('datas', metadata, - Column('id', Integer, primary_key=True, nullable=False), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('a', Integer, nullable=False)) Table('foo', metadata, - Column('data_id', Integer, - ForeignKey('datas.id'), - nullable=False, primary_key=True), + Column('data_id', Integer, ForeignKey('datas.id'),primary_key=True), Column('bar', Integer)) Table('stats', metadata, - Column('id', Integer, primary_key=True, nullable=False ), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data_id', Integer, ForeignKey('datas.id')), Column('somedata', Integer, nullable=False )) @@ -364,11 +358,11 @@ class EagerTest4(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('departments', metadata, - Column('department_id', Integer, primary_key=True), + Column('department_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50))) Table('employees', metadata, - Column('person_id', Integer, primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('department_id', Integer, ForeignKey('departments.department_id'))) @@ -422,17 +416,15 @@ class EagerTest5(_base.MappedTest): Column('x', String(30))) Table('derived', metadata, - Column('uid', String(30), ForeignKey('base.uid'), - primary_key=True), + Column('uid', String(30), ForeignKey('base.uid'), primary_key=True), Column('y', String(30))) Table('derivedII', metadata, - Column('uid', String(30), ForeignKey('base.uid'), - primary_key=True), + Column('uid', String(30), ForeignKey('base.uid'), primary_key=True), Column('z', String(30))) Table('comments', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('uid', String(30), ForeignKey('base.uid')), Column('comment', String(30))) @@ -505,21 +497,21 @@ class EagerTest6(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('design_types', metadata, - Column('design_type_id', Integer, primary_key=True)) + Column('design_type_id', Integer, primary_key=True, test_needs_autoincrement=True)) Table('design', metadata, - Column('design_id', Integer, primary_key=True), + Column('design_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('design_type_id', Integer, ForeignKey('design_types.design_type_id'))) Table('parts', metadata, - Column('part_id', Integer, primary_key=True), + Column('part_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('design_id', Integer, ForeignKey('design.design_id')), Column('design_type_id', Integer, ForeignKey('design_types.design_type_id'))) Table('inherited_part', metadata, - Column('ip_id', Integer, primary_key=True), + Column('ip_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('part_id', Integer, ForeignKey('parts.part_id')), Column('design_id', Integer, ForeignKey('design.design_id'))) @@ -573,32 +565,27 @@ class EagerTest7(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('companies', metadata, - Column('company_id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('company_name', String(40))) Table('addresses', metadata, - Column('address_id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('address_id', Integer, primary_key=True,test_needs_autoincrement=True), Column('company_id', Integer, ForeignKey("companies.company_id")), Column('address', String(40))) Table('phone_numbers', metadata, - Column('phone_id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('phone_id', Integer, primary_key=True,test_needs_autoincrement=True), Column('address_id', Integer, ForeignKey('addresses.address_id')), Column('type', String(20)), Column('number', String(10))) Table('invoices', metadata, - Column('invoice_id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('invoice_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('company_id', Integer, ForeignKey("companies.company_id")), Column('date', sa.DateTime)) Table('items', metadata, - Column('item_id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('item_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('invoice_id', Integer, ForeignKey('invoices.invoice_id')), Column('code', String(20)), Column('qty', Integer)) @@ -722,12 +709,12 @@ class EagerTest8(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('prj', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('created', sa.DateTime ), Column('title', sa.Unicode(100))) Table('task', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('status_id', Integer, ForeignKey('task_status.id'), nullable=False), Column('title', sa.Unicode(100)), @@ -736,19 +723,19 @@ class EagerTest8(_base.MappedTest): Column('prj_id', Integer , ForeignKey('prj.id'), nullable=False)) Table('task_status', metadata, - Column('id', Integer, primary_key=True)) + Column('id', Integer, primary_key=True, test_needs_autoincrement=True)) Table('task_type', metadata, - Column('id', Integer, primary_key=True)) + Column('id', Integer, primary_key=True, test_needs_autoincrement=True)) Table('msg', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('posted', sa.DateTime, index=True,), Column('type_id', Integer, ForeignKey('msg_type.id')), Column('task_id', Integer, ForeignKey('task.id'))) Table('msg_type', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', sa.Unicode(20)), Column('display_name', sa.Unicode(20))) @@ -814,15 +801,15 @@ class EagerTest9(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('accounts', metadata, - Column('account_id', Integer, primary_key=True), + Column('account_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(40))) Table('transactions', metadata, - Column('transaction_id', Integer, primary_key=True), + Column('transaction_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(40))) Table('entries', metadata, - Column('entry_id', Integer, primary_key=True), + Column('entry_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(40)), Column('account_id', Integer, ForeignKey('accounts.account_id')), diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index ca8cef3ad..fa26ec7d7 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -6,7 +6,8 @@ from sqlalchemy import exc as sa_exc from sqlalchemy.test import * from sqlalchemy.test.testing import eq_ from test.orm import _base -import gc +from sqlalchemy.test.util import gc_collect +from sqlalchemy.util import cmp, jython # global for pickling tests MyTest = None @@ -80,7 +81,9 @@ class AttributesTest(_base.ORMTest): del o2.__dict__['mt2'] o2.__dict__[o_mt2_str] = former - self.assert_(pk_o == pk_o2) + # Relies on dict ordering + if not jython: + self.assert_(pk_o == pk_o2) # the above is kind of distrurbing, so let's do it again a little # differently. the string-id in serialization thing is just an @@ -93,7 +96,9 @@ class AttributesTest(_base.ORMTest): o4 = pickle.loads(pk_o3) pk_o4 = pickle.dumps(o4) - self.assert_(pk_o3 == pk_o4) + # Relies on dict ordering + if not jython: + self.assert_(pk_o3 == pk_o4) # and lastly make sure we still have our data after all that. # identical serialzation is great, *if* it's complete :) @@ -117,7 +122,7 @@ class AttributesTest(_base.ORMTest): f.bar = "foo" assert state.dict == {'bar':'foo', state.manager.STATE_ATTR:state} del f - gc.collect() + gc_collect() assert state.obj() is None assert state.dict == {} diff --git a/test/orm/test_cascade.py b/test/orm/test_cascade.py index d0a7b9ded..c523fb5f0 100644 --- a/test/orm/test_cascade.py +++ b/test/orm/test_cascade.py @@ -1,8 +1,7 @@ from sqlalchemy.test.testing import assert_raises, assert_raises_message from sqlalchemy import Integer, String, ForeignKey, Sequence, exc as sa_exc -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session, class_mapper, backref from sqlalchemy.orm import attributes, exc as orm_exc from sqlalchemy.test import testing @@ -20,7 +19,7 @@ class O2MCascadeTest(_fixtures.FixtureTest): mapper(User, users, properties = dict( addresses = relation(Address, cascade="all, delete-orphan", backref="user"), orders = relation( - mapper(Order, orders), cascade="all, delete-orphan") + mapper(Order, orders), cascade="all, delete-orphan", order_by=orders.c.id) )) mapper(Dingaling,dingalings, properties={ 'address':relation(Address) @@ -50,16 +49,12 @@ class O2MCascadeTest(_fixtures.FixtureTest): orders=[Order(description="order 3"), Order(description="order 4")])) - eq_(sess.query(Order).all(), + eq_(sess.query(Order).order_by(Order.id).all(), [Order(description="order 3"), Order(description="order 4")]) o5 = Order(description="order 5") sess.add(o5) - try: - sess.flush() - assert False - except orm_exc.FlushError, e: - assert "is an orphan" in str(e) + assert_raises_message(orm_exc.FlushError, "is an orphan", sess.flush) @testing.resolve_artifact_names @@ -351,18 +346,15 @@ class M2OCascadeTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("extra", metadata, - Column("id", Integer, Sequence("extra_id_seq", optional=True), - primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("prefs_id", Integer, ForeignKey("prefs.id"))) Table('prefs', metadata, - Column('id', Integer, Sequence('prefs_id_seq', optional=True), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(40))) Table('users', metadata, - Column('id', Integer, Sequence('user_id_seq', optional=True), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(40)), Column('pref_id', Integer, ForeignKey('prefs.id'))) @@ -453,22 +445,22 @@ class M2OCascadeTest(_base.MappedTest): jack.pref = newpref jack.pref = newpref sess.flush() - eq_(sess.query(Pref).all(), + eq_(sess.query(Pref).order_by(Pref.id).all(), [Pref(data="pref 1"), Pref(data="pref 3"), Pref(data="newpref")]) class M2OCascadeDeleteTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)), Column('t2id', Integer, ForeignKey('t2.id'))) Table('t2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)), Column('t3id', Integer, ForeignKey('t3.id'))) Table('t3', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) @classmethod @@ -581,15 +573,15 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)), Column('t2id', Integer, ForeignKey('t2.id'))) Table('t2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)), Column('t3id', Integer, ForeignKey('t3.id'))) Table('t3', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) @classmethod @@ -696,12 +688,12 @@ class M2MCascadeTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('a', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), test_needs_fk=True ) Table('b', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), test_needs_fk=True @@ -713,7 +705,7 @@ class M2MCascadeTest(_base.MappedTest): ) Table('c', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('bid', Integer, ForeignKey('b.id')), test_needs_fk=True @@ -838,15 +830,11 @@ class UnsavedOrphansTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('users', metadata, - Column('user_id', Integer, - Sequence('user_id_seq', optional=True), - primary_key=True), + Column('user_id', Integer,primary_key=True, test_needs_autoincrement=True), Column('name', String(40))) Table('addresses', metadata, - Column('address_id', Integer, - Sequence('address_id_seq', optional=True), - primary_key=True), + Column('address_id', Integer,primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer, ForeignKey('users.user_id')), Column('email_address', String(40))) @@ -923,20 +911,17 @@ class UnsavedOrphansTest2(_base.MappedTest): @classmethod def define_tables(cls, meta): Table('orders', meta, - Column('id', Integer, Sequence('order_id_seq'), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50))) Table('items', meta, - Column('id', Integer, Sequence('item_id_seq'), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('order_id', Integer, ForeignKey('orders.id'), nullable=False), Column('name', String(50))) Table('attributes', meta, - Column('id', Integer, Sequence('attribute_id_seq'), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('item_id', Integer, ForeignKey('items.id'), nullable=False), Column('name', String(50))) @@ -982,19 +967,13 @@ class UnsavedOrphansTest3(_base.MappedTest): @classmethod def define_tables(cls, meta): Table('sales_reps', meta, - Column('sales_rep_id', Integer, - Sequence('sales_rep_id_seq'), - primary_key=True), + Column('sales_rep_id', Integer,primary_key=True, test_needs_autoincrement=True), Column('name', String(50))) Table('accounts', meta, - Column('account_id', Integer, - Sequence('account_id_seq'), - primary_key=True), + Column('account_id', Integer,primary_key=True, test_needs_autoincrement=True), Column('balance', Integer)) Table('customers', meta, - Column('customer_id', Integer, - Sequence('customer_id_seq'), - primary_key=True), + Column('customer_id', Integer,primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('sales_rep_id', Integer, ForeignKey('sales_reps.sales_rep_id')), @@ -1087,19 +1066,19 @@ class DoubleParentOrphanTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('addresses', metadata, - Column('address_id', Integer, primary_key=True), + Column('address_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('street', String(30)), ) Table('homes', metadata, - Column('home_id', Integer, primary_key=True, key="id"), + Column('home_id', Integer, primary_key=True, key="id", test_needs_autoincrement=True), Column('description', String(30)), Column('address_id', Integer, ForeignKey('addresses.address_id'), nullable=False), ) Table('businesses', metadata, - Column('business_id', Integer, primary_key=True, key="id"), + Column('business_id', Integer, primary_key=True, key="id", test_needs_autoincrement=True), Column('description', String(30), key="description"), Column('address_id', Integer, ForeignKey('addresses.address_id'), nullable=False), @@ -1159,10 +1138,10 @@ class CollectionAssignmentOrphanTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('table_a', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(30))) Table('table_b', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(30)), Column('a_id', Integer, ForeignKey('table_a.id'))) @@ -1208,12 +1187,12 @@ class PartialFlushTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("base", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("descr", String(50)) ) Table("noninh_child", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('base_id', Integer, ForeignKey('base.id')) ) diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py index 12ff25c46..3d1b30bc9 100644 --- a/test/orm/test_collection.py +++ b/test/orm/test_collection.py @@ -8,12 +8,11 @@ from sqlalchemy.orm.collections import collection import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy import util, exc as sa_exc -from sqlalchemy.orm import create_session, mapper, relation, attributes +from sqlalchemy.orm import create_session, mapper, relation, attributes from test.orm import _base -from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.testing import eq_, assert_raises class Canary(sa.orm.interfaces.AttributeExtension): def __init__(self): @@ -169,6 +168,13 @@ class CollectionsTest(_base.ORMTest): control[slice(0,-1)] = values assert_eq() + values = [creator(),creator(),creator()] + control[:] = values + direct[:] = values + def invalid(): + direct[slice(0, 6, 2)] = [creator()] + assert_raises(ValueError, invalid) + if hasattr(direct, '__delitem__'): e = creator() direct.append(e) @@ -193,7 +199,7 @@ class CollectionsTest(_base.ORMTest): del direct[::2] del control[::2] assert_eq() - + if hasattr(direct, 'remove'): e = creator() direct.append(e) @@ -202,8 +208,21 @@ class CollectionsTest(_base.ORMTest): direct.remove(e) control.remove(e) assert_eq() - - if hasattr(direct, '__setslice__'): + + if hasattr(direct, '__setitem__') or hasattr(direct, '__setslice__'): + + values = [creator(), creator()] + direct[:] = values + control[:] = values + assert_eq() + + # test slice assignment where + # slice size goes over the number of items + values = [creator(), creator()] + direct[1:3] = values + control[1:3] = values + assert_eq() + values = [creator(), creator()] direct[0:1] = values control[0:1] = values @@ -228,8 +247,19 @@ class CollectionsTest(_base.ORMTest): direct[1::2] = values control[1::2] = values assert_eq() + + values = [creator(), creator()] + direct[-1:-3] = values + control[-1:-3] = values + assert_eq() - if hasattr(direct, '__delslice__'): + values = [creator(), creator()] + direct[-2:-1] = values + control[-2:-1] = values + assert_eq() + + + if hasattr(direct, '__delitem__') or hasattr(direct, '__delslice__'): for i in range(1, 4): e = creator() direct.append(e) @@ -246,7 +276,7 @@ class CollectionsTest(_base.ORMTest): del direct[:] del control[:] assert_eq() - + if hasattr(direct, 'extend'): values = [creator(), creator(), creator()] @@ -345,6 +375,45 @@ class CollectionsTest(_base.ORMTest): self._test_list(list) self._test_list_bulk(list) + def test_list_setitem_with_slices(self): + + # this is a "list" that has no __setslice__ + # or __delslice__ methods. The __setitem__ + # and __delitem__ must therefore accept + # slice objects (i.e. as in py3k) + class ListLike(object): + def __init__(self): + self.data = list() + def append(self, item): + self.data.append(item) + def remove(self, item): + self.data.remove(item) + def insert(self, index, item): + self.data.insert(index, item) + def pop(self, index=-1): + return self.data.pop(index) + def extend(self): + assert False + def __len__(self): + return len(self.data) + def __setitem__(self, key, value): + self.data[key] = value + def __getitem__(self, key): + return self.data[key] + def __delitem__(self, key): + del self.data[key] + def __iter__(self): + return iter(self.data) + __hash__ = object.__hash__ + def __eq__(self, other): + return self.data == other + def __repr__(self): + return 'ListLike(%s)' % repr(self.data) + + self._test_adapter(ListLike) + self._test_list(ListLike) + self._test_list_bulk(ListLike) + def test_list_subclass(self): class MyList(list): pass @@ -1343,10 +1412,10 @@ class DictHelpersTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('parents', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('label', String(128))) Table('children', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('parent_id', Integer, ForeignKey('parents.id'), nullable=False), Column('a', String(128)), @@ -1481,12 +1550,12 @@ class DictHelpersTest(_base.MappedTest): class Foo(BaseObject): __tablename__ = "foo" - id = Column(Integer(), primary_key=True) + id = Column(Integer(), primary_key=True, test_needs_autoincrement=True) bar_id = Column(Integer, ForeignKey('bar.id')) class Bar(BaseObject): __tablename__ = "bar" - id = Column(Integer(), primary_key=True) + id = Column(Integer(), primary_key=True, test_needs_autoincrement=True) foos = relation(Foo, collection_class=collections.column_mapped_collection(Foo.id)) foos2 = relation(Foo, collection_class=collections.column_mapped_collection((Foo.id, Foo.bar_id))) @@ -1521,17 +1590,16 @@ class DictHelpersTest(_base.MappedTest): collection_class = lambda: Ordered2(lambda v: (v.a, v.b)) self._test_composite_mapped(collection_class) -# TODO: are these tests redundant vs. the above tests ? -# remove if so class CustomCollectionsTest(_base.MappedTest): + """test the integration of collections with mapped classes.""" @classmethod def define_tables(cls, metadata): Table('sometable', metadata, - Column('col1',Integer, primary_key=True), + Column('col1',Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30))) Table('someothertable', metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('scol1', Integer, ForeignKey('sometable.col1')), Column('data', String(20))) @@ -1646,15 +1714,50 @@ class CustomCollectionsTest(_base.MappedTest): replaced = set([id(b) for b in f.bars.values()]) self.assert_(existing != replaced) - @testing.resolve_artifact_names def test_list(self): + self._test_list(list) + + def test_list_no_setslice(self): + class ListLike(object): + def __init__(self): + self.data = list() + def append(self, item): + self.data.append(item) + def remove(self, item): + self.data.remove(item) + def insert(self, index, item): + self.data.insert(index, item) + def pop(self, index=-1): + return self.data.pop(index) + def extend(self): + assert False + def __len__(self): + return len(self.data) + def __setitem__(self, key, value): + self.data[key] = value + def __getitem__(self, key): + return self.data[key] + def __delitem__(self, key): + del self.data[key] + def __iter__(self): + return iter(self.data) + __hash__ = object.__hash__ + def __eq__(self, other): + return self.data == other + def __repr__(self): + return 'ListLike(%s)' % repr(self.data) + + self._test_list(ListLike) + + @testing.resolve_artifact_names + def _test_list(self, listcls): class Parent(object): pass class Child(object): pass mapper(Parent, sometable, properties={ - 'children':relation(Child, collection_class=list) + 'children':relation(Child, collection_class=listcls) }) mapper(Child, someothertable) diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py index fe77b3601..6fbfe7fe1 100644 --- a/test/orm/test_cycles.py +++ b/test/orm/test_cycles.py @@ -7,8 +7,7 @@ T1/T2. """ from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, backref, create_session from sqlalchemy.test.testing import eq_ from sqlalchemy.test.assertsql import RegexSQL, ExactSQL, CompiledSQL, AllOf @@ -138,7 +137,7 @@ class SelfReferentialNoPKTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('item', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('uuid', String(32), unique=True, nullable=False), Column('parent_uuid', String(32), ForeignKey('item.uuid'), nullable=True)) @@ -190,18 +189,16 @@ class InheritTestOne(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("parent", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("parent_data", String(50)), Column("type", String(10))) Table("child1", metadata, - Column("id", Integer, ForeignKey("parent.id"), - primary_key=True), + Column("id", Integer, ForeignKey("parent.id"), primary_key=True), Column("child1_data", String(50))) Table("child2", metadata, - Column("id", Integer, ForeignKey("parent.id"), - primary_key=True), + Column("id", Integer, ForeignKey("parent.id"), primary_key=True), Column("child1_id", Integer, ForeignKey("child1.id"), nullable=False), Column("child2_data", String(50))) @@ -262,7 +259,7 @@ class InheritTestTwo(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('a', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('cid', Integer, ForeignKey('c.id'))) @@ -271,7 +268,7 @@ class InheritTestTwo(_base.MappedTest): Column('data', String(30))) Table('c', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('aid', Integer, ForeignKey('a.id', use_alter=True, name="foo"))) @@ -311,16 +308,16 @@ class BiDirectionalManyToOneTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('t2id', Integer, ForeignKey('t2.id'))) Table('t2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('t1id', Integer, ForeignKey('t1.id', use_alter=True, name="foo_fk"))) Table('t3', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('t1id', Integer, ForeignKey('t1.id'), nullable=False), Column('t2id', Integer, ForeignKey('t2.id'), nullable=False)) @@ -402,13 +399,11 @@ class BiDirectionalOneToManyTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), Column('c2', Integer, ForeignKey('t2.c1'))) Table('t2', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), Column('c2', Integer, ForeignKey('t1.c1', use_alter=True, name='t1c1_fk'))) @@ -453,18 +448,18 @@ class BiDirectionalOneToManyTest2(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('c1', Integer, primary_key=True), + Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), Column('c2', Integer, ForeignKey('t2.c1')), test_needs_autoincrement=True) Table('t2', metadata, - Column('c1', Integer, primary_key=True), + Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), Column('c2', Integer, ForeignKey('t1.c1', use_alter=True, name='t1c1_fq')), test_needs_autoincrement=True) Table('t1_data', metadata, - Column('c1', Integer, primary_key=True), + Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), Column('t1id', Integer, ForeignKey('t1.c1')), Column('data', String(20)), test_needs_autoincrement=True) @@ -530,15 +525,13 @@ class OneToManyManyToOneTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('ball', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('person_id', Integer, ForeignKey('person.id', use_alter=True, name='fk_person_id')), Column('data', String(30))) Table('person', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('favorite_ball_id', Integer, ForeignKey('ball.id')), Column('data', String(30))) @@ -841,7 +834,7 @@ class SelfReferentialPostUpdateTest2(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("a_table", metadata, - Column("id", Integer(), primary_key=True), + Column("id", Integer(), primary_key=True, test_needs_autoincrement=True), Column("fui", String(128)), Column("b", Integer(), ForeignKey("a_table.id"))) diff --git a/test/orm/test_defaults.py b/test/orm/test_defaults.py index b063780ac..5379c9714 100644 --- a/test/orm/test_defaults.py +++ b/test/orm/test_defaults.py @@ -2,8 +2,7 @@ import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session from test.orm import _base from sqlalchemy.test.testing import eq_ @@ -15,7 +14,7 @@ class TriggerDefaultsTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): dt = Table('dt', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('col1', String(20)), Column('col2', String(20), server_default=sa.schema.FetchedValue()), @@ -34,17 +33,23 @@ class TriggerDefaultsTest(_base.MappedTest): "UPDATE dt SET col2='ins', col4='ins' " "WHERE dt.id IN (SELECT id FROM inserted);", on='mssql'), - ): - if testing.against(ins.on): - break - else: - ins = sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT ON dt " + sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT " + "ON dt " + "FOR EACH ROW " + "BEGIN " + ":NEW.col2 := 'ins'; :NEW.col4 := 'ins'; END;", + on='oracle'), + sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT ON dt " "FOR EACH ROW BEGIN " - "SET NEW.col2='ins'; SET NEW.col4='ins'; END") - ins.execute_at('after-create', dt) + "SET NEW.col2='ins'; SET NEW.col4='ins'; END", + on=lambda event, schema_item, bind, **kw: + bind.engine.name not in ('oracle', 'mssql', 'sqlite') + ), + ): + ins.execute_at('after-create', dt) + sa.DDL("DROP TRIGGER dt_ins").execute_at('before-drop', dt) - for up in ( sa.DDL("CREATE TRIGGER dt_up AFTER UPDATE ON dt " "FOR EACH ROW BEGIN " @@ -55,14 +60,19 @@ class TriggerDefaultsTest(_base.MappedTest): "UPDATE dt SET col3='up', col4='up' " "WHERE dt.id IN (SELECT id FROM deleted);", on='mssql'), - ): - if testing.against(up.on): - break - else: - up = sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt " + sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt " + "FOR EACH ROW BEGIN " + ":NEW.col3 := 'up'; :NEW.col4 := 'up'; END;", + on='oracle'), + sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt " "FOR EACH ROW BEGIN " - "SET NEW.col3='up'; SET NEW.col4='up'; END") - up.execute_at('after-create', dt) + "SET NEW.col3='up'; SET NEW.col4='up'; END", + on=lambda event, schema_item, bind, **kw: + bind.engine.name not in ('oracle', 'mssql', 'sqlite') + ), + ): + up.execute_at('after-create', dt) + sa.DDL("DROP TRIGGER dt_up").execute_at('before-drop', dt) @@ -115,7 +125,7 @@ class ExcludedDefaultsTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): dt = Table('dt', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('col1', String(20), default="hello"), ) diff --git a/test/orm/test_dynamic.py b/test/orm/test_dynamic.py index f2089a435..23a5fc876 100644 --- a/test/orm/test_dynamic.py +++ b/test/orm/test_dynamic.py @@ -3,8 +3,7 @@ import operator from sqlalchemy.orm import dynamic_loader, backref from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey, desc, select, func -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session, Query, attributes from sqlalchemy.orm.dynamic import AppenderMixin from sqlalchemy.test.testing import eq_ @@ -344,7 +343,8 @@ class SessionTest(_fixtures.FixtureTest): sess.flush() sess.commit() u1.addresses.append(Address(email_address='foo@bar.com')) - eq_(u1.addresses.all(), [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')]) + eq_(u1.addresses.order_by(Address.id).all(), + [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')]) sess.rollback() eq_(u1.addresses.all(), [Address(email_address='lala@hoho.com')]) @@ -502,13 +502,13 @@ class DontDereferenceTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(40)), Column('fullname', String(100)), Column('password', String(15))) Table('addresses', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('email_address', String(100), nullable=False), Column('user_id', Integer, ForeignKey('users.id'))) diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py index 384e0472f..425c08c61 100644 --- a/test/orm/test_eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -5,8 +5,7 @@ import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy.orm import eagerload, deferred, undefer from sqlalchemy import Integer, String, Date, ForeignKey, and_, select, func -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session, lazyload, aliased from sqlalchemy.test.testing import eq_ from sqlalchemy.test.assertsql import CompiledSQL @@ -459,20 +458,14 @@ class EagerTest(_fixtures.FixtureTest): }) mapper(User, users, properties={ 'addresses':relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.id), - 'orders':relation(Order, lazy=True) + 'orders':relation(Order, lazy=True, order_by=orders.c.id) }) sess = create_session() q = sess.query(User) - if testing.against('mysql'): - l = q.limit(2).all() - assert self.static.user_all_result[:2] == l - else: - l = q.order_by(User.id).limit(2).offset(1).all() - print self.static.user_all_result[1:3] - print l - assert self.static.user_all_result[1:3] == l + l = q.order_by(User.id).limit(2).offset(1).all() + eq_(self.static.user_all_result[1:3], l) @testing.resolve_artifact_names def test_distinct(self): @@ -483,15 +476,15 @@ class EagerTest(_fixtures.FixtureTest): s = sa.union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u') mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses), lazy=False), + 'addresses':relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.id), }) sess = create_session() q = sess.query(User) def go(): - l = q.filter(s.c.u2_id==User.id).distinct().all() - assert self.static.user_address_result == l + l = q.filter(s.c.u2_id==User.id).distinct().order_by(User.id).all() + eq_(self.static.user_address_result, l) self.assert_sql_count(testing.db, go, 1) @testing.fails_on('maxdb', 'FIXME: unknown') @@ -656,9 +649,12 @@ class EagerTest(_fixtures.FixtureTest): mapper(Order, orders) mapper(User, users, properties={ - 'orders':relation(Order, backref='user', lazy=False), - 'max_order':relation(mapper(Order, max_orders, non_primary=True), lazy=False, uselist=False) + 'orders':relation(Order, backref='user', lazy=False, order_by=orders.c.id), + 'max_order':relation( + mapper(Order, max_orders, non_primary=True), + lazy=False, uselist=False) }) + q = create_session().query(User) def go(): @@ -675,7 +671,7 @@ class EagerTest(_fixtures.FixtureTest): max_order=Order(id=4) ), User(id=10), - ] == q.all() + ] == q.order_by(User.id).all() self.assert_sql_count(testing.db, go, 1) @testing.resolve_artifact_names @@ -823,15 +819,15 @@ class OrderBySecondaryTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('m2m', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('aid', Integer, ForeignKey('a.id')), Column('bid', Integer, ForeignKey('b.id'))) Table('a', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) Table('b', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) @classmethod @@ -873,8 +869,7 @@ class SelfReferentialEagerTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('nodes', metadata, - Column('id', Integer, sa.Sequence('node_id_seq', optional=True), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('parent_id', Integer, ForeignKey('nodes.id')), Column('data', String(30))) @@ -1088,11 +1083,11 @@ class MixedSelfReferentialEagerTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('a_table', metadata, - Column('id', Integer, primary_key=True) + Column('id', Integer, primary_key=True, test_needs_autoincrement=True) ) Table('b_table', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('parent_b1_id', Integer, ForeignKey('b_table.id')), Column('parent_a_id', Integer, ForeignKey('a_table.id')), Column('parent_b2_id', Integer, ForeignKey('b_table.id'))) @@ -1161,7 +1156,7 @@ class SelfReferentialM2MEagerTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('widget', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', sa.Unicode(40), nullable=False, unique=True), ) @@ -1244,7 +1239,7 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) self.assert_sql_count(testing.db, go, 1) - @testing.exclude('sqlite', '>', (0, 0, 0), "sqlite flat out blows it on the multiple JOINs") + @testing.exclude('sqlite', '>', (0, ), "sqlite flat out blows it on the multiple JOINs") @testing.resolve_artifact_names def test_two_entities_with_joins(self): sess = create_session() @@ -1337,13 +1332,13 @@ class CyclicalInheritingEagerTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('c1', Integer, primary_key=True), + Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), Column('c2', String(30)), Column('type', String(30)) ) Table('t2', metadata, - Column('c1', Integer, primary_key=True), + Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), Column('c2', String(30)), Column('type', String(30)), Column('t1.id', Integer, ForeignKey('t1.c1'))) @@ -1376,12 +1371,12 @@ class SubqueryTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('users_table', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(16)) ) Table('tags_table', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer, ForeignKey("users_table.id")), Column('score1', sa.Float), Column('score2', sa.Float), @@ -1461,16 +1456,20 @@ class CorrelatedSubqueryTest(_base.MappedTest): Exercises a variety of ways to configure this. """ + + # another argument for eagerload learning about inner joins + + __requires__ = ('correlated_outer_joins', ) @classmethod def define_tables(cls, metadata): users = Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)) ) stuff = Table('stuff', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('date', Date), Column('user_id', Integer, ForeignKey('users.id'))) @@ -1549,11 +1548,13 @@ class CorrelatedSubqueryTest(_base.MappedTest): if ondate: # the more 'relational' way to do this, join on the max date - stuff_view = select([func.max(salias.c.date).label('max_date')]).where(salias.c.user_id==users.c.id).correlate(users) + stuff_view = select([func.max(salias.c.date).label('max_date')]).\ + where(salias.c.user_id==users.c.id).correlate(users) else: # a common method with the MySQL crowd, which actually might perform better in some # cases - subquery does a limit with order by DESC, join on the id - stuff_view = select([salias.c.id]).where(salias.c.user_id==users.c.id).correlate(users).order_by(salias.c.date.desc()).limit(1) + stuff_view = select([salias.c.id]).where(salias.c.user_id==users.c.id).\ + correlate(users).order_by(salias.c.date.desc()).limit(1) if labeled == 'label': stuff_view = stuff_view.label('foo') diff --git a/test/orm/test_expire.py b/test/orm/test_expire.py index 659349897..c602ac963 100644 --- a/test/orm/test_expire.py +++ b/test/orm/test_expire.py @@ -1,7 +1,7 @@ """Attribute/instance expiration, deferral of attributes, etc.""" from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message -import gc +from sqlalchemy.test.util import gc_collect import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey, exc as sa_exc @@ -666,7 +666,7 @@ class ExpireTest(_fixtures.FixtureTest): assert self.static.user_address_result == userlist assert len(list(sess)) == 9 sess.expire_all() - gc.collect() + gc_collect() assert len(list(sess)) == 4 # since addresses were gc'ed userlist = sess.query(User).order_by(User.id).all() diff --git a/test/orm/test_generative.py b/test/orm/test_generative.py index 0efc1814e..8f61d4d14 100644 --- a/test/orm/test_generative.py +++ b/test/orm/test_generative.py @@ -70,12 +70,17 @@ class GenerativeQueryTest(_base.MappedTest): assert sess.query(func.min(foo.c.bar)).filter(foo.c.bar<30).one() == (0,) assert sess.query(func.max(foo.c.bar)).filter(foo.c.bar<30).one() == (29,) + # Py3K + #assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).__next__()[0] == 29 + #assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).__next__()[0] == 29 + # Py2K assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29 assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29 - + # end Py2K + @testing.resolve_artifact_names def test_aggregate_1(self): - if (testing.against('mysql') and + if (testing.against('mysql') and not testing.against('+zxjdbc') and testing.db.dialect.dbapi.version_info[:4] == (1, 2, 1, 'gamma')): return @@ -95,10 +100,18 @@ class GenerativeQueryTest(_base.MappedTest): def test_aggregate_3(self): query = create_session().query(Foo) + # Py3K + #avg_f = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).__next__()[0] + # Py2K avg_f = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).next()[0] + # end Py2K assert round(avg_f, 1) == 14.5 + # Py3K + #avg_o = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).__next__()[0] + # Py2K avg_o = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).next()[0] + # end Py2K assert round(avg_o, 1) == 14.5 @testing.resolve_artifact_names diff --git a/test/orm/test_instrumentation.py b/test/orm/test_instrumentation.py index b4c8f8601..6390e2596 100644 --- a/test/orm/test_instrumentation.py +++ b/test/orm/test_instrumentation.py @@ -488,7 +488,7 @@ class InstrumentationCollisionTest(_base.ORMTest): class B(A): __sa_instrumentation_manager__ = staticmethod(mgr_factory) - assert_raises(TypeError, attributes.register_class, B) + assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B) def test_single_up(self): @@ -499,7 +499,8 @@ class InstrumentationCollisionTest(_base.ORMTest): class B(A): __sa_instrumentation_manager__ = staticmethod(mgr_factory) attributes.register_class(B) - assert_raises(TypeError, attributes.register_class, A) + + assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, A) def test_diamond_b1(self): mgr_factory = lambda cls: attributes.ClassManager(cls) @@ -507,10 +508,10 @@ class InstrumentationCollisionTest(_base.ORMTest): class A(object): pass class B1(A): pass class B2(A): - __sa_instrumentation_manager__ = mgr_factory + __sa_instrumentation_manager__ = staticmethod(mgr_factory) class C(object): pass - assert_raises(TypeError, attributes.register_class, B1) + assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B1) def test_diamond_b2(self): mgr_factory = lambda cls: attributes.ClassManager(cls) @@ -518,10 +519,11 @@ class InstrumentationCollisionTest(_base.ORMTest): class A(object): pass class B1(A): pass class B2(A): - __sa_instrumentation_manager__ = mgr_factory + __sa_instrumentation_manager__ = staticmethod(mgr_factory) class C(object): pass - assert_raises(TypeError, attributes.register_class, B2) + attributes.register_class(B2) + assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B1) def test_diamond_c_b(self): mgr_factory = lambda cls: attributes.ClassManager(cls) @@ -529,12 +531,12 @@ class InstrumentationCollisionTest(_base.ORMTest): class A(object): pass class B1(A): pass class B2(A): - __sa_instrumentation_manager__ = mgr_factory + __sa_instrumentation_manager__ = staticmethod(mgr_factory) class C(object): pass attributes.register_class(C) - assert_raises(TypeError, attributes.register_class, B1) + assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B1) class OnLoadTest(_base.ORMTest): """Check that Events.on_load is not hit in regular attributes operations.""" diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py index 819f29911..8c196cfcf 100644 --- a/test/orm/test_lazy_relations.py +++ b/test/orm/test_lazy_relations.py @@ -163,9 +163,8 @@ class LazyTest(_fixtures.FixtureTest): # use a union all to get a lot of rows to join against u2 = users.alias('u2') s = sa.union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u') - print [key for key in s.c.keys()] - l = q.filter(s.c.u2_id==User.id).distinct().all() - assert self.static.user_all_result == l + l = q.filter(s.c.u2_id==User.id).order_by(User.id).distinct().all() + eq_(self.static.user_all_result, l) @testing.resolve_artifact_names def test_one_to_many_scalar(self): diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 13913578a..c34ccdbab 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -3,9 +3,8 @@ from sqlalchemy.test.testing import assert_raises, assert_raises_message import sqlalchemy as sa from sqlalchemy.test import testing, pickleable -from sqlalchemy import MetaData, Integer, String, ForeignKey, func -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy import MetaData, Integer, String, ForeignKey, func, util +from sqlalchemy.test.schema import Table, Column from sqlalchemy.engine import default from sqlalchemy.orm import mapper, relation, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased from sqlalchemy.orm import defer, deferred, synonym, attributes, column_property, composite, relation, dynamic_loader, comparable_property @@ -390,7 +389,7 @@ class MapperTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_self_ref_synonym(self): t = Table('nodes', MetaData(), - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('parent_id', Integer, ForeignKey('nodes.id'))) class Node(object): @@ -432,7 +431,7 @@ class MapperTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_prop_filters(self): t = Table('person', MetaData(), - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('type', String(128)), Column('name', String(128)), Column('employee_number', Integer), @@ -870,6 +869,7 @@ class MapperTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_comparable_column(self): class MyComparator(sa.orm.properties.ColumnProperty.Comparator): + __hash__ = None def __eq__(self, other): # lower case comparison return func.lower(self.__clause_element__()) == func.lower(other) @@ -1451,12 +1451,12 @@ class DeferredTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_group(self): """Deferred load with a group""" - mapper(Order, orders, properties={ - 'userident': deferred(orders.c.user_id, group='primary'), - 'addrident': deferred(orders.c.address_id, group='primary'), - 'description': deferred(orders.c.description, group='primary'), - 'opened': deferred(orders.c.isopen, group='primary') - }) + mapper(Order, orders, properties=util.OrderedDict([ + ('userident', deferred(orders.c.user_id, group='primary')), + ('addrident', deferred(orders.c.address_id, group='primary')), + ('description', deferred(orders.c.description, group='primary')), + ('opened', deferred(orders.c.isopen, group='primary')) + ])) sess = create_session() q = sess.query(Order).order_by(Order.id) @@ -1562,10 +1562,12 @@ class DeferredTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_undefer_group(self): - mapper(Order, orders, properties={ - 'userident':deferred(orders.c.user_id, group='primary'), - 'description':deferred(orders.c.description, group='primary'), - 'opened':deferred(orders.c.isopen, group='primary')}) + mapper(Order, orders, properties=util.OrderedDict([ + ('userident',deferred(orders.c.user_id, group='primary')), + ('description',deferred(orders.c.description, group='primary')), + ('opened',deferred(orders.c.isopen, group='primary')) + ] + )) sess = create_session() q = sess.query(Order).order_by(Order.id) @@ -1796,11 +1798,11 @@ class DeferredPopulationTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("thing", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("name", String(20))) Table("human", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("thing_id", Integer, ForeignKey("thing.id")), Column("name", String(20))) @@ -1884,13 +1886,12 @@ class CompositeTypesTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('graphs', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('version_id', Integer, primary_key=True, nullable=True), Column('name', String(30))) Table('edges', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('graph_id', Integer, nullable=False), Column('graph_version_id', Integer, nullable=False), Column('x1', Integer), @@ -1902,7 +1903,7 @@ class CompositeTypesTest(_base.MappedTest): ['graphs.id', 'graphs.version_id'])) Table('foobars', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('x1', Integer, default=2), Column('x2', Integer), Column('x3', Integer, default=15), @@ -2041,7 +2042,7 @@ class CompositeTypesTest(_base.MappedTest): # test pk with one column NULL # TODO: can't seem to get NULL in for a PK value - # in either mysql or postgres, autoincrement=False etc. + # in either mysql or postgresql, autoincrement=False etc. # notwithstanding @testing.fails_on_everything_except("sqlite") def go(): @@ -2475,33 +2476,26 @@ class RequirementsTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('ht1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('value', String(10))) Table('ht2', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('ht1_id', Integer, ForeignKey('ht1.id')), Column('value', String(10))) Table('ht3', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('value', String(10))) Table('ht4', metadata, - Column('ht1_id', Integer, ForeignKey('ht1.id'), - primary_key=True), - Column('ht3_id', Integer, ForeignKey('ht3.id'), - primary_key=True)) + Column('ht1_id', Integer, ForeignKey('ht1.id'), primary_key=True), + Column('ht3_id', Integer, ForeignKey('ht3.id'), primary_key=True)) Table('ht5', metadata, - Column('ht1_id', Integer, ForeignKey('ht1.id'), - primary_key=True)) + Column('ht1_id', Integer, ForeignKey('ht1.id'), primary_key=True)) Table('ht6', metadata, - Column('ht1a_id', Integer, ForeignKey('ht1.id'), - primary_key=True), - Column('ht1b_id', Integer, ForeignKey('ht1.id'), - primary_key=True), + Column('ht1a_id', Integer, ForeignKey('ht1.id'), primary_key=True), + Column('ht1b_id', Integer, ForeignKey('ht1.id'), primary_key=True), Column('value', String(10))) + # Py2K @testing.resolve_artifact_names def test_baseclass(self): class OldStyle: @@ -2516,7 +2510,8 @@ class RequirementsTest(_base.MappedTest): # TODO: is weakref support detectable without an instance? #self.assertRaises(sa.exc.ArgumentError, mapper, NoWeakrefSupport, t2) - + # end Py2K + @testing.resolve_artifact_names def test_comparison_overrides(self): """Simple tests to ensure users can supply comparison __methods__. @@ -2618,12 +2613,12 @@ class MagicNamesTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('cartographers', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('alias', String(50)), Column('quip', String(100))) Table('maps', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('cart_id', Integer, ForeignKey('cartographers.id')), Column('state', String(2)), @@ -2665,7 +2660,7 @@ class MagicNamesTest(_base.MappedTest): for reserved in (sa.orm.attributes.ClassManager.STATE_ATTR, sa.orm.attributes.ClassManager.MANAGER_ATTR): t = Table('t', sa.MetaData(), - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column(reserved, Integer)) class T(object): pass diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py index f4e3872b0..5433515ca 100644 --- a/test/orm/test_merge.py +++ b/test/orm/test_merge.py @@ -1,13 +1,13 @@ from sqlalchemy.test.testing import assert_raises, assert_raises_message import sqlalchemy as sa -from sqlalchemy import Table, Column, Integer, PickleType +from sqlalchemy import Integer, PickleType import operator from sqlalchemy.test import testing from sqlalchemy.util import OrderedSet from sqlalchemy.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property, sessionmaker from sqlalchemy.test.testing import eq_, ne_ from test.orm import _base, _fixtures - +from sqlalchemy.test.schema import Table, Column class MergeTest(_fixtures.FixtureTest): """Session.merge() functionality""" @@ -103,6 +103,7 @@ class MergeTest(_fixtures.FixtureTest): 'addresses':relation(Address, backref='user', collection_class=OrderedSet, + order_by=addresses.c.id, cascade="all, delete-orphan") }) mapper(Address, addresses) @@ -154,6 +155,7 @@ class MergeTest(_fixtures.FixtureTest): mapper(User, users, properties={ 'addresses':relation(Address, backref='user', + order_by=addresses.c.id, collection_class=OrderedSet)}) mapper(Address, addresses) on_load = self.on_load_tracker(User) @@ -300,20 +302,20 @@ class MergeTest(_fixtures.FixtureTest): # test with "dontload" merge sess5 = create_session() - u = sess5.merge(u, dont_load=True) + u = sess5.merge(u, load=False) assert len(u.addresses) for a in u.addresses: assert a.user is u def go(): sess5.flush() # no changes; therefore flush should do nothing - # but also, dont_load wipes out any difference in committed state, + # but also, load=False wipes out any difference in committed state, # so no flush at all self.assert_sql_count(testing.db, go, 0) eq_(on_load.called, 15) sess4 = create_session() - u = sess4.merge(u, dont_load=True) + u = sess4.merge(u, load=False) # post merge change u.addresses[1].email_address='afafds' def go(): @@ -445,17 +447,35 @@ class MergeTest(_fixtures.FixtureTest): assert u3 is u @testing.resolve_artifact_names - def test_transient_dontload(self): + def test_transient_no_load(self): mapper(User, users) sess = create_session() u = User() - assert_raises_message(sa.exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True) + assert_raises_message(sa.exc.InvalidRequestError, "load=False option does not support", sess.merge, u, load=False) + @testing.resolve_artifact_names + def test_dont_load_deprecated(self): + mapper(User, users) + + sess = create_session() + u = User(name='ed') + sess.add(u) + sess.flush() + u = sess.query(User).first() + sess.expunge(u) + sess.execute(users.update().values(name='jack')) + @testing.uses_deprecated("dont_load=True has been renamed") + def go(): + u1 = sess.merge(u, dont_load=True) + assert u1 in sess + assert u1.name=='ed' + assert u1 not in sess.dirty + go() @testing.resolve_artifact_names - def test_dontload_with_backrefs(self): - """dontload populates relations in both directions without requiring a load""" + def test_no_load_with_backrefs(self): + """load=False populates relations in both directions without requiring a load""" mapper(User, users, properties={ 'addresses':relation(mapper(Address, addresses), backref='user') }) @@ -470,7 +490,7 @@ class MergeTest(_fixtures.FixtureTest): assert 'user' in u.addresses[1].__dict__ sess = create_session() - u2 = sess.merge(u, dont_load=True) + u2 = sess.merge(u, load=False) assert 'user' in u2.addresses[1].__dict__ eq_(u2.addresses[1].user, User(id=7, name='fred')) @@ -479,7 +499,7 @@ class MergeTest(_fixtures.FixtureTest): sess.close() sess = create_session() - u = sess.merge(u2, dont_load=True) + u = sess.merge(u2, load=False) assert 'user' not in u.addresses[1].__dict__ eq_(u.addresses[1].user, User(id=7, name='fred')) @@ -488,12 +508,12 @@ class MergeTest(_fixtures.FixtureTest): def test_dontload_with_eager(self): """ - This test illustrates that with dont_load=True, we can't just copy the + This test illustrates that with load=False, we can't just copy the committed_state of the merged instance over; since it references collection objects which themselves are to be merged. This committed_state would instead need to be piecemeal 'converted' to represent the correct objects. However, at the moment I'd rather not - support this use case; if you are merging with dont_load=True, you're + support this use case; if you are merging with load=False, you're typically dealing with caching and the merged objects shouldnt be 'dirty'. @@ -516,16 +536,16 @@ class MergeTest(_fixtures.FixtureTest): u2 = sess2.query(User).options(sa.orm.eagerload('addresses')).get(7) sess3 = create_session() - u3 = sess3.merge(u2, dont_load=True) + u3 = sess3.merge(u2, load=False) def go(): sess3.flush() self.assert_sql_count(testing.db, go, 0) @testing.resolve_artifact_names - def test_dont_load_disallows_dirty(self): - """dont_load doesnt support 'dirty' objects right now + def test_no_load_disallows_dirty(self): + """load=False doesnt support 'dirty' objects right now - (see test_dont_load_with_eager()). Therefore lets assert it. + (see test_no_load_with_eager()). Therefore lets assert it. """ mapper(User, users) @@ -539,17 +559,17 @@ class MergeTest(_fixtures.FixtureTest): u.name = 'ed' sess2 = create_session() try: - sess2.merge(u, dont_load=True) + sess2.merge(u, load=False) assert False except sa.exc.InvalidRequestError, e: - assert ("merge() with dont_load=True option does not support " + assert ("merge() with load=False option does not support " "objects marked as 'dirty'. flush() all changes on mapped " - "instances before merging with dont_load=True.") in str(e) + "instances before merging with load=False.") in str(e) u2 = sess2.query(User).get(7) sess3 = create_session() - u3 = sess3.merge(u2, dont_load=True) + u3 = sess3.merge(u2, load=False) assert not sess3.dirty def go(): sess3.flush() @@ -557,7 +577,7 @@ class MergeTest(_fixtures.FixtureTest): @testing.resolve_artifact_names - def test_dont_load_sets_backrefs(self): + def test_no_load_sets_backrefs(self): mapper(User, users, properties={ 'addresses':relation(mapper(Address, addresses),backref='user')}) @@ -575,17 +595,17 @@ class MergeTest(_fixtures.FixtureTest): assert u.addresses[0].user is u sess2 = create_session() - u2 = sess2.merge(u, dont_load=True) + u2 = sess2.merge(u, load=False) assert not sess2.dirty def go(): assert u2.addresses[0].user is u2 self.assert_sql_count(testing.db, go, 0) @testing.resolve_artifact_names - def test_dont_load_preserves_parents(self): - """Merge with dont_load does not trigger a 'delete-orphan' operation. + def test_no_load_preserves_parents(self): + """Merge with load=False does not trigger a 'delete-orphan' operation. - merge with dont_load sets attributes without using events. this means + merge with load=False sets attributes without using events. this means the 'hasparent' flag is not propagated to the newly merged instance. in fact this works out OK, because the '_state.parents' collection on the newly merged instance is empty; since the mapper doesn't see an @@ -610,7 +630,7 @@ class MergeTest(_fixtures.FixtureTest): assert u.addresses[0].user is u sess2 = create_session() - u2 = sess2.merge(u, dont_load=True) + u2 = sess2.merge(u, load=False) assert not sess2.dirty a2 = u2.addresses[0] a2.email_address='somenewaddress' @@ -624,19 +644,19 @@ class MergeTest(_fixtures.FixtureTest): # this use case is not supported; this is with a pending Address on # the pre-merged object, and we currently dont support 'dirty' objects - # being merged with dont_load=True. in this case, the empty + # being merged with load=False. in this case, the empty # '_state.parents' collection would be an issue, since the optimistic # flag is False in _is_orphan() for pending instances. so if we start - # supporting 'dirty' with dont_load=True, this test will need to pass + # supporting 'dirty' with load=False, this test will need to pass sess = create_session() u = sess.query(User).get(7) u.addresses.append(Address()) sess2 = create_session() try: - u2 = sess2.merge(u, dont_load=True) + u2 = sess2.merge(u, load=False) assert False - # if dont_load is changed to support dirty objects, this code + # if load=False is changed to support dirty objects, this code # needs to pass a2 = u2.addresses[0] a2.email_address='somenewaddress' @@ -647,7 +667,7 @@ class MergeTest(_fixtures.FixtureTest): eq_(sess2.query(User).get(u2.id).addresses[0].email_address, 'somenewaddress') except sa.exc.InvalidRequestError, e: - assert "dont_load=True option does not support" in str(e) + assert "load=False option does not support" in str(e) @testing.resolve_artifact_names def test_synonym_comparable(self): @@ -737,7 +757,7 @@ class MutableMergeTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("data", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', PickleType(comparator=operator.eq)) ) diff --git a/test/orm/test_naturalpks.py b/test/orm/test_naturalpks.py index 1376c402e..e99bfb794 100644 --- a/test/orm/test_naturalpks.py +++ b/test/orm/test_naturalpks.py @@ -6,8 +6,7 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session from sqlalchemy.test.testing import eq_ from test.orm import _base @@ -16,6 +15,11 @@ class NaturalPKTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): + if testing.against('oracle'): + fk_args = dict(deferrable=True, initially='deferred') + else: + fk_args = dict(onupdate='cascade') + users = Table('users', metadata, Column('username', String(50), primary_key=True), Column('fullname', String(100)), @@ -23,7 +27,7 @@ class NaturalPKTest(_base.MappedTest): addresses = Table('addresses', metadata, Column('email', String(50), primary_key=True), - Column('username', String(50), ForeignKey('users.username', onupdate="cascade")), + Column('username', String(50), ForeignKey('users.username', **fk_args)), test_needs_fk=True) items = Table('items', metadata, @@ -32,8 +36,8 @@ class NaturalPKTest(_base.MappedTest): test_needs_fk=True) users_to_items = Table('users_to_items', metadata, - Column('username', String(50), ForeignKey('users.username', onupdate='cascade'), primary_key=True), - Column('itemname', String(50), ForeignKey('items.itemname', onupdate='cascade'), primary_key=True), + Column('username', String(50), ForeignKey('users.username', **fk_args), primary_key=True), + Column('itemname', String(50), ForeignKey('items.itemname', **fk_args), primary_key=True), test_needs_fk=True) @classmethod @@ -110,6 +114,7 @@ class NaturalPKTest(_base.MappedTest): @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE') def test_onetomany_passive(self): self._test_onetomany(True) @@ -161,6 +166,7 @@ class NaturalPKTest(_base.MappedTest): @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE') def test_manytoone_passive(self): self._test_manytoone(True) @@ -203,6 +209,7 @@ class NaturalPKTest(_base.MappedTest): eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE') def test_onetoone_passive(self): self._test_onetoone(True) @@ -244,6 +251,7 @@ class NaturalPKTest(_base.MappedTest): eq_([Address(username='ed')], sess.query(Address).all()) @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE') def test_bidirectional_passive(self): self._test_bidirectional(True) @@ -298,10 +306,12 @@ class NaturalPKTest(_base.MappedTest): @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE') def test_manytomany_passive(self): self._test_manytomany(True) - @testing.fails_on('mysql', 'the executemany() of the association table fails to report the correct row count') + # mysqldb executemany() of the association table fails to report the correct row count + @testing.fails_if(lambda: testing.against('mysql') and not testing.against('+zxjdbc')) def test_manytomany_nonpassive(self): self._test_manytomany(False) @@ -361,10 +371,15 @@ class SelfRefTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): + if testing.against('oracle'): + fk_args = dict(deferrable=True, initially='deferred') + else: + fk_args = dict(onupdate='cascade') + Table('nodes', metadata, Column('name', String(50), primary_key=True), Column('parent', String(50), - ForeignKey('nodes.name', onupdate='cascade'))) + ForeignKey('nodes.name', **fk_args))) @classmethod def setup_classes(cls): @@ -400,17 +415,22 @@ class SelfRefTest(_base.MappedTest): class NonPKCascadeTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): + if testing.against('oracle'): + fk_args = dict(deferrable=True, initially='deferred') + else: + fk_args = dict(onupdate='cascade') + Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('username', String(50), unique=True), Column('fullname', String(100)), test_needs_fk=True) Table('addresses', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('email', String(50)), Column('username', String(50), - ForeignKey('users.username', onupdate="cascade")), + ForeignKey('users.username', **fk_args)), test_needs_fk=True ) @@ -422,6 +442,7 @@ class NonPKCascadeTest(_base.MappedTest): pass @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE') def test_onetomany_passive(self): self._test_onetomany(True) diff --git a/test/orm/test_onetoone.py b/test/orm/test_onetoone.py index 0d66915ea..6880f1f74 100644 --- a/test/orm/test_onetoone.py +++ b/test/orm/test_onetoone.py @@ -1,8 +1,7 @@ import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session from test.orm import _base @@ -11,13 +10,13 @@ class O2OTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('jack', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('number', String(50)), Column('status', String(20)), Column('subroom', String(5))) Table('port', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(30)), Column('description', String(100)), Column('jack_id', Integer, ForeignKey("jack.id"))) diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py index 5343cc15b..6ac9f2470 100644 --- a/test/orm/test_pickled.py +++ b/test/orm/test_pickled.py @@ -3,8 +3,7 @@ import pickle import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session, attributes from test.orm import _base, _fixtures @@ -60,7 +59,7 @@ class PickleTest(_fixtures.FixtureTest): u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() - u2 = sess2.merge(u2, dont_load=True) + u2 = sess2.merge(u2, load=False) eq_(u2.name, 'ed') eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) @@ -94,7 +93,7 @@ class PickleTest(_fixtures.FixtureTest): u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() - u2 = sess2.merge(u2, dont_load=True) + u2 = sess2.merge(u2, load=False) eq_(u2.name, 'ed') assert 'addresses' not in u2.__dict__ ad = u2.addresses[0] @@ -136,7 +135,7 @@ class PolymorphicDeferredTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(30)), Column('type', String(30))) Table('email_users', metadata, diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 88a95bf76..8cb7ef969 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -109,13 +109,8 @@ class GetTest(QueryTest): pass s = users.select(users.c.id!=12).alias('users') m = mapper(SomeUser, s) - print s.primary_key - print m.primary_key assert s.primary_key == m.primary_key - row = s.select(use_labels=True).execute().fetchone() - print row[s.primary_key[0]] - sess = create_session() assert sess.query(SomeUser).get(7).name == 'jack' @@ -145,15 +140,20 @@ class GetTest(QueryTest): @testing.requires.unicode_connections def test_unicode(self): """test that Query.get properly sets up the type for the bind parameter. using unicode would normally fail - on postgres, mysql and oracle unless it is converted to an encoded string""" + on postgresql, mysql and oracle unless it is converted to an encoded string""" metadata = MetaData(engines.utf8_engine()) table = Table('unicode_data', metadata, - Column('id', Unicode(40), primary_key=True), + Column('id', Unicode(40), primary_key=True, test_needs_autoincrement=True), Column('data', Unicode(40))) try: metadata.create_all() + # Py3K + #ustring = 'petit voix m\xe2\x80\x99a' + # Py2K ustring = 'petit voix m\xe2\x80\x99a'.decode('utf-8') + # end Py2K + table.insert().execute(id=ustring, data=ustring) class LocalFoo(Base): pass @@ -195,7 +195,7 @@ class GetTest(QueryTest): assert u.addresses[0].email_address == 'jack@bean.com' assert u.orders[1].items[2].description == 'item 5' - @testing.fails_on_everything_except('sqlite', 'mssql') + @testing.fails_on_everything_except('sqlite', '+pyodbc', '+zxjdbc') def test_query_str(self): s = create_session() q = s.query(User).filter(User.id==1) @@ -299,7 +299,12 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): def test_arithmetic(self): create_session().query(User) for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'), - (operator.sub, '-'), (operator.div, '/'), + (operator.sub, '-'), + # Py3k + #(operator.truediv, '/'), + # Py2K + (operator.div, '/'), + # end Py2K ): for (lhs, rhs, res) in ( (5, User.id, ':id_1 %s users.id'), @@ -489,10 +494,16 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): sess = create_session() self.assert_compile(sess.query(users).select_from(users.select()).with_labels().statement, - "SELECT users.id AS users_id, users.name AS users_name FROM users, (SELECT users.id AS id, users.name AS name FROM users) AS anon_1") + "SELECT users.id AS users_id, users.name AS users_name FROM users, " + "(SELECT users.id AS id, users.name AS name FROM users) AS anon_1", + dialect=default.DefaultDialect() + ) self.assert_compile(sess.query(users, exists([1], from_obj=addresses)).with_labels().statement, - "SELECT users.id AS users_id, users.name AS users_name, EXISTS (SELECT 1 FROM addresses) AS anon_1 FROM users") + "SELECT users.id AS users_id, users.name AS users_name, EXISTS " + "(SELECT 1 FROM addresses) AS anon_1 FROM users", + dialect=default.DefaultDialect() + ) # a little tedious here, adding labels to work around Query's auto-labelling. # also correlate needed explicitly. hmmm..... @@ -687,15 +698,19 @@ class FilterTest(QueryTest): sess = create_session() assert [Address(id=5)] == sess.query(Address).filter(Address.user.has(name='fred')).all() - assert [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'))).all() + assert [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] == \ + sess.query(Address).filter(Address.user.has(User.name.like('%ed%'))).order_by(Address.id).all() - assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).all() + assert [Address(id=2), Address(id=3), Address(id=4)] == \ + sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).order_by(Address.id).all() # test has() doesn't overcorrelate - assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user").filter(Address.user.has(User.name.like('%ed%'), id=8)).all() + assert [Address(id=2), Address(id=3), Address(id=4)] == \ + sess.query(Address).join("user").filter(Address.user.has(User.name.like('%ed%'), id=8)).order_by(Address.id).all() # test has() doesnt' get subquery contents adapted by aliased join - assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user", aliased=True).filter(Address.user.has(User.name.like('%ed%'), id=8)).all() + assert [Address(id=2), Address(id=3), Address(id=4)] == \ + sess.query(Address).join("user", aliased=True).filter(Address.user.has(User.name.like('%ed%'), id=8)).order_by(Address.id).all() dingaling = sess.query(Dingaling).get(2) assert [User(id=9)] == sess.query(User).filter(User.addresses.any(Address.dingaling==dingaling)).all() @@ -730,8 +745,8 @@ class FilterTest(QueryTest): assert [Address(id=5)] == sess.query(Address).filter(Address.dingaling==dingaling).all() # m2m - eq_(sess.query(Item).filter(Item.keywords==None).all(), [Item(id=4), Item(id=5)]) - eq_(sess.query(Item).filter(Item.keywords!=None).all(), [Item(id=1),Item(id=2), Item(id=3)]) + eq_(sess.query(Item).filter(Item.keywords==None).order_by(Item.id).all(), [Item(id=4), Item(id=5)]) + eq_(sess.query(Item).filter(Item.keywords!=None).order_by(Item.id).all(), [Item(id=1),Item(id=2), Item(id=3)]) def test_filter_by(self): sess = create_session() @@ -748,8 +763,9 @@ class FilterTest(QueryTest): sess = create_session() # o2o - eq_([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all()) - eq_([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all()) + eq_([Address(id=1), Address(id=3), Address(id=4)], + sess.query(Address).filter(Address.dingaling==None).order_by(Address.id).all()) + eq_([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).order_by(Address.id).all()) # m2o eq_([Order(id=5)], sess.query(Order).filter(Order.address==None).all()) @@ -806,11 +822,15 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): s = create_session() + oracle_as = not testing.against('oracle') and "AS " or "" + self.assert_compile( s.query(User).options(eagerload(User.addresses)).from_self().statement, "SELECT anon_1.users_id, anon_1.users_name, addresses_1.id, addresses_1.user_id, "\ - "addresses_1.email_address FROM (SELECT users.id AS users_id, users.name AS users_name FROM users) AS anon_1 "\ - "LEFT OUTER JOIN addresses AS addresses_1 ON anon_1.users_id = addresses_1.user_id ORDER BY addresses_1.id" + "addresses_1.email_address FROM (SELECT users.id AS users_id, users.name AS users_name FROM users) %(oracle_as)sanon_1 "\ + "LEFT OUTER JOIN addresses %(oracle_as)saddresses_1 ON anon_1.users_id = addresses_1.user_id ORDER BY addresses_1.id" % { + 'oracle_as':oracle_as + } ) def test_aliases(self): @@ -987,8 +1007,14 @@ class CountTest(QueryTest): class DistinctTest(QueryTest): def test_basic(self): - assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).distinct().all() - assert [User(id=7), User(id=9), User(id=8),User(id=10)] == create_session().query(User).distinct().order_by(desc(User.name)).all() + eq_( + [User(id=7), User(id=8), User(id=9),User(id=10)], + create_session().query(User).order_by(User.id).distinct().all() + ) + eq_( + [User(id=7), User(id=9), User(id=8),User(id=10)], + create_session().query(User).distinct().order_by(desc(User.name)).all() + ) def test_joined(self): """test that orderbys from a joined table get placed into the columns clause when DISTINCT is used""" @@ -1017,7 +1043,6 @@ class DistinctTest(QueryTest): class YieldTest(QueryTest): def test_basic(self): - import gc sess = create_session() q = iter(sess.query(User).yield_per(1).from_statement("select * from users")) @@ -1447,11 +1472,11 @@ class MultiplePathTest(_base.MappedTest): def define_tables(cls, metadata): global t1, t2, t1t2_1, t1t2_2 t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)) ) t2 = Table('t2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)) ) @@ -1715,7 +1740,7 @@ class MixedEntitiesTest(QueryTest): eq_(list(q2), [(u'jack',), (u'ed',)]) q = sess.query(User) - q2 = q.order_by(User.id).values(User.name, User.name + " " + cast(User.id, String)) + q2 = q.order_by(User.id).values(User.name, User.name + " " + cast(User.id, String(50))) eq_(list(q2), [(u'jack', u'jack 7'), (u'ed', u'ed 8'), (u'fred', u'fred 9'), (u'chuck', u'chuck 10')]) q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(User.id, Address.id).values(User.name, Address.email_address) @@ -1755,6 +1780,9 @@ class MixedEntitiesTest(QueryTest): eq_(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')]) @testing.fails_on('mssql', 'FIXME: unknown') + @testing.fails_on('oracle', "Oracle doesn't support boolean expressions as columns") + @testing.fails_on('postgresql+pg8000', "pg8000 parses the SQL itself before passing on to PG, doesn't parse this") + @testing.fails_on('postgresql+zxjdbc', "zxjdbc parses the SQL itself before passing on to PG, doesn't parse this") def test_values_with_boolean_selects(self): """Tests a values clause that works with select boolean evaluations""" sess = create_session() @@ -1763,6 +1791,10 @@ class MixedEntitiesTest(QueryTest): q2 = q.group_by([User.name.like('%j%')]).order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'), func.count(User.name.like('%j%'))) eq_(list(q2), [(True, 1), (False, 3)]) + q2 = q.order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%')) + eq_(list(q2), [(True,), (False,), (False,), (False,)]) + + def test_correlated_subquery(self): """test that a subquery constructed from ORM attributes doesn't leak out those entities to the outermost query. @@ -2514,12 +2546,14 @@ class SelfReferentialTest(_base.MappedTest): sess = create_session() eq_(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), []) eq_(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')]) - eq_(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),]) + eq_(sess.query(Node).filter(~Node.children.any()).order_by(Node.id).all(), + [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),]) def test_has(self): sess = create_session() - eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) + eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).order_by(Node.id).all(), + [Node(data='n121'),Node(data='n122'),Node(data='n123')]) eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), []) eq_(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')]) @@ -2660,7 +2694,7 @@ class ExternalColumnsTest(QueryTest): for x in range(2): sess.expunge_all() def go(): - eq_(sess.query(Address).options(eagerload('user')).all(), address_result) + eq_(sess.query(Address).options(eagerload('user')).order_by(Address.id).all(), address_result) self.assert_sql_count(testing.db, go, 1) ualias = aliased(User) @@ -2691,7 +2725,9 @@ class ExternalColumnsTest(QueryTest): ) ua = aliased(User) - eq_(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(), + eq_(sess.query(Address, ua.concat, ua.count). + select_from(join(Address, ua, 'user')). + options(eagerload(Address.user)).order_by(Address.id).all(), [ (Address(id=1, user=User(id=7, concat=14, count=1)), 14, 1), (Address(id=2, user=User(id=8, concat=16, count=3)), 16, 3), @@ -2742,7 +2778,7 @@ class TestOverlyEagerEquivalentCols(_base.MappedTest): def define_tables(cls, metadata): global base, sub1, sub2 base = Table('base', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)) ) @@ -2800,12 +2836,12 @@ class UpdateDeleteTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(32)), Column('age', Integer)) Table('documents', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', None, ForeignKey('users.id')), Column('title', String(32))) @@ -2875,7 +2911,7 @@ class UpdateDeleteTest(_base.MappedTest): sess = create_session(bind=testing.db, autocommit=False) john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter('name = :name').params(name='john').delete() + sess.query(User).filter('name = :name').params(name='john').delete('fetch') assert john not in sess eq_(sess.query(User).order_by(User.id).all(), [jack,jill,jane]) @@ -2922,12 +2958,17 @@ class UpdateDeleteTest(_base.MappedTest): @testing.fails_on('mysql', 'FIXME: unknown') @testing.resolve_artifact_names - def test_delete_fallback(self): + def test_delete_invalid_evaluation(self): sess = create_session(bind=testing.db, autocommit=False) john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(User.name == select([func.max(User.name)])).delete(synchronize_session='evaluate') + assert_raises(sa_exc.InvalidRequestError, + sess.query(User).filter(User.name == select([func.max(User.name)])).delete, synchronize_session='evaluate' + ) + + sess.query(User).filter(User.name == select([func.max(User.name)])).delete(synchronize_session='fetch') + assert john not in sess eq_(sess.query(User).order_by(User.id).all(), [jack,jill,jane]) @@ -2957,7 +2998,7 @@ class UpdateDeleteTest(_base.MappedTest): john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter('age > :x').params(x=29).update({'age': User.age - 10}, synchronize_session='evaluate') + sess.query(User).filter('age > :x').params(x=29).update({'age': User.age - 10}, synchronize_session='fetch') eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) @@ -3017,11 +3058,12 @@ class UpdateDeleteTest(_base.MappedTest): sess = create_session(bind=testing.db, autocommit=False) john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='expire') + sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='fetch') eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) + @testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount) @testing.resolve_artifact_names def test_update_returns_rowcount(self): sess = create_session(bind=testing.db, autocommit=False) @@ -3032,6 +3074,7 @@ class UpdateDeleteTest(_base.MappedTest): rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age - 10}) eq_(rowcount, 2) + @testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount) @testing.resolve_artifact_names def test_delete_returns_rowcount(self): sess = create_session(bind=testing.db, autocommit=False) @@ -3046,7 +3089,7 @@ class UpdateDeleteTest(_base.MappedTest): sess = create_session(bind=testing.db, autocommit=False) foo,bar,baz = sess.query(Document).order_by(Document.id).all() - sess.query(Document).filter(Document.user_id == 1).update({'title': Document.title+Document.title}, synchronize_session='expire') + sess.query(Document).filter(Document.user_id == 1).update({'title': Document.title+Document.title}, synchronize_session='fetch') eq_([foo.title, bar.title, baz.title], ['foofoo','barbar', 'baz']) eq_(sess.query(Document.title).order_by(Document.id).all(), zip(['foofoo','barbar', 'baz'])) @@ -3056,7 +3099,7 @@ class UpdateDeleteTest(_base.MappedTest): sess = create_session(bind=testing.db, autocommit=False) john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).options(eagerload(User.documents)).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='expire') + sess.query(User).options(eagerload(User.documents)).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='fetch') eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index fef1577f0..481deb81b 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -3,8 +3,7 @@ import datetime import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey, MetaData, and_ -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers, sessionmaker from sqlalchemy.test.testing import eq_, startswith_ from test.orm import _base, _fixtures @@ -32,17 +31,17 @@ class RelationTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("tbl_a", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("name", String(128))) Table("tbl_b", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("name", String(128))) Table("tbl_c", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("tbl_a_id", Integer, ForeignKey("tbl_a.id"), nullable=False), Column("name", String(128))) Table("tbl_d", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("tbl_c_id", Integer, ForeignKey("tbl_c.id"), nullable=False), Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")), Column("name", String(128))) @@ -132,7 +131,7 @@ class RelationTest2(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('company_t', metadata, - Column('company_id', Integer, primary_key=True), + Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', sa.Unicode(30))) Table('employee_t', metadata, @@ -395,7 +394,7 @@ class RelationTest4(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("tableA", metadata, - Column("id",Integer,primary_key=True), + Column("id",Integer,primary_key=True, test_needs_autoincrement=True), Column("foo",Integer,), test_needs_fk=True) Table("tableB",metadata, @@ -456,7 +455,7 @@ class RelationTest4(_base.MappedTest): @testing.fails_on_everything_except('sqlite', 'mysql') @testing.resolve_artifact_names def test_nullPKsOK_BtoA(self): - # postgres cant handle a nullable PK column...? + # postgresql cant handle a nullable PK column...? tableC = Table('tablec', tableA.metadata, Column('id', Integer, primary_key=True), Column('a_id', Integer, ForeignKey('tableA.id'), @@ -642,12 +641,12 @@ class RelationTest6(_base.MappedTest): @classmethod def define_tables(cls, metadata): - Table('tags', metadata, Column("id", Integer, primary_key=True), + Table('tags', metadata, Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("data", String(50)), ) Table('tag_foo', metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column('tagid', Integer), Column("data", String(50)), ) @@ -691,11 +690,11 @@ class BackrefPropagatesForwardsArgs(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)) ) Table('addresses', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer), Column('email', String(50)) ) @@ -738,7 +737,7 @@ class AmbiguousJoinInterpretedAsSelfRef(_base.MappedTest): @classmethod def define_tables(cls, metadata): subscriber_table = Table('subscriber', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('dummy', String(10)) # to appease older sqlite version ) @@ -947,18 +946,18 @@ class TypeMatchTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("a", metadata, - Column('aid', Integer, primary_key=True), + Column('aid', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30))) Table("b", metadata, - Column('bid', Integer, primary_key=True), + Column('bid', Integer, primary_key=True, test_needs_autoincrement=True), Column("a_id", Integer, ForeignKey("a.aid")), Column('data', String(30))) Table("c", metadata, - Column('cid', Integer, primary_key=True), + Column('cid', Integer, primary_key=True, test_needs_autoincrement=True), Column("b_id", Integer, ForeignKey("b.bid")), Column('data', String(30))) Table("d", metadata, - Column('did', Integer, primary_key=True), + Column('did', Integer, primary_key=True, test_needs_autoincrement=True), Column("a_id", Integer, ForeignKey("a.aid")), Column('data', String(30))) @@ -1116,14 +1115,14 @@ class ViewOnlyOverlappingNames(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("t1", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(40))) Table("t2", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(40)), Column('t1id', Integer, ForeignKey('t1.id'))) Table("t3", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(40)), Column('t2id', Integer, ForeignKey('t2.id'))) @@ -1176,14 +1175,14 @@ class ViewOnlyUniqueNames(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("t1", metadata, - Column('t1id', Integer, primary_key=True), + Column('t1id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(40))) Table("t2", metadata, - Column('t2id', Integer, primary_key=True), + Column('t2id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(40)), Column('t1id_ref', Integer, ForeignKey('t1.t1id'))) Table("t3", metadata, - Column('t3id', Integer, primary_key=True), + Column('t3id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(40)), Column('t2id_ref', Integer, ForeignKey('t2.t2id'))) @@ -1309,12 +1308,12 @@ class ViewOnlyRepeatedRemoteColumn(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('foos', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('bid1', Integer,ForeignKey('bars.id')), Column('bid2', Integer,ForeignKey('bars.id'))) Table('bars', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) @testing.resolve_artifact_names @@ -1357,10 +1356,10 @@ class ViewOnlyRepeatedLocalColumn(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('foos', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) - Table('bars', metadata, Column('id', Integer, primary_key=True), + Table('bars', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('fid1', Integer, ForeignKey('foos.id')), Column('fid2', Integer, ForeignKey('foos.id')), Column('data', String(50))) @@ -1405,14 +1404,14 @@ class ViewOnlyComplexJoin(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) Table('t2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)), Column('t1id', Integer, ForeignKey('t1.id'))) Table('t3', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) Table('t2tot3', metadata, Column('t2id', Integer, ForeignKey('t2.id')), @@ -1476,10 +1475,10 @@ class ExplicitLocalRemoteTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('id', String(50), primary_key=True), + Column('id', String(50), primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) Table('t2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)), Column('t1id', String(50))) diff --git a/test/orm/test_scoping.py b/test/orm/test_scoping.py index 9f2f59e19..0d6b3deae 100644 --- a/test/orm/test_scoping.py +++ b/test/orm/test_scoping.py @@ -3,13 +3,13 @@ import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy.orm import scoped_session from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, query from sqlalchemy.test.testing import eq_ from test.orm import _base + class _ScopedTest(_base.MappedTest): """Adds another lookup bucket to emulate Session globals.""" @@ -34,10 +34,10 @@ class ScopedSessionTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('table1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30))) Table('table2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('someid', None, ForeignKey('table1.id'))) @testing.resolve_artifact_names @@ -82,10 +82,10 @@ class ScopedMapperTest(_ScopedTest): @classmethod def define_tables(cls, metadata): Table('table1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30))) Table('table2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('someid', None, ForeignKey('table1.id'))) @classmethod @@ -204,11 +204,11 @@ class ScopedMapperTest2(_ScopedTest): @classmethod def define_tables(cls, metadata): Table('table1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('type', String(30))) Table('table2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('someid', None, ForeignKey('table1.id')), Column('somedata', String(30))) diff --git a/test/orm/test_selectable.py b/test/orm/test_selectable.py index 0a2025360..bfa400895 100644 --- a/test/orm/test_selectable.py +++ b/test/orm/test_selectable.py @@ -3,8 +3,7 @@ from sqlalchemy.test.testing import assert_raises, assert_raises_message import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import String, Integer, select -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, create_session from sqlalchemy.test.testing import eq_ from test.orm import _base @@ -16,7 +15,7 @@ class SelectableNoFromsTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('common', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', Integer), Column('extra', String(45))) diff --git a/test/orm/test_session.py b/test/orm/test_session.py index 328cbee8e..2d99e2063 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -1,13 +1,12 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message -import gc +from sqlalchemy.test.util import gc_collect import inspect import pickle from sqlalchemy.orm import create_session, sessionmaker, attributes import sqlalchemy as sa from sqlalchemy.test import engines, testing, config from sqlalchemy import Integer, String, Sequence -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, backref, eagerload from sqlalchemy.test.testing import eq_ from test.engine import _base as engine_base @@ -229,7 +228,7 @@ class SessionTest(_fixtures.FixtureTest): u = sess.query(User).get(u.id) q = sess.query(Address).filter(Address.user==u) del u - gc.collect() + gc_collect() eq_(q.one(), Address(email_address='foo')) @@ -381,18 +380,18 @@ class SessionTest(_fixtures.FixtureTest): session = create_session(bind=testing.db) session.begin() - session.connection().execute("insert into users (name) values ('user1')") + session.connection().execute(users.insert().values(name='user1')) session.begin(subtransactions=True) session.begin_nested() - session.connection().execute("insert into users (name) values ('user2')") + session.connection().execute(users.insert().values(name='user2')) assert session.connection().execute("select count(1) from users").scalar() == 2 session.rollback() assert session.connection().execute("select count(1) from users").scalar() == 1 - session.connection().execute("insert into users (name) values ('user3')") + session.connection().execute(users.insert().values(name='user3')) session.commit() assert session.connection().execute("select count(1) from users").scalar() == 2 @@ -771,18 +770,18 @@ class SessionTest(_fixtures.FixtureTest): user = s.query(User).one() del user - gc.collect() + gc_collect() assert len(s.identity_map) == 0 user = s.query(User).one() user.name = 'fred' del user - gc.collect() + gc_collect() assert len(s.identity_map) == 1 assert len(s.dirty) == 1 assert None not in s.dirty s.flush() - gc.collect() + gc_collect() assert not s.dirty assert not s.identity_map @@ -809,13 +808,13 @@ class SessionTest(_fixtures.FixtureTest): s.add(u2) del u2 - gc.collect() + gc_collect() assert len(s.identity_map) == 1 assert len(s.dirty) == 1 assert None not in s.dirty s.flush() - gc.collect() + gc_collect() assert not s.dirty assert not s.identity_map @@ -835,14 +834,14 @@ class SessionTest(_fixtures.FixtureTest): eq_(user, User(name="ed", addresses=[Address(email_address="ed1")])) del user - gc.collect() + gc_collect() assert len(s.identity_map) == 0 user = s.query(User).options(eagerload(User.addresses)).one() user.addresses[0].email_address='ed2' user.addresses[0].user # lazyload del user - gc.collect() + gc_collect() assert len(s.identity_map) == 2 s.commit() @@ -864,7 +863,7 @@ class SessionTest(_fixtures.FixtureTest): eq_(user, User(name="ed", address=Address(email_address="ed1"))) del user - gc.collect() + gc_collect() assert len(s.identity_map) == 0 user = s.query(User).options(eagerload(User.address)).one() @@ -872,7 +871,7 @@ class SessionTest(_fixtures.FixtureTest): user.address.user # lazyload del user - gc.collect() + gc_collect() assert len(s.identity_map) == 2 s.commit() @@ -890,8 +889,7 @@ class SessionTest(_fixtures.FixtureTest): user = s.query(User).one() user = None print s.identity_map - import gc - gc.collect() + gc_collect() assert len(s.identity_map) == 1 user = s.query(User).one() @@ -901,7 +899,7 @@ class SessionTest(_fixtures.FixtureTest): s.flush() eq_(users.select().execute().fetchall(), [(user.id, 'u2')]) - + @testing.fails_on('+zxjdbc', 'http://www.sqlalchemy.org/trac/ticket/1473') @testing.resolve_artifact_names def test_prune(self): s = create_session(weak_identity_map=False) @@ -914,8 +912,7 @@ class SessionTest(_fixtures.FixtureTest): self.assert_(len(s.identity_map) == 0) self.assert_(s.prune() == 0) s.flush() - import gc - gc.collect() + gc_collect() self.assert_(s.prune() == 9) self.assert_(len(s.identity_map) == 1) @@ -1228,7 +1225,7 @@ class DisposedStates(_base.MappedTest): def define_tables(cls, metadata): global t1 t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)) ) @@ -1327,7 +1324,7 @@ class SessionInterface(testing.TestBase): def _map_it(self, cls): return mapper(cls, Table('t', sa.MetaData(), - Column('id', Integer, primary_key=True))) + Column('id', Integer, primary_key=True, test_needs_autoincrement=True))) @testing.uses_deprecated() def _test_instance_guards(self, user_arg): @@ -1447,7 +1444,7 @@ class TLTransactionTest(engine_base.AltEngineTest, _base.MappedTest): @classmethod def define_tables(cls, metadata): Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(20)), test_needs_acid=True) diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py index 5aa541cda..51b345ceb 100644 --- a/test/orm/test_transaction.py +++ b/test/orm/test_transaction.py @@ -4,13 +4,11 @@ from sqlalchemy import * from sqlalchemy.orm import attributes from sqlalchemy import exc as sa_exc from sqlalchemy.orm import * - +from sqlalchemy.test.util import gc_collect from sqlalchemy.test import testing from test.orm import _base from test.orm._fixtures import FixtureTest, User, Address, users, addresses -import gc - class TransactionTest(FixtureTest): run_setup_mappers = 'once' run_inserts = None @@ -20,7 +18,7 @@ class TransactionTest(FixtureTest): def setup_mappers(cls): mapper(User, users, properties={ 'addresses':relation(Address, backref='user', - cascade="all, delete-orphan"), + cascade="all, delete-orphan", order_by=addresses.c.id), }) mapper(Address, addresses) @@ -109,7 +107,7 @@ class AutoExpireTest(TransactionTest): assert u1_state not in s.identity_map.all_states() assert u1_state not in s._deleted del u1 - gc.collect() + gc_collect() assert u1_state.obj() is None s.rollback() diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index f95346902..4d2056b26 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -379,7 +379,6 @@ class MutableTypesTest(_base.MappedTest): "WHERE mutable_t.id = :mutable_t_id", {'mutable_t_id': f1.id, 'val': u'hi', 'data':f1.data})]) - @testing.resolve_artifact_names def test_resurrect(self): f1 = Foo() @@ -392,42 +391,13 @@ class MutableTypesTest(_base.MappedTest): f1.data.y = 19 del f1 - + gc.collect() assert len(session.identity_map) == 1 - - session.commit() - - assert session.query(Foo).one().data == pickleable.Bar(4, 19) - - - @testing.uses_deprecated() - @testing.resolve_artifact_names - def test_nocomparison(self): - """Changes are detected on MutableTypes lacking an __eq__ method.""" - f1 = Foo() - f1.data = pickleable.BarWithoutCompare(4,5) - session = create_session(autocommit=False) - session.add(f1) session.commit() - self.sql_count_(0, session.commit) - session.close() - - session = create_session(autocommit=False) - f2 = session.query(Foo).filter_by(id=f1.id).one() - self.sql_count_(0, session.commit) - - f2.data.y = 19 - self.sql_count_(1, session.commit) - session.close() - - session = create_session(autocommit=False) - f3 = session.query(Foo).filter_by(id=f1.id).one() - eq_((f3.data.x, f3.data.y), (4,19)) - self.sql_count_(0, session.commit) - session.close() + assert session.query(Foo).one().data == pickleable.Bar(4, 19) @testing.resolve_artifact_names def test_unicode(self): @@ -892,7 +862,7 @@ class DefaultTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): - use_string_defaults = testing.against('postgres', 'oracle', 'sqlite', 'mssql') + use_string_defaults = testing.against('postgresql', 'oracle', 'sqlite', 'mssql') if use_string_defaults: hohotype = String(30) @@ -910,15 +880,14 @@ class DefaultTest(_base.MappedTest): Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('hoho', hohotype, server_default=str(hohoval)), - Column('counter', Integer, default=sa.func.char_length("1234567")), - Column('foober', String(30), default="im foober", - onupdate="im the update")) + Column('counter', Integer, default=sa.func.char_length("1234567", type_=Integer)), + Column('foober', String(30), default="im foober", onupdate="im the update")) st = Table('secondary_table', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) - if testing.against('postgres', 'oracle'): + if testing.against('postgresql', 'oracle'): dt.append_column( Column('secondary_id', Integer, sa.Sequence('sec_id_seq'), unique=True)) @@ -1004,14 +973,14 @@ class DefaultTest(_base.MappedTest): # "post-update" mapper(Hoho, default_t) - h1 = Hoho(hoho="15", counter="15") + h1 = Hoho(hoho="15", counter=15) session = create_session() session.add(h1) session.flush() def go(): eq_(h1.hoho, "15") - eq_(h1.counter, "15") + eq_(h1.counter, 15) eq_(h1.foober, "im foober") self.sql_count_(0, go) @@ -1036,7 +1005,7 @@ class DefaultTest(_base.MappedTest): """A server-side default can be used as the target of a foreign key""" mapper(Hoho, default_t, properties={ - 'secondaries':relation(Secondary)}) + 'secondaries':relation(Secondary, order_by=secondary_table.c.id)}) mapper(Secondary, secondary_table) h1 = Hoho() @@ -1068,7 +1037,7 @@ class ColumnPropertyTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('data', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('a', String(50)), Column('b', String(50)) ) @@ -1681,7 +1650,7 @@ class ManyToOneTest(_fixtures.FixtureTest): l = sa.select([users, addresses], sa.and_(users.c.id==addresses.c.user_id, addresses.c.id==a.id)).execute() - eq_(l.fetchone().values(), + eq_(l.first().values(), [a.user.id, 'asdf8d', a.id, a.user_id, 'theater@foo.com']) @testing.resolve_artifact_names @@ -2201,8 +2170,14 @@ class RowSwitchTest(_base.MappedTest): sess.add(o5) sess.flush() - assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some t5')] - assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some t6', 1), (2, 'some other t6', 1)] + eq_( + list(sess.execute(t5.select(), mapper=T5)), + [(1, 'some t5')] + ) + eq_( + list(sess.execute(t6.select().order_by(t6.c.id), mapper=T5)), + [(1, 'some t6', 1), (2, 'some other t6', 1)] + ) o6 = T5(data='some other t5', id=o5.id, t6s=[ T6(data='third t6', id=3), @@ -2212,8 +2187,14 @@ class RowSwitchTest(_base.MappedTest): sess.add(o6) sess.flush() - assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some other t5')] - assert list(sess.execute(t6.select(), mapper=T5)) == [(3, 'third t6', 1), (4, 'fourth t6', 1)] + eq_( + list(sess.execute(t5.select(), mapper=T5)), + [(1, 'some other t5')] + ) + eq_( + list(sess.execute(t6.select().order_by(t6.c.id), mapper=T5)), + [(3, 'third t6', 1), (4, 'fourth t6', 1)] + ) @testing.resolve_artifact_names def test_manytomany(self): @@ -2369,6 +2350,6 @@ class TransactionTest(_base.MappedTest): # todo: on 8.3 at least, the failed commit seems to close the cursor? # needs investigation. leaving in the DDL above now to help verify # that the new deferrable support on FK isn't involved in this issue. - if testing.against('postgres'): + if testing.against('postgresql'): t1.bind.engine.dispose() diff --git a/test/orm/test_utils.py b/test/orm/test_utils.py index 06533a243..8635ad212 100644 --- a/test/orm/test_utils.py +++ b/test/orm/test_utils.py @@ -39,8 +39,13 @@ class ExtensionCarrierTest(TestBase): assert 'populate_instance' not in carrier carrier.append(interfaces.MapperExtension) + + # Py3K + #assert 'populate_instance' not in carrier + # Py2K assert 'populate_instance' in carrier - + # end Py2K + assert carrier.interface for m in carrier.interface: assert getattr(interfaces.MapperExtension, m) @@ -85,7 +90,10 @@ class AliasedClassTest(TestBase): alias = aliased(Point) assert Point.zero + # Py2K + # TODO: what is this testing ?? assert not getattr(alias, 'zero') + # end Py2K def test_classmethods(self): class Point(object): @@ -152,9 +160,17 @@ class AliasedClassTest(TestBase): self.func = func def __get__(self, instance, owner): if instance is None: + # Py3K + #args = (self.func, owner) + # Py2K args = (self.func, owner, owner.__class__) + # end Py2K else: + # Py3K + #args = (self.func, instance) + # Py2K args = (self.func, instance, owner) + # end Py2K return types.MethodType(*args) class PropertyDescriptor(object): diff --git a/test/perf/insertspeed.py b/test/perf/insertspeed.py index 32877560e..0491e9f95 100644 --- a/test/perf/insertspeed.py +++ b/test/perf/insertspeed.py @@ -2,7 +2,7 @@ import testenv; testenv.simple_setup() import sys, time from sqlalchemy import * from sqlalchemy.orm import * -from testlib import profiling +from sqlalchemy.test import profiling db = create_engine('sqlite://') metadata = MetaData(db) diff --git a/test/perf/masscreate.py b/test/perf/masscreate.py index ae32f83e2..5b8e0da55 100644 --- a/test/perf/masscreate.py +++ b/test/perf/masscreate.py @@ -3,7 +3,6 @@ import testenv; testenv.simple_setup() from sqlalchemy.orm import attributes import time -import gc manage_attributes = True init_attributes = manage_attributes and True @@ -34,7 +33,6 @@ for i in range(0,130): attributes.manage(a) a.email = 'foo@bar.com' u.addresses.append(a) -# gc.collect() # print len(managed_attributes) # managed_attributes.clear() total = time.time() - now diff --git a/test/perf/masscreate2.py b/test/perf/masscreate2.py index 25d4b4915..e525fcf99 100644 --- a/test/perf/masscreate2.py +++ b/test/perf/masscreate2.py @@ -1,9 +1,9 @@ import testenv; testenv.simple_setup() -import gc import random, string from sqlalchemy.orm import attributes +from sqlalchemy.test.util import gc_collect # with this test, run top. make sure the Python process doenst grow in size arbitrarily. @@ -33,4 +33,4 @@ for i in xrange(1000): a.user = u print "clearing" #managed_attributes.clear() - gc.collect() + gc_collect() diff --git a/test/perf/masseagerload.py b/test/perf/masseagerload.py index a848b866c..88a3ade20 100644 --- a/test/perf/masseagerload.py +++ b/test/perf/masseagerload.py @@ -1,7 +1,6 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * -from testlib import * +from sqlalchemy.test import * NUM = 500 DIVISOR = 50 diff --git a/test/perf/massload.py b/test/perf/massload.py index 9391ead2a..f6cde3adf 100644 --- a/test/perf/massload.py +++ b/test/perf/massload.py @@ -1,10 +1,8 @@ -import testenv; testenv.configure_for_tests() import time -#import gc #import sqlalchemy.orm.attributes as attributes from sqlalchemy import * from sqlalchemy.orm import * -from testlib import * +from sqlalchemy.test import * """ @@ -18,16 +16,18 @@ top while it runs NUM = 2500 class LoadTest(TestBase, AssertsExecutionResults): - def setUpAll(self): + @classmethod + def setup_class(cls): global items, meta meta = MetaData(testing.db) items = Table('items', meta, Column('item_id', Integer, primary_key=True), Column('value', String(100))) items.create() - def tearDownAll(self): + @classmethod + def teardown_class(cls): items.drop() - def setUp(self): + def setup(self): for x in range(1,NUM/500+1): l = [] for y in range(x*500-500 + 1, x*500 + 1): @@ -43,7 +43,7 @@ class LoadTest(TestBase, AssertsExecutionResults): query = sess.query(Item) for x in range (1,NUM/100): # this is not needed with cpython which clears non-circular refs immediately - #gc.collect() + #gc_collect() l = query.filter(items.c.item_id.between(x*100 - 100 + 1, x*100)).all() assert len(l) == 100 print "loaded ", len(l), " items " @@ -61,5 +61,3 @@ class LoadTest(TestBase, AssertsExecutionResults): print "total time ", total -if __name__ == "__main__": - testenv.main() diff --git a/test/perf/masssave.py b/test/perf/masssave.py index bf65c8fdf..41acd12cc 100644 --- a/test/perf/masssave.py +++ b/test/perf/masssave.py @@ -1,21 +1,23 @@ -import testenv; testenv.configure_for_tests() +import gc import types from sqlalchemy import * from sqlalchemy.orm import * -from testlib import * +from sqlalchemy.test import * NUM = 2500 class SaveTest(TestBase, AssertsExecutionResults): - def setUpAll(self): + @classmethod + def setup_class(cls): global items, metadata metadata = MetaData(testing.db) items = Table('items', metadata, Column('item_id', Integer, primary_key=True), Column('value', String(100))) items.create() - def tearDownAll(self): + @classmethod + def teardown_class(cls): clear_mappers() metadata.drop_all() @@ -50,5 +52,3 @@ class SaveTest(TestBase, AssertsExecutionResults): print x -if __name__ == "__main__": - testenv.main() diff --git a/test/perf/objselectspeed.py b/test/perf/objselectspeed.py index 896fd4c49..867a396f3 100644 --- a/test/perf/objselectspeed.py +++ b/test/perf/objselectspeed.py @@ -1,7 +1,8 @@ import testenv; testenv.simple_setup() -import time, gc, resource +import time, resource from sqlalchemy import * from sqlalchemy.orm import * +from sqlalchemy.test.util import gc_collect db = create_engine('sqlite://') @@ -68,35 +69,35 @@ def all(): usage.snap = lambda stats=None: setattr( usage, 'last', stats or resource.getrusage(resource.RUSAGE_SELF)) - gc.collect() + gc_collect() usage.snap() t = time.clock() sqlite_select(RawPerson) t2 = time.clock() usage('sqlite select/native') - gc.collect() + gc_collect() usage.snap() t = time.clock() sqlite_select(Person) t2 = time.clock() usage('sqlite select/instrumented') - gc.collect() + gc_collect() usage.snap() t = time.clock() sql_select(RawPerson) t2 = time.clock() usage('sqlalchemy.sql select/native') - gc.collect() + gc_collect() usage.snap() t = time.clock() sql_select(Person) t2 = time.clock() usage('sqlalchemy.sql select/instrumented') - gc.collect() + gc_collect() usage.snap() t = time.clock() orm_select() diff --git a/test/perf/objupdatespeed.py b/test/perf/objupdatespeed.py index a49eb4724..52224211a 100644 --- a/test/perf/objupdatespeed.py +++ b/test/perf/objupdatespeed.py @@ -1,8 +1,9 @@ -import testenv; testenv.configure_for_tests() -import time, gc, resource +import time, resource from sqlalchemy import * from sqlalchemy.orm import * -from testlib import * +from sqlalchemy.test import * +from sqlalchemy.test.util import gc_collect + NUM = 100 @@ -72,14 +73,14 @@ def all(): session = create_session() - gc.collect() + gc_collect() usage.snap() t = time.clock() people = orm_select(session) t2 = time.clock() usage('load objects') - gc.collect() + gc_collect() usage.snap() t = time.clock() update_and_flush(session, people) diff --git a/test/perf/ormsession.py b/test/perf/ormsession.py index cdffa51a9..f9f9dee8b 100644 --- a/test/perf/ormsession.py +++ b/test/perf/ormsession.py @@ -1,11 +1,10 @@ -import testenv; testenv.configure_for_tests() import time from datetime import datetime from sqlalchemy import * from sqlalchemy.orm import * -from testlib import * -from testlib.profiling import profiled +from sqlalchemy.test import * +from sqlalchemy.test.profiling import profiled class Item(object): def __repr__(self): diff --git a/test/perf/poolload.py b/test/perf/poolload.py index 8d66da84f..62c66fbae 100644 --- a/test/perf/poolload.py +++ b/test/perf/poolload.py @@ -1,9 +1,8 @@ # load test of connection pool -import testenv; testenv.configure_for_tests() import thread, time from sqlalchemy import * import sqlalchemy.pool as pool -from testlib import testing +from sqlalchemy.test import testing db = create_engine(testing.db.url, pool_timeout=30, echo_pool=True) metadata = MetaData(db) diff --git a/test/perf/sessions.py b/test/perf/sessions.py index f4be1ee93..0d4cc1f01 100644 --- a/test/perf/sessions.py +++ b/test/perf/sessions.py @@ -1,17 +1,17 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * -import gc -from testlib import TestBase, AssertsExecutionResults, profiling, testing -from orm import _fixtures +from sqlalchemy.test.compat import gc_collect +from sqlalchemy.test import TestBase, AssertsExecutionResults, profiling, testing +from test.orm import _fixtures # in this test we are specifically looking for time spent in the attributes.InstanceState.__cleanup() method. ITERATIONS = 100 class SessionTest(TestBase, AssertsExecutionResults): - def setUpAll(self): + @classmethod + def setup_class(cls): global t1, t2, metadata,T1, T2 metadata = MetaData(testing.db) t1 = Table('t1', metadata, @@ -46,7 +46,8 @@ class SessionTest(TestBase, AssertsExecutionResults): }) mapper(T2, t2) - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() clear_mappers() @@ -60,7 +61,7 @@ class SessionTest(TestBase, AssertsExecutionResults): sess.close() del sess - gc.collect() + gc_collect() @profiling.profiled('dirty', report=True) def test_session_dirty(self): @@ -74,11 +75,11 @@ class SessionTest(TestBase, AssertsExecutionResults): t2.c2 = 'this is some modified text' del t1s - gc.collect() + gc_collect() sess.close() del sess - gc.collect() + gc_collect() @profiling.profiled('noclose', report=True) def test_session_noclose(self): @@ -89,9 +90,6 @@ class SessionTest(TestBase, AssertsExecutionResults): t1s[index].t2s del sess - gc.collect() - + gc_collect() -if __name__ == '__main__': - testenv.main() diff --git a/test/perf/wsgi.py b/test/perf/wsgi.py index 6fc8149bc..549c92ade 100644 --- a/test/perf/wsgi.py +++ b/test/perf/wsgi.py @@ -1,11 +1,10 @@ #!/usr/bin/python """Uses ``wsgiref``, standard in Python 2.5 and also in the cheeseshop.""" -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * import thread -from testlib import * +from sqlalchemy.test import * port = 8000 diff --git a/test/sql/test_constraints.py b/test/sql/test_constraints.py index 8abeb3533..4ad52604d 100644 --- a/test/sql/test_constraints.py +++ b/test/sql/test_constraints.py @@ -1,10 +1,14 @@ -from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +from sqlalchemy.test.testing import assert_raises, assert_raises_message from sqlalchemy import * -from sqlalchemy import exc +from sqlalchemy import exc, schema from sqlalchemy.test import * from sqlalchemy.test import config, engines +from sqlalchemy.engine import ddl +from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.assertsql import AllOf, RegexSQL, ExactSQL, CompiledSQL +from sqlalchemy.dialects.postgresql import base as postgresql -class ConstraintTest(TestBase, AssertsExecutionResults): +class ConstraintTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): def setup(self): global metadata @@ -33,11 +37,8 @@ class ConstraintTest(TestBase, AssertsExecutionResults): def test_double_fk_usage_raises(self): f = ForeignKey('b.id') - assert_raises(exc.InvalidRequestError, Table, "a", metadata, - Column('x', Integer, f), - Column('y', Integer, f) - ) - + Column('x', Integer, f) + assert_raises(exc.InvalidRequestError, Column, "y", Integer, f) def test_circular_constraint(self): a = Table("a", metadata, @@ -78,18 +79,9 @@ class ConstraintTest(TestBase, AssertsExecutionResults): metadata.create_all() foo.insert().execute(id=1,x=9,y=5) - try: - foo.insert().execute(id=2,x=5,y=9) - assert False - except exc.SQLError: - assert True - + assert_raises(exc.SQLError, foo.insert().execute, id=2,x=5,y=9) bar.insert().execute(id=1,x=10) - try: - bar.insert().execute(id=2,x=5) - assert False - except exc.SQLError: - assert True + assert_raises(exc.SQLError, bar.insert().execute, id=2,x=5) def test_unique_constraint(self): foo = Table('foo', metadata, @@ -106,16 +98,8 @@ class ConstraintTest(TestBase, AssertsExecutionResults): foo.insert().execute(id=2, value='value2') bar.insert().execute(id=1, value='a', value2='a') bar.insert().execute(id=2, value='a', value2='b') - try: - foo.insert().execute(id=3, value='value1') - assert False - except exc.SQLError: - assert True - try: - bar.insert().execute(id=3, value='a', value2='b') - assert False - except exc.SQLError: - assert True + assert_raises(exc.SQLError, foo.insert().execute, id=3, value='value1') + assert_raises(exc.SQLError, bar.insert().execute, id=3, value='a', value2='b') def test_index_create(self): employees = Table('employees', metadata, @@ -174,35 +158,22 @@ class ConstraintTest(TestBase, AssertsExecutionResults): Index('sport_announcer', events.c.sport, events.c.announcer, unique=True) Index('idx_winners', events.c.winner) - index_names = [ ix.name for ix in events.indexes ] - assert 'ix_events_name' in index_names - assert 'ix_events_location' in index_names - assert 'sport_announcer' in index_names - assert 'idx_winners' in index_names - assert len(index_names) == 4 - - capt = [] - connection = testing.db.connect() - # TODO: hacky, put a real connection proxy in - ex = connection._Connection__execute_context - def proxy(context): - capt.append(context.statement) - capt.append(repr(context.parameters)) - ex(context) - connection._Connection__execute_context = proxy - schemagen = testing.db.dialect.schemagenerator(testing.db.dialect, connection) - schemagen.traverse(events) - - assert capt[0].strip().startswith('CREATE TABLE events') - - s = set([capt[x].strip() for x in [2,4,6,8]]) - - assert s == set([ - 'CREATE UNIQUE INDEX ix_events_name ON events (name)', - 'CREATE INDEX ix_events_location ON events (location)', - 'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)', - 'CREATE INDEX idx_winners ON events (winner)' - ]) + eq_( + set([ ix.name for ix in events.indexes ]), + set(['ix_events_name', 'ix_events_location', 'sport_announcer', 'idx_winners']) + ) + + self.assert_sql_execution( + testing.db, + lambda: events.create(testing.db), + RegexSQL("^CREATE TABLE events"), + AllOf( + ExactSQL('CREATE UNIQUE INDEX ix_events_name ON events (name)'), + ExactSQL('CREATE INDEX ix_events_location ON events (location)'), + ExactSQL('CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)'), + ExactSQL('CREATE INDEX idx_winners ON events (winner)') + ) + ) # verify that the table is functional events.insert().execute(id=1, name='hockey finals', location='rink', @@ -214,84 +185,57 @@ class ConstraintTest(TestBase, AssertsExecutionResults): dialect = testing.db.dialect.__class__() dialect.max_identifier_length = 20 - schemagen = dialect.schemagenerator(dialect, None) - schemagen.execute = lambda : None - t1 = Table("sometable", MetaData(), Column("foo", Integer)) - schemagen.visit_index(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)) - eq_(schemagen.buffer.getvalue(), "CREATE INDEX this_name_is_t_1 ON sometable (foo)") - schemagen.buffer.truncate(0) - schemagen.visit_index(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo)) - eq_(schemagen.buffer.getvalue(), "CREATE INDEX this_other_nam_2 ON sometable (foo)") - - schemadrop = dialect.schemadropper(dialect, None) - schemadrop.execute = lambda: None - assert_raises(exc.IdentifierError, schemadrop.visit_index, Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)) + self.assert_compile( + schema.CreateIndex(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)), + "CREATE INDEX this_name_is_t_1 ON sometable (foo)", + dialect=dialect + ) + + self.assert_compile( + schema.CreateIndex(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo)), + "CREATE INDEX this_other_nam_1 ON sometable (foo)", + dialect=dialect + ) -class ConstraintCompilationTest(TestBase, AssertsExecutionResults): - class accum(object): - def __init__(self): - self.statements = [] - def __call__(self, sql, *a, **kw): - self.statements.append(sql) - def __contains__(self, substring): - for s in self.statements: - if substring in s: - return True - return False - def __str__(self): - return '\n'.join([repr(x) for x in self.statements]) - def clear(self): - del self.statements[:] - - def setup(self): - self.sql = self.accum() - opts = config.db_opts.copy() - opts['strategy'] = 'mock' - opts['executor'] = self.sql - self.engine = engines.testing_engine(options=opts) - +class ConstraintCompilationTest(TestBase, AssertsCompiledSQL): def _test_deferrable(self, constraint_factory): - meta = MetaData(self.engine) - t = Table('tbl', meta, + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer), constraint_factory(deferrable=True)) - t.create() - assert 'DEFERRABLE' in self.sql, self.sql - assert 'NOT DEFERRABLE' not in self.sql, self.sql - self.sql.clear() - meta.clear() - - t = Table('tbl', meta, + + sql = str(schema.CreateTable(t).compile(bind=testing.db)) + assert 'DEFERRABLE' in sql, sql + assert 'NOT DEFERRABLE' not in sql, sql + + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer), constraint_factory(deferrable=False)) - t.create() - assert 'NOT DEFERRABLE' in self.sql - self.sql.clear() - meta.clear() - t = Table('tbl', meta, + sql = str(schema.CreateTable(t).compile(bind=testing.db)) + assert 'NOT DEFERRABLE' in sql + + + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer), constraint_factory(deferrable=True, initially='IMMEDIATE')) - t.create() - assert 'NOT DEFERRABLE' not in self.sql - assert 'INITIALLY IMMEDIATE' in self.sql - self.sql.clear() - meta.clear() + sql = str(schema.CreateTable(t).compile(bind=testing.db)) + assert 'NOT DEFERRABLE' not in sql + assert 'INITIALLY IMMEDIATE' in sql - t = Table('tbl', meta, + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer), constraint_factory(deferrable=True, initially='DEFERRED')) - t.create() + sql = str(schema.CreateTable(t).compile(bind=testing.db)) - assert 'NOT DEFERRABLE' not in self.sql - assert 'INITIALLY DEFERRED' in self.sql, self.sql + assert 'NOT DEFERRABLE' not in sql + assert 'INITIALLY DEFERRED' in sql def test_deferrable_pk(self): factory = lambda **kw: PrimaryKeyConstraint('a', **kw) @@ -302,15 +246,16 @@ class ConstraintCompilationTest(TestBase, AssertsExecutionResults): self._test_deferrable(factory) def test_deferrable_column_fk(self): - meta = MetaData(self.engine) - t = Table('tbl', meta, + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer, ForeignKey('tbl.a', deferrable=True, initially='DEFERRED'))) - t.create() - assert 'DEFERRABLE' in self.sql, self.sql - assert 'INITIALLY DEFERRED' in self.sql, self.sql + + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE tbl (a INTEGER, b INTEGER, FOREIGN KEY(b) REFERENCES tbl (a) DEFERRABLE INITIALLY DEFERRED)", + ) def test_deferrable_unique(self): factory = lambda **kw: UniqueConstraint('b', **kw) @@ -321,15 +266,105 @@ class ConstraintCompilationTest(TestBase, AssertsExecutionResults): self._test_deferrable(factory) def test_deferrable_column_check(self): - meta = MetaData(self.engine) - t = Table('tbl', meta, + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer, CheckConstraint('a < b', deferrable=True, initially='DEFERRED'))) - t.create() - assert 'DEFERRABLE' in self.sql, self.sql - assert 'INITIALLY DEFERRED' in self.sql, self.sql + + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE tbl (a INTEGER, b INTEGER CHECK (a < b) DEFERRABLE INITIALLY DEFERRED)" + ) + + def test_use_alter(self): + m = MetaData() + t = Table('t', m, + Column('a', Integer), + ) + + t2 = Table('t2', m, + Column('a', Integer, ForeignKey('t.a', use_alter=True, name='fk_ta')), + Column('b', Integer, ForeignKey('t.a', name='fk_tb')), # to ensure create ordering ... + ) + + e = engines.mock_engine(dialect_name='postgresql') + m.create_all(e) + m.drop_all(e) + + e.assert_sql([ + 'CREATE TABLE t (a INTEGER)', + 'CREATE TABLE t2 (a INTEGER, b INTEGER, CONSTRAINT fk_tb FOREIGN KEY(b) REFERENCES t (a))', + 'ALTER TABLE t2 ADD CONSTRAINT fk_ta FOREIGN KEY(a) REFERENCES t (a)', + 'ALTER TABLE t2 DROP CONSTRAINT fk_ta', + 'DROP TABLE t2', + 'DROP TABLE t' + ]) + + + def test_add_drop_constraint(self): + m = MetaData() + + t = Table('tbl', m, + Column('a', Integer), + Column('b', Integer) + ) + + t2 = Table('t2', m, + Column('a', Integer), + Column('b', Integer) + ) + + constraint = CheckConstraint('a < b',name="my_test_constraint", deferrable=True,initially='DEFERRED', table=t) + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD CONSTRAINT my_test_constraint CHECK (a < b) DEFERRABLE INITIALLY DEFERRED" + ) + + self.assert_compile( + schema.DropConstraint(constraint), + "ALTER TABLE tbl DROP CONSTRAINT my_test_constraint" + ) + + self.assert_compile( + schema.DropConstraint(constraint, cascade=True), + "ALTER TABLE tbl DROP CONSTRAINT my_test_constraint CASCADE" + ) + constraint = ForeignKeyConstraint(["b"], ["t2.a"]) + t.append_constraint(constraint) + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD FOREIGN KEY(b) REFERENCES t2 (a)" + ) + constraint = ForeignKeyConstraint([t.c.a], [t2.c.b]) + t.append_constraint(constraint) + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD FOREIGN KEY(a) REFERENCES t2 (b)" + ) + + constraint = UniqueConstraint("a", "b", name="uq_cst") + t2.append_constraint(constraint) + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE t2 ADD CONSTRAINT uq_cst UNIQUE (a, b)" + ) + + constraint = UniqueConstraint(t2.c.a, t2.c.b, name="uq_cs2") + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE t2 ADD CONSTRAINT uq_cs2 UNIQUE (a, b)" + ) + + assert t.c.a.primary_key is False + constraint = PrimaryKeyConstraint(t.c.a) + assert t.c.a.primary_key is True + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD PRIMARY KEY (a)" + ) + + diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index 964157466..5638dad77 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -3,7 +3,7 @@ import datetime from sqlalchemy import Sequence, Column, func from sqlalchemy.sql import select, text import sqlalchemy as sa -from sqlalchemy.test import testing +from sqlalchemy.test import testing, engines from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean from sqlalchemy.test.schema import Table from sqlalchemy.test.testing import eq_ @@ -37,7 +37,7 @@ class DefaultTest(testing.TestBase): # since its a "branched" connection conn.close() - use_function_defaults = testing.against('postgres', 'mssql', 'maxdb') + use_function_defaults = testing.against('postgresql', 'mssql', 'maxdb') is_oracle = testing.against('oracle') # select "count(1)" returns different results on different DBs also @@ -146,7 +146,7 @@ class DefaultTest(testing.TestBase): assert_raises_message(sa.exc.ArgumentError, ex_msg, sa.ColumnDefault, fn) - + def test_arg_signature(self): def fn1(): pass def fn2(): pass @@ -276,7 +276,7 @@ class DefaultTest(testing.TestBase): assert r.lastrow_has_defaults() eq_(set(r.context.postfetch_cols), set([t.c.col3, t.c.col5, t.c.col4, t.c.col6])) - + eq_(t.select(t.c.col1==54).execute().fetchall(), [(54, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today, None)]) @@ -284,7 +284,7 @@ class DefaultTest(testing.TestBase): @testing.fails_on('firebird', 'Data type unknown') def test_insertmany(self): # MySQL-Python 1.2.2 breaks functions in execute_many :( - if (testing.against('mysql') and + if (testing.against('mysql') and not testing.against('+zxjdbc') and testing.db.dialect.dbapi.version_info[:3] == (1, 2, 2)): return @@ -304,12 +304,12 @@ class DefaultTest(testing.TestBase): def test_insert_values(self): t.insert(values={'col3':50}).execute() l = t.select().execute() - eq_(50, l.fetchone()['col3']) + eq_(50, l.first()['col3']) @testing.fails_on('firebird', 'Data type unknown') def test_updatemany(self): # MySQL-Python 1.2.2 breaks functions in execute_many :( - if (testing.against('mysql') and + if (testing.against('mysql') and not testing.against('+zxjdbc') and testing.db.dialect.dbapi.version_info[:3] == (1, 2, 2)): return @@ -337,11 +337,11 @@ class DefaultTest(testing.TestBase): @testing.fails_on('firebird', 'Data type unknown') def test_update(self): r = t.insert().execute() - pk = r.last_inserted_ids()[0] + pk = r.inserted_primary_key[0] t.update(t.c.col1==pk).execute(col4=None, col5=None) ctexec = currenttime.scalar() l = t.select(t.c.col1==pk).execute() - l = l.fetchone() + l = l.first() eq_(l, (pk, 'im the update', f2, None, None, ctexec, True, False, 13, datetime.date.today(), 'py')) @@ -350,43 +350,12 @@ class DefaultTest(testing.TestBase): @testing.fails_on('firebird', 'Data type unknown') def test_update_values(self): r = t.insert().execute() - pk = r.last_inserted_ids()[0] + pk = r.inserted_primary_key[0] t.update(t.c.col1==pk, values={'col3': 55}).execute() l = t.select(t.c.col1==pk).execute() - l = l.fetchone() + l = l.first() eq_(55, l['col3']) - @testing.fails_on_everything_except('postgres') - def test_passive_override(self): - """ - Primarily for postgres, tests that when we get a primary key column - back from reflecting a table which has a default value on it, we - pre-execute that DefaultClause upon insert, even though DefaultClause - says "let the database execute this", because in postgres we must have - all the primary key values in memory before insert; otherwise we can't - locate the just inserted row. - - """ - # TODO: move this to dialect/postgres - try: - meta = MetaData(testing.db) - testing.db.execute(""" - CREATE TABLE speedy_users - ( - speedy_user_id SERIAL PRIMARY KEY, - - user_name VARCHAR NOT NULL, - user_password VARCHAR NOT NULL - ); - """, None) - - t = Table("speedy_users", meta, autoload=True) - t.insert().execute(user_name='user', user_password='lala') - l = t.select().execute().fetchall() - eq_(l, [(1, 'user', 'lala')]) - finally: - testing.db.execute("drop table speedy_users", None) - class PKDefaultTest(_base.TablesTest): __requires__ = ('subqueries',) @@ -400,18 +369,27 @@ class PKDefaultTest(_base.TablesTest): Column('id', Integer, primary_key=True, default=sa.select([func.max(t2.c.nextid)]).as_scalar()), Column('data', String(30))) - - @testing.fails_on('mssql', 'FIXME: unknown') + + @testing.requires.returning + def test_with_implicit_returning(self): + self._test(True) + + def test_regular(self): + self._test(False) + @testing.resolve_artifact_names - def test_basic(self): - t2.insert().execute(nextid=1) - r = t1.insert().execute(data='hi') - eq_([1], r.last_inserted_ids()) - - t2.insert().execute(nextid=2) - r = t1.insert().execute(data='there') - eq_([2], r.last_inserted_ids()) + def _test(self, returning): + if not returning and not testing.db.dialect.implicit_returning: + engine = testing.db + else: + engine = engines.testing_engine(options={'implicit_returning':returning}) + engine.execute(t2.insert(), nextid=1) + r = engine.execute(t1.insert(), data='hi') + eq_([1], r.inserted_primary_key) + engine.execute(t2.insert(), nextid=2) + r = engine.execute(t1.insert(), data='there') + eq_([2], r.inserted_primary_key) class PKIncrementTest(_base.TablesTest): run_define_tables = 'each' @@ -430,29 +408,31 @@ class PKIncrementTest(_base.TablesTest): def _test_autoincrement(self, bind): ids = set() rs = bind.execute(aitable.insert(), int1=1) - last = rs.last_inserted_ids()[0] + last = rs.inserted_primary_key[0] self.assert_(last) self.assert_(last not in ids) ids.add(last) rs = bind.execute(aitable.insert(), str1='row 2') - last = rs.last_inserted_ids()[0] + last = rs.inserted_primary_key[0] self.assert_(last) self.assert_(last not in ids) ids.add(last) rs = bind.execute(aitable.insert(), int1=3, str1='row 3') - last = rs.last_inserted_ids()[0] + last = rs.inserted_primary_key[0] self.assert_(last) self.assert_(last not in ids) ids.add(last) rs = bind.execute(aitable.insert(values={'int1':func.length('four')})) - last = rs.last_inserted_ids()[0] + last = rs.inserted_primary_key[0] self.assert_(last) self.assert_(last not in ids) ids.add(last) + eq_(ids, set([1,2,3,4])) + eq_(list(bind.execute(aitable.select().order_by(aitable.c.id))), [(1, 1, None), (2, None, 'row 2'), (3, 3, 'row 3'), (4, 4, None)]) @@ -510,8 +490,8 @@ class AutoIncrementTest(_base.TablesTest): single.create() r = single.insert().execute() - id_ = r.last_inserted_ids()[0] - assert id_ is not None + id_ = r.inserted_primary_key[0] + eq_(id_, 1) eq_(1, sa.select([func.count(sa.text('*'))], from_obj=single).scalar()) def test_autoincrement_fk(self): @@ -522,7 +502,7 @@ class AutoIncrementTest(_base.TablesTest): nodes.create() r = nodes.insert().execute(data='foo') - id_ = r.last_inserted_ids()[0] + id_ = r.inserted_primary_key[0] nodes.insert().execute(data='bar', parent_id=id_) @testing.fails_on('sqlite', 'FIXME: unknown') @@ -535,7 +515,7 @@ class AutoIncrementTest(_base.TablesTest): try: - # postgres + mysql strict will fail on first row, + # postgresql + mysql strict will fail on first row, # mysql in legacy mode fails on second row nonai.insert().execute(data='row 1') nonai.insert().execute(data='row 2') @@ -570,16 +550,17 @@ class SequenceTest(testing.TestBase): def testseqnonpk(self): """test sequences fire off as defaults on non-pk columns""" - result = sometable.insert().execute(name="somename") + engine = engines.testing_engine(options={'implicit_returning':False}) + result = engine.execute(sometable.insert(), name="somename") assert 'id' in result.postfetch_cols() - result = sometable.insert().execute(name="someother") + result = engine.execute(sometable.insert(), name="someother") assert 'id' in result.postfetch_cols() sometable.insert().execute( {'name':'name3'}, {'name':'name4'}) - eq_(sometable.select().execute().fetchall(), + eq_(sometable.select().order_by(sometable.c.id).execute().fetchall(), [(1, "somename", 1), (2, "someother", 2), (3, "name3", 3), @@ -590,8 +571,8 @@ class SequenceTest(testing.TestBase): cartitems.insert().execute(description='there') r = cartitems.insert().execute(description='lala') - assert r.last_inserted_ids() and r.last_inserted_ids()[0] is not None - id_ = r.last_inserted_ids()[0] + assert r.inserted_primary_key and r.inserted_primary_key[0] is not None + id_ = r.inserted_primary_key[0] eq_(1, sa.select([func.count(cartitems.c.cart_id)], diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index e9bf49ce3..7a0f12cac 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -24,7 +24,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): bindtemplate = BIND_TEMPLATES[dialect.paramstyle] self.assert_compile(func.current_timestamp(), "CURRENT_TIMESTAMP", dialect=dialect) self.assert_compile(func.localtime(), "LOCALTIME", dialect=dialect) - if isinstance(dialect, firebird.dialect): + if isinstance(dialect, (firebird.dialect, maxdb.dialect, oracle.dialect)): self.assert_compile(func.nosuchfunction(), "nosuchfunction", dialect=dialect) else: self.assert_compile(func.nosuchfunction(), "nosuchfunction()", dialect=dialect) @@ -50,7 +50,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): for ret, dialect in [ ('CURRENT_TIMESTAMP', sqlite.dialect()), - ('now()', postgres.dialect()), + ('now()', postgresql.dialect()), ('now()', mysql.dialect()), ('CURRENT_TIMESTAMP', oracle.dialect()) ]: @@ -62,9 +62,9 @@ class CompileTest(TestBase, AssertsCompiledSQL): for ret, dialect in [ ('random()', sqlite.dialect()), - ('random()', postgres.dialect()), + ('random()', postgresql.dialect()), ('rand()', mysql.dialect()), - ('random()', oracle.dialect()) + ('random', oracle.dialect()) ]: self.assert_compile(func.random(), ret, dialect=dialect) @@ -180,7 +180,10 @@ class CompileTest(TestBase, AssertsCompiledSQL): class ExecuteTest(TestBase): - + @engines.close_first + def tearDown(self): + pass + def test_standalone_execute(self): x = testing.db.func.current_date().execute().scalar() y = testing.db.func.current_date().select().execute().scalar() @@ -202,6 +205,7 @@ class ExecuteTest(TestBase): conn.close() assert (x == y == z) is True + @engines.close_first def test_update(self): """ Tests sending functions and SQL expressions to the VALUES and SET @@ -222,15 +226,15 @@ class ExecuteTest(TestBase): meta.create_all() try: t.insert(values=dict(value=func.length("one"))).execute() - assert t.select().execute().fetchone()['value'] == 3 + assert t.select().execute().first()['value'] == 3 t.update(values=dict(value=func.length("asfda"))).execute() - assert t.select().execute().fetchone()['value'] == 5 + assert t.select().execute().first()['value'] == 5 r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute() - id = r.last_inserted_ids()[0] - assert t.select(t.c.id==id).execute().fetchone()['value'] == 9 + id = r.inserted_primary_key[0] + assert t.select(t.c.id==id).execute().first()['value'] == 9 t.update(values={t.c.value:func.length("asdf")}).execute() - assert t.select().execute().fetchone()['value'] == 4 + assert t.select().execute().first()['value'] == 4 print "--------------------------" t2.insert().execute() t2.insert(values=dict(value=func.length("one"))).execute() @@ -245,18 +249,18 @@ class ExecuteTest(TestBase): t2.delete().execute() t2.insert(values=dict(value=func.length("one") + 8)).execute() - assert t2.select().execute().fetchone()['value'] == 11 + assert t2.select().execute().first()['value'] == 11 t2.update(values=dict(value=func.length("asfda"))).execute() - assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (5, "thisisstuff") + assert select([t2.c.value, t2.c.stuff]).execute().first() == (5, "thisisstuff") t2.update(values={t2.c.value:func.length("asfdaasdf"), t2.c.stuff:"foo"}).execute() - print "HI", select([t2.c.value, t2.c.stuff]).execute().fetchone() - assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (9, "foo") + print "HI", select([t2.c.value, t2.c.stuff]).execute().first() + assert select([t2.c.value, t2.c.stuff]).execute().first() == (9, "foo") finally: meta.drop_all() - @testing.fails_on_everything_except('postgres') + @testing.fails_on_everything_except('postgresql') def test_as_from(self): # TODO: shouldnt this work on oracle too ? x = testing.db.func.current_date().execute().scalar() @@ -266,7 +270,7 @@ class ExecuteTest(TestBase): # construct a column-based FROM object out of a function, like in [ticket:172] s = select([sql.column('date', type_=DateTime)], from_obj=[testing.db.func.current_date()]) - q = s.execute().fetchone()[s.c.date] + q = s.execute().first()[s.c.date] r = s.alias('datequery').select().scalar() assert x == y == z == w == q == r @@ -301,7 +305,7 @@ class ExecuteTest(TestBase): 'd': datetime.date(2010, 5, 1) }) rs = select([extract('year', table.c.dt), extract('month', table.c.d)]).execute() - row = rs.fetchone() + row = rs.first() assert row[0] == 2010 assert row[1] == 5 rs.close() diff --git a/test/sql/test_labels.py b/test/sql/test_labels.py index b946b0ae9..bcac7c01d 100644 --- a/test/sql/test_labels.py +++ b/test/sql/test_labels.py @@ -35,6 +35,7 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL): maxlen = testing.db.dialect.max_identifier_length testing.db.dialect.max_identifier_length = IDENT_LENGTH + @engines.close_first def teardown(self): table1.delete().execute() @@ -92,10 +93,16 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL): ], repr(result) def test_table_alias_names(self): - self.assert_compile( - table2.alias().select(), - "SELECT table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM table_with_exactly_29_characs AS table_with_exactly_29_c_1" - ) + if testing.against('oracle'): + self.assert_compile( + table2.alias().select(), + "SELECT table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM table_with_exactly_29_characs table_with_exactly_29_c_1" + ) + else: + self.assert_compile( + table2.alias().select(), + "SELECT table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM table_with_exactly_29_characs AS table_with_exactly_29_c_1" + ) ta = table2.alias() dialect = default.DefaultDialect() diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 51b933e45..0e3b9dff2 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -1,9 +1,11 @@ +from sqlalchemy.test.testing import eq_ import datetime from sqlalchemy import * from sqlalchemy import exc, sql from sqlalchemy.engine import default from sqlalchemy.test import * -from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.testing import eq_, assert_raises_message +from sqlalchemy.test.schema import Table, Column class QueryTest(TestBase): @@ -12,11 +14,11 @@ class QueryTest(TestBase): global users, users2, addresses, metadata metadata = MetaData(testing.db) users = Table('query_users', metadata, - Column('user_id', INT, primary_key = True), + Column('user_id', INT, primary_key=True, test_needs_autoincrement=True), Column('user_name', VARCHAR(20)), ) addresses = Table('query_addresses', metadata, - Column('address_id', Integer, primary_key=True), + Column('address_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer, ForeignKey('query_users.user_id')), Column('address', String(30))) @@ -26,7 +28,8 @@ class QueryTest(TestBase): ) metadata.create_all() - def tearDown(self): + @engines.close_first + def teardown(self): addresses.delete().execute() users.delete().execute() users2.delete().execute() @@ -52,89 +55,133 @@ class QueryTest(TestBase): assert users.count().scalar() == 1 users.update(users.c.user_id == 7).execute(user_name = 'fred') - assert users.select(users.c.user_id==7).execute().fetchone()['user_name'] == 'fred' + assert users.select(users.c.user_id==7).execute().first()['user_name'] == 'fred' def test_lastrow_accessor(self): - """Tests the last_inserted_ids() and lastrow_has_id() functions.""" + """Tests the inserted_primary_key and lastrow_has_id() functions.""" - def insert_values(table, values): + def insert_values(engine, table, values): """ Inserts a row into a table, returns the full list of values INSERTed including defaults that fired off on the DB side and detects rows that had defaults and post-fetches. """ - result = table.insert().execute(**values) + result = engine.execute(table.insert(), **values) ret = values.copy() - for col, id in zip(table.primary_key, result.last_inserted_ids()): + for col, id in zip(table.primary_key, result.inserted_primary_key): ret[col.key] = id if result.lastrow_has_defaults(): - criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())]) - row = table.select(criterion).execute().fetchone() + criterion = and_(*[col==id for col, id in zip(table.primary_key, result.inserted_primary_key)]) + row = engine.execute(table.select(criterion)).first() for c in table.c: ret[c.key] = row[c] return ret - for supported, table, values, assertvalues in [ - ( - {'unsupported':['sqlite']}, - Table("t1", metadata, - Column('id', Integer, Sequence('t1_id_seq', optional=True), primary_key=True), - Column('foo', String(30), primary_key=True)), - {'foo':'hi'}, - {'id':1, 'foo':'hi'} - ), - ( - {'unsupported':['sqlite']}, - Table("t2", metadata, - Column('id', Integer, Sequence('t2_id_seq', optional=True), primary_key=True), - Column('foo', String(30), primary_key=True), - Column('bar', String(30), server_default='hi') + if testing.against('firebird', 'postgresql', 'oracle', 'mssql'): + test_engines = [ + engines.testing_engine(options={'implicit_returning':False}), + engines.testing_engine(options={'implicit_returning':True}), + ] + else: + test_engines = [testing.db] + + for engine in test_engines: + metadata = MetaData() + for supported, table, values, assertvalues in [ + ( + {'unsupported':['sqlite']}, + Table("t1", metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('foo', String(30), primary_key=True)), + {'foo':'hi'}, + {'id':1, 'foo':'hi'} ), - {'foo':'hi'}, - {'id':1, 'foo':'hi', 'bar':'hi'} - ), - ( - {'unsupported':[]}, - Table("t3", metadata, - Column("id", String(40), primary_key=True), - Column('foo', String(30), primary_key=True), - Column("bar", String(30)) + ( + {'unsupported':['sqlite']}, + Table("t2", metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('foo', String(30), primary_key=True), + Column('bar', String(30), server_default='hi') ), - {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}, - {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"} - ), - ( - {'unsupported':[]}, - Table("t4", metadata, - Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True), - Column('foo', String(30), primary_key=True), - Column('bar', String(30), server_default='hi') + {'foo':'hi'}, + {'id':1, 'foo':'hi', 'bar':'hi'} ), - {'foo':'hi', 'id':1}, - {'id':1, 'foo':'hi', 'bar':'hi'} - ), - ( - {'unsupported':[]}, - Table("t5", metadata, - Column('id', String(10), primary_key=True), - Column('bar', String(30), server_default='hi') + ( + {'unsupported':[]}, + Table("t3", metadata, + Column("id", String(40), primary_key=True), + Column('foo', String(30), primary_key=True), + Column("bar", String(30)) + ), + {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}, + {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"} ), - {'id':'id1'}, - {'id':'id1', 'bar':'hi'}, - ), - ]: - if testing.db.name in supported['unsupported']: - continue - try: - table.create() - i = insert_values(table, values) - assert i == assertvalues, repr(i) + " " + repr(assertvalues) - finally: - table.drop() + ( + {'unsupported':[]}, + Table("t4", metadata, + Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True), + Column('foo', String(30), primary_key=True), + Column('bar', String(30), server_default='hi') + ), + {'foo':'hi', 'id':1}, + {'id':1, 'foo':'hi', 'bar':'hi'} + ), + ( + {'unsupported':[]}, + Table("t5", metadata, + Column('id', String(10), primary_key=True), + Column('bar', String(30), server_default='hi') + ), + {'id':'id1'}, + {'id':'id1', 'bar':'hi'}, + ), + ]: + if testing.db.name in supported['unsupported']: + continue + try: + table.create(bind=engine, checkfirst=True) + i = insert_values(engine, table, values) + assert i == assertvalues, "tablename: %s %r %r" % (table.name, repr(i), repr(assertvalues)) + finally: + table.drop(bind=engine) + + @testing.fails_on('sqlite', "sqlite autoincremnt doesn't work with composite pks") + def test_misordered_lastrow(self): + related = Table('related', metadata, + Column('id', Integer, primary_key=True) + ) + t6 = Table("t6", metadata, + Column('manual_id', Integer, ForeignKey('related.id'), primary_key=True), + Column('auto_id', Integer, primary_key=True, test_needs_autoincrement=True), + ) + metadata.create_all() + r = related.insert().values(id=12).execute() + id = r.inserted_primary_key[0] + assert id==12 + + r = t6.insert().values(manual_id=id).execute() + eq_(r.inserted_primary_key, [12, 1]) + + def test_autoclose_on_insert(self): + if testing.against('firebird', 'postgresql', 'oracle', 'mssql'): + test_engines = [ + engines.testing_engine(options={'implicit_returning':False}), + engines.testing_engine(options={'implicit_returning':True}), + ] + else: + test_engines = [testing.db] + + for engine in test_engines: + + r = engine.execute(users.insert(), + {'user_name':'jack'}, + ) + assert r.closed + def test_row_iteration(self): users.insert().execute( {'user_id':7, 'user_name':'jack'}, @@ -147,7 +194,7 @@ class QueryTest(TestBase): l.append(row) self.assert_(len(l) == 3) - @testing.fails_on('firebird', 'Data type unknown') + @testing.fails_on('firebird', "kinterbasdb doesn't send full type information") @testing.requires.subqueries def test_anonymous_rows(self): users.insert().execute( @@ -161,6 +208,7 @@ class QueryTest(TestBase): assert row['anon_1'] == 8 assert row['anon_2'] == 10 + @testing.fails_on('firebird', "kinterbasdb doesn't send full type information") def test_order_by_label(self): """test that a label within an ORDER BY works on each backend. @@ -179,6 +227,11 @@ class QueryTest(TestBase): select([concat]).order_by(concat).execute().fetchall(), [("test: ed",), ("test: fred",), ("test: jack",)] ) + + eq_( + select([concat]).order_by(concat).execute().fetchall(), + [("test: ed",), ("test: fred",), ("test: jack",)] + ) concat = ("test: " + users.c.user_name).label('thedata') eq_( @@ -195,7 +248,7 @@ class QueryTest(TestBase): def test_row_comparison(self): users.insert().execute(user_id = 7, user_name = 'jack') - rp = users.select().execute().fetchone() + rp = users.select().execute().first() self.assert_(rp == rp) self.assert_(not(rp != rp)) @@ -207,8 +260,7 @@ class QueryTest(TestBase): self.assert_(not (rp != equal)) self.assert_(not (equal != equal)) - @testing.fails_on('mssql', 'No support for boolean logic in column select.') - @testing.fails_on('oracle', 'FIXME: unknown') + @testing.requires.boolean_col_expressions def test_or_and_as_columns(self): true, false = literal(True), literal(False) @@ -218,11 +270,11 @@ class QueryTest(TestBase): eq_(testing.db.execute(select([or_(false, false)])).scalar(), False) eq_(testing.db.execute(select([not_(or_(false, false))])).scalar(), True) - row = testing.db.execute(select([or_(false, false).label("x"), and_(true, false).label("y")])).fetchone() + row = testing.db.execute(select([or_(false, false).label("x"), and_(true, false).label("y")])).first() assert row.x == False assert row.y == False - row = testing.db.execute(select([or_(true, false).label("x"), and_(true, false).label("y")])).fetchone() + row = testing.db.execute(select([or_(true, false).label("x"), and_(true, false).label("y")])).first() assert row.x == True assert row.y == False @@ -253,6 +305,9 @@ class QueryTest(TestBase): eq_(expr.execute().fetchall(), result) + @testing.fails_on("firebird", "see dialect.test_firebird:MiscTest.test_percents_in_text") + @testing.fails_on("oracle", "neither % nor %% are accepted") + @testing.fails_on("+pg8000", "can't interpret result column from '%%'") @testing.emits_warning('.*now automatically escapes.*') def test_percents_in_text(self): for expr, result in ( @@ -277,7 +332,7 @@ class QueryTest(TestBase): eq_(select([users.c.user_id]).where(users.c.user_name.ilike('TWO')).execute().fetchall(), [(2, )]) - if testing.against('postgres'): + if testing.against('postgresql'): eq_(select([users.c.user_id]).where(users.c.user_name.like('one')).execute().fetchall(), [(1, )]) eq_(select([users.c.user_id]).where(users.c.user_name.like('TWO')).execute().fetchall(), []) @@ -373,7 +428,7 @@ class QueryTest(TestBase): s = select([datetable.alias('x').c.today]).as_scalar() s2 = select([datetable.c.id, s.label('somelabel')]) #print s2.c.somelabel.type - assert isinstance(s2.execute().fetchone()['somelabel'], datetime.datetime) + assert isinstance(s2.execute().first()['somelabel'], datetime.datetime) finally: datetable.drop() @@ -444,45 +499,58 @@ class QueryTest(TestBase): users.insert().execute(user_id=2, user_name='jack') addresses.insert().execute(address_id=1, user_id=2, address='foo@bar.com') - r = users.select(users.c.user_id==2).execute().fetchone() + r = users.select(users.c.user_id==2).execute().first() self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2) self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack') - - r = text("select * from query_users where user_id=2", bind=testing.db).execute().fetchone() + + r = text("select * from query_users where user_id=2", bind=testing.db).execute().first() self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2) self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack') - + # test slices - r = text("select * from query_addresses", bind=testing.db).execute().fetchone() + r = text("select * from query_addresses", bind=testing.db).execute().first() self.assert_(r[0:1] == (1,)) self.assert_(r[1:] == (2, 'foo@bar.com')) self.assert_(r[:-1] == (1, 2)) - + # test a little sqlite weirdness - with the UNION, cols come back as "query_users.user_id" in cursor.description r = text("select query_users.user_id, query_users.user_name from query_users " - "UNION select query_users.user_id, query_users.user_name from query_users", bind=testing.db).execute().fetchone() + "UNION select query_users.user_id, query_users.user_name from query_users", bind=testing.db).execute().first() self.assert_(r['user_id']) == 1 self.assert_(r['user_name']) == "john" # test using literal tablename.colname - r = text('select query_users.user_id AS "query_users.user_id", query_users.user_name AS "query_users.user_name" from query_users', bind=testing.db).execute().fetchone() + r = text('select query_users.user_id AS "query_users.user_id", ' + 'query_users.user_name AS "query_users.user_name" from query_users', + bind=testing.db).execute().first() self.assert_(r['query_users.user_id']) == 1 self.assert_(r['query_users.user_name']) == "john" # unary experssions - r = select([users.c.user_name.distinct()]).order_by(users.c.user_name).execute().fetchone() + r = select([users.c.user_name.distinct()]).order_by(users.c.user_name).execute().first() eq_(r[users.c.user_name], 'jack') eq_(r.user_name, 'jack') - r.close() + + def test_result_case_sensitivity(self): + """test name normalization for result sets.""" + row = testing.db.execute( + select([ + literal_column("1").label("case_insensitive"), + literal_column("2").label("CaseSensitive") + ]) + ).first() + + assert row.keys() == ["case_insensitive", "CaseSensitive"] + def test_row_as_args(self): users.insert().execute(user_id=1, user_name='john') - r = users.select(users.c.user_id==1).execute().fetchone() + r = users.select(users.c.user_id==1).execute().first() users.delete().execute() users.insert().execute(r) - assert users.select().execute().fetchall() == [(1, 'john')] - + eq_(users.select().execute().fetchall(), [(1, 'john')]) + def test_result_as_args(self): users.insert().execute([dict(user_id=1, user_name='john'), dict(user_id=2, user_name='ed')]) r = users.select().execute() @@ -496,13 +564,12 @@ class QueryTest(TestBase): def test_ambiguous_column(self): users.insert().execute(user_id=1, user_name='john') - r = users.outerjoin(addresses).select().execute().fetchone() - try: - print r['user_id'] - assert False - except exc.InvalidRequestError, e: - assert str(e) == "Ambiguous column name 'user_id' in result set! try 'use_labels' option on select statement." or \ - str(e) == "Ambiguous column name 'USER_ID' in result set! try 'use_labels' option on select statement." + r = users.outerjoin(addresses).select().execute().first() + assert_raises_message( + exc.InvalidRequestError, + "Ambiguous column name", + lambda: r['user_id'] + ) @testing.requires.subqueries def test_column_label_targeting(self): @@ -512,31 +579,29 @@ class QueryTest(TestBase): users.select().alias('foo'), users.select().alias(users.name), ): - row = s.select(use_labels=True).execute().fetchone() + row = s.select(use_labels=True).execute().first() assert row[s.c.user_id] == 7 assert row[s.c.user_name] == 'ed' def test_keys(self): users.insert().execute(user_id=1, user_name='foo') - r = users.select().execute().fetchone() + r = users.select().execute().first() eq_([x.lower() for x in r.keys()], ['user_id', 'user_name']) def test_items(self): users.insert().execute(user_id=1, user_name='foo') - r = users.select().execute().fetchone() + r = users.select().execute().first() eq_([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')]) def test_len(self): users.insert().execute(user_id=1, user_name='foo') - r = users.select().execute().fetchone() + r = users.select().execute().first() eq_(len(r), 2) - r.close() - r = testing.db.execute('select user_name, user_id from query_users').fetchone() + + r = testing.db.execute('select user_name, user_id from query_users').first() eq_(len(r), 2) - r.close() - r = testing.db.execute('select user_name from query_users').fetchone() + r = testing.db.execute('select user_name from query_users').first() eq_(len(r), 1) - r.close() def test_cant_execute_join(self): try: @@ -549,7 +614,7 @@ class QueryTest(TestBase): def test_column_order_with_simple_query(self): # should return values in column definition order users.insert().execute(user_id=1, user_name='foo') - r = users.select(users.c.user_id==1).execute().fetchone() + r = users.select(users.c.user_id==1).execute().first() eq_(r[0], 1) eq_(r[1], 'foo') eq_([x.lower() for x in r.keys()], ['user_id', 'user_name']) @@ -558,7 +623,7 @@ class QueryTest(TestBase): def test_column_order_with_text_query(self): # should return values in query order users.insert().execute(user_id=1, user_name='foo') - r = testing.db.execute('select user_name, user_id from query_users').fetchone() + r = testing.db.execute('select user_name, user_id from query_users').first() eq_(r[0], 'foo') eq_(r[1], 1) eq_([x.lower() for x in r.keys()], ['user_name', 'user_id']) @@ -580,7 +645,7 @@ class QueryTest(TestBase): shadowed.create(checkfirst=True) try: shadowed.insert().execute(shadow_id=1, shadow_name='The Shadow', parent='The Light', row='Without light there is no shadow', __parent='Hidden parent', __row='Hidden row') - r = shadowed.select(shadowed.c.shadow_id==1).execute().fetchone() + r = shadowed.select(shadowed.c.shadow_id==1).execute().first() self.assert_(r.shadow_id == r['shadow_id'] == r[shadowed.c.shadow_id] == 1) self.assert_(r.shadow_name == r['shadow_name'] == r[shadowed.c.shadow_name] == 'The Shadow') self.assert_(r.parent == r['parent'] == r[shadowed.c.parent] == 'The Light') @@ -622,13 +687,13 @@ class QueryTest(TestBase): # Null values are not outside any set assert len(r) == 0 - u = bindparam('search_key') + @testing.fails_on('firebird', "kinterbasdb doesn't send full type information") + def test_bind_in(self): + users.insert().execute(user_id = 7, user_name = 'jack') + users.insert().execute(user_id = 8, user_name = 'fred') + users.insert().execute(user_id = 9, user_name = None) - s = users.select(u.in_([])) - r = s.execute(search_key='john').fetchall() - assert len(r) == 0 - r = s.execute(search_key=None).fetchall() - assert len(r) == 0 + u = bindparam('search_key') s = users.select(not_(u.in_([]))) r = s.execute(search_key='john').fetchall() @@ -660,14 +725,15 @@ class QueryTest(TestBase): class PercentSchemaNamesTest(TestBase): """tests using percent signs, spaces in table and column names. - Doesn't pass for mysql, postgres, but this is really a + Doesn't pass for mysql, postgresql, but this is really a SQLAlchemy bug - we should be escaping out %% signs for this operation the same way we do for text() and column labels. """ + @classmethod @testing.crashes('mysql', 'mysqldb calls name % (params)') - @testing.crashes('postgres', 'postgres calls name % (params)') + @testing.crashes('postgresql', 'postgresql calls name % (params)') def setup_class(cls): global percent_table, metadata metadata = MetaData(testing.db) @@ -680,12 +746,12 @@ class PercentSchemaNamesTest(TestBase): @classmethod @testing.crashes('mysql', 'mysqldb calls name % (params)') - @testing.crashes('postgres', 'postgres calls name % (params)') + @testing.crashes('postgresql', 'postgresql calls name % (params)') def teardown_class(cls): metadata.drop_all() @testing.crashes('mysql', 'mysqldb calls name % (params)') - @testing.crashes('postgres', 'postgres calls name % (params)') + @testing.crashes('postgresql', 'postgresql calls name % (params)') def test_roundtrip(self): percent_table.insert().execute( {'percent%':5, '%(oneofthese)s':7, 'spaces % more spaces':12}, @@ -731,7 +797,7 @@ class PercentSchemaNamesTest(TestBase): percent_table.update().values({percent_table.c['%(oneofthese)s']:9, percent_table.c['spaces % more spaces']:15}).execute() eq_( - percent_table.select().order_by(percent_table.c['%(oneofthese)s']).execute().fetchall(), + percent_table.select().order_by(percent_table.c['percent%']).execute().fetchall(), [ (5, 9, 15), (7, 9, 15), @@ -852,7 +918,11 @@ class CompoundTest(TestBase): dict(col2="t3col2r2", col3="bbb", col4="aaa"), dict(col2="t3col2r3", col3="ccc", col4="bbb"), ]) - + + @engines.close_first + def teardown(self): + pass + @classmethod def teardown_class(cls): metadata.drop_all() @@ -878,6 +948,7 @@ class CompoundTest(TestBase): found2 = self._fetchall_sorted(u.alias('bar').select().execute()) eq_(found2, wanted) + @testing.fails_on('firebird', "doesn't like ORDER BY with UNIONs") def test_union_ordered(self): (s1, s2) = ( select([t1.c.col3.label('col3'), t1.c.col4.label('col4')], @@ -891,6 +962,7 @@ class CompoundTest(TestBase): ('ccc', 'aaa')] eq_(u.execute().fetchall(), wanted) + @testing.fails_on('firebird', "doesn't like ORDER BY with UNIONs") @testing.fails_on('maxdb', 'FIXME: unknown') @testing.requires.subqueries def test_union_ordered_alias(self): @@ -907,6 +979,7 @@ class CompoundTest(TestBase): eq_(u.alias('bar').select().execute().fetchall(), wanted) @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on') + @testing.fails_on('firebird', "has trouble extracting anonymous column from union subquery") @testing.fails_on('mysql', 'FIXME: unknown') @testing.fails_on('sqlite', 'FIXME: unknown') def test_union_all(self): @@ -925,6 +998,29 @@ class CompoundTest(TestBase): found2 = self._fetchall_sorted(e.alias('foo').select().execute()) eq_(found2, wanted) + def test_union_all_lightweight(self): + """like test_union_all, but breaks the sub-union into + a subquery with an explicit column reference on the outside, + more palatable to a wider variety of engines. + + """ + u = union( + select([t1.c.col3]), + select([t1.c.col3]), + ).alias() + + e = union_all( + select([t1.c.col3]), + select([u.c.col3]) + ) + + wanted = [('aaa',),('aaa',),('bbb',), ('bbb',), ('ccc',),('ccc',)] + found1 = self._fetchall_sorted(e.execute()) + eq_(found1, wanted) + + found2 = self._fetchall_sorted(e.alias('foo').select().execute()) + eq_(found2, wanted) + @testing.crashes('firebird', 'Does not support intersect') @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on') @testing.fails_on('mysql', 'FIXME: unknown') @@ -1330,3 +1426,6 @@ class OperatorTest(TestBase): order_by=flds.c.idcol).execute().fetchall(), [(2,),(1,)] ) + + + diff --git a/test/sql/test_quote.py b/test/sql/test_quote.py index 64e097b85..3198a07af 100644 --- a/test/sql/test_quote.py +++ b/test/sql/test_quote.py @@ -129,7 +129,7 @@ class QuoteTest(TestBase, AssertsCompiledSQL): def testlabels(self): """test the quoting of labels. - if labels arent quoted, a query in postgres in particular will fail since it produces: + if labels arent quoted, a query in postgresql in particular will fail since it produces: SELECT LaLa.lowercase, LaLa."UPPERCASE", LaLa."MixedCase", LaLa."ASC" FROM (SELECT DISTINCT "WorstCase1".lowercase AS lowercase, "WorstCase1"."UPPERCASE" AS UPPERCASE, "WorstCase1"."MixedCase" AS MixedCase, "WorstCase1"."ASC" AS ASC \nFROM "WorstCase1") AS LaLa diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py new file mode 100644 index 000000000..e076f3fe7 --- /dev/null +++ b/test/sql/test_returning.py @@ -0,0 +1,159 @@ +from sqlalchemy.test.testing import eq_ +from sqlalchemy import * +from sqlalchemy.test import * +from sqlalchemy.test.schema import Table, Column +from sqlalchemy.types import TypeDecorator + + +class ReturningTest(TestBase, AssertsExecutionResults): + __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access') + + def setup(self): + meta = MetaData(testing.db) + global table, GoofyType + + class GoofyType(TypeDecorator): + impl = String + + def process_bind_param(self, value, dialect): + if value is None: + return None + return "FOO" + value + + def process_result_value(self, value, dialect): + if value is None: + return None + return value + "BAR" + + table = Table('tables', meta, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('persons', Integer), + Column('full', Boolean), + Column('goofy', GoofyType(50)) + ) + table.create(checkfirst=True) + + def teardown(self): + table.drop() + + @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_column_targeting(self): + result = table.insert().returning(table.c.id, table.c.full).execute({'persons': 1, 'full': False}) + + row = result.first() + assert row[table.c.id] == row['id'] == 1 + assert row[table.c.full] == row['full'] == False + + result = table.insert().values(persons=5, full=True, goofy="somegoofy").\ + returning(table.c.persons, table.c.full, table.c.goofy).execute() + row = result.first() + assert row[table.c.persons] == row['persons'] == 5 + assert row[table.c.full] == row['full'] == True + assert row[table.c.goofy] == row['goofy'] == "FOOsomegoofyBAR" + + @testing.fails_on('firebird', "fb can't handle returning x AS y") + @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_labeling(self): + result = table.insert().values(persons=6).\ + returning(table.c.persons.label('lala')).execute() + row = result.first() + assert row['lala'] == 6 + + @testing.fails_on('firebird', "fb/kintersbasdb can't handle the bind params") + @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_anon_expressions(self): + result = table.insert().values(goofy="someOTHERgoofy").\ + returning(func.lower(table.c.goofy, type_=GoofyType)).execute() + row = result.first() + assert row[0] == "foosomeothergoofyBAR" + + result = table.insert().values(persons=12).\ + returning(table.c.persons + 18).execute() + row = result.first() + assert row[0] == 30 + + @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_update_returning(self): + table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) + + result = table.update(table.c.persons > 4, dict(full=True)).returning(table.c.id).execute() + eq_(result.fetchall(), [(1,)]) + + result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() + eq_(result2.fetchall(), [(1,True),(2,False)]) + + @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_insert_returning(self): + result = table.insert().returning(table.c.id).execute({'persons': 1, 'full': False}) + + eq_(result.fetchall(), [(1,)]) + + @testing.fails_on('postgresql', '') + @testing.fails_on('oracle', '') + def test_executemany(): + # return value is documented as failing with psycopg2/executemany + result2 = table.insert().returning(table).execute( + [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}]) + + if testing.against('firebird', 'mssql'): + # Multiple inserts only return the last row + eq_(result2.fetchall(), [(3,3,True, None)]) + else: + # nobody does this as far as we know (pg8000?) + eq_(result2.fetchall(), [(2, 2, False, None), (3,3,True, None)]) + + test_executemany() + + result3 = table.insert().returning(table.c.id).execute({'persons': 4, 'full': False}) + eq_([dict(row) for row in result3], [{'id': 4}]) + + + @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + @testing.fails_on_everything_except('postgresql', 'firebird') + def test_literal_returning(self): + if testing.against("postgresql"): + literal_true = "true" + else: + literal_true = "1" + + result4 = testing.db.execute('insert into tables (id, persons, "full") ' + 'values (5, 10, %s) returning persons' % literal_true) + eq_([dict(row) for row in result4], [{'persons': 10}]) + + @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_delete_returning(self): + table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) + + result = table.delete(table.c.persons > 4).returning(table.c.id).execute() + eq_(result.fetchall(), [(1,)]) + + result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() + eq_(result2.fetchall(), [(2,False),]) + +class SequenceReturningTest(TestBase): + __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access', 'mssql') + + def setup(self): + meta = MetaData(testing.db) + global table, seq + seq = Sequence('tid_seq') + table = Table('tables', meta, + Column('id', Integer, seq, primary_key=True), + Column('data', String(50)) + ) + table.create(checkfirst=True) + + def teardown(self): + table.drop() + + def test_insert(self): + r = table.insert().values(data='hi').returning(table.c.id).execute() + assert r.first() == (1, ) + assert seq.execute() == 2 diff --git a/test/sql/test_select.py b/test/sql/test_select.py index f70492fb3..9acc94eb2 100644 --- a/test/sql/test_select.py +++ b/test/sql/test_select.py @@ -5,7 +5,7 @@ from sqlalchemy import exc, sql, util from sqlalchemy.sql import table, column, label, compiler from sqlalchemy.sql.expression import ClauseList from sqlalchemy.engine import default -from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql +from sqlalchemy.databases import * from sqlalchemy.test import * table1 = table('mytable', @@ -149,12 +149,10 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A ) self.assert_compile( - select([cast("data", sqlite.SLInteger)], use_labels=True), # this will work with plain Integer in 0.6 + select([cast("data", Integer)], use_labels=True), # this will work with plain Integer in 0.6 "SELECT CAST(:param_1 AS INTEGER) AS anon_1" ) - - def test_nested_uselabels(self): """test nested anonymous label generation. this essentially tests the ANONYMOUS_LABEL regex. @@ -429,7 +427,12 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A def test_operators(self): for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'), - (operator.sub, '-'), (operator.div, '/'), + (operator.sub, '-'), + # Py3K + #(operator.truediv, '/'), + # Py2K + (operator.div, '/'), + # end Py2K ): for (lhs, rhs, res) in ( (5, table1.c.myid, ':myid_1 %s mytable.myid'), @@ -519,22 +522,22 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A (~table1.c.myid.like('somstr', escape='\\'), "mytable.myid NOT LIKE :myid_1 ESCAPE '\\'", None), (table1.c.myid.ilike('somstr', escape='\\'), "lower(mytable.myid) LIKE lower(:myid_1) ESCAPE '\\'", None), (~table1.c.myid.ilike('somstr', escape='\\'), "lower(mytable.myid) NOT LIKE lower(:myid_1) ESCAPE '\\'", None), - (table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid ILIKE %(myid_1)s ESCAPE '\\'", postgres.PGDialect()), - (~table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid NOT ILIKE %(myid_1)s ESCAPE '\\'", postgres.PGDialect()), + (table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid ILIKE %(myid_1)s ESCAPE '\\'", postgresql.PGDialect()), + (~table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid NOT ILIKE %(myid_1)s ESCAPE '\\'", postgresql.PGDialect()), (table1.c.name.ilike('%something%'), "lower(mytable.name) LIKE lower(:name_1)", None), - (table1.c.name.ilike('%something%'), "mytable.name ILIKE %(name_1)s", postgres.PGDialect()), + (table1.c.name.ilike('%something%'), "mytable.name ILIKE %(name_1)s", postgresql.PGDialect()), (~table1.c.name.ilike('%something%'), "lower(mytable.name) NOT LIKE lower(:name_1)", None), - (~table1.c.name.ilike('%something%'), "mytable.name NOT ILIKE %(name_1)s", postgres.PGDialect()), + (~table1.c.name.ilike('%something%'), "mytable.name NOT ILIKE %(name_1)s", postgresql.PGDialect()), ]: self.assert_compile(expr, check, dialect=dialect) def test_match(self): for expr, check, dialect in [ (table1.c.myid.match('somstr'), "mytable.myid MATCH ?", sqlite.SQLiteDialect()), - (table1.c.myid.match('somstr'), "MATCH (mytable.myid) AGAINST (%s IN BOOLEAN MODE)", mysql.MySQLDialect()), - (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", mssql.MSSQLDialect()), - (table1.c.myid.match('somstr'), "mytable.myid @@ to_tsquery(%(myid_1)s)", postgres.PGDialect()), - (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", oracle.OracleDialect()), + (table1.c.myid.match('somstr'), "MATCH (mytable.myid) AGAINST (%s IN BOOLEAN MODE)", mysql.dialect()), + (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", mssql.dialect()), + (table1.c.myid.match('somstr'), "mytable.myid @@ to_tsquery(%(myid_1)s)", postgresql.dialect()), + (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", oracle.dialect()), ]: self.assert_compile(expr, check, dialect=dialect) @@ -635,7 +638,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A select([table1.alias('foo')]) ,"SELECT foo.myid, foo.name, foo.description FROM mytable AS foo") - for dialect in (firebird.dialect(), oracle.dialect()): + for dialect in (oracle.dialect(),): self.assert_compile( select([table1.alias('foo')]) ,"SELECT foo.myid, foo.name, foo.description FROM mytable foo" @@ -748,7 +751,7 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = params={}, ) - dialect = postgres.dialect() + dialect = postgresql.dialect() self.assert_compile( text("select * from foo where lala=:bar and hoho=:whee", bindparams=[bindparam('bar',4), bindparam('whee',7)]), "select * from foo where lala=%(bar)s and hoho=%(whee)s", @@ -1122,10 +1125,10 @@ UNION SELECT mytable.myid FROM mytable" self.assert_compile(stmt, expected_positional_stmt, dialect=sqlite.dialect()) nonpositional = stmt.compile() positional = stmt.compile(dialect=sqlite.dialect()) - pp = positional.get_params() + pp = positional.params assert [pp[k] for k in positional.positiontup] == expected_default_params_list - assert nonpositional.get_params(**test_param_dict) == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict))) - pp = positional.get_params(**test_param_dict) + assert nonpositional.construct_params(test_param_dict) == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict))) + pp = positional.construct_params(test_param_dict) assert [pp[k] for k in positional.positiontup] == expected_test_params_list # check that params() doesnt modify original statement @@ -1144,7 +1147,7 @@ UNION SELECT mytable.myid FROM mytable" ":myid_1) AS anon_1 FROM mytable WHERE mytable.myid = (SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1)") positional = s2.compile(dialect=sqlite.dialect()) - pp = positional.get_params() + pp = positional.params assert [pp[k] for k in positional.positiontup] == [12, 12] # check that conflicts with "unique" params are caught @@ -1163,11 +1166,11 @@ UNION SELECT mytable.myid FROM mytable" params = dict(('in%d' % i, i) for i in range(total_params)) sql = 'text clause %s' % ', '.join(in_clause) t = text(sql) - assert len(t.bindparams) == total_params + eq_(len(t.bindparams), total_params) c = t.compile() pp = c.construct_params(params) - assert len(set(pp)) == total_params - assert len(set(pp.values())) == total_params + eq_(len(set(pp)), total_params, '%s %s' % (len(set(pp)), len(pp))) + eq_(len(set(pp.values())), total_params) def test_bind_as_col(self): @@ -1291,28 +1294,28 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)") eq_(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[0]) eq_(str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[1]) eq_(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), 'CAST(casttest.ts AS %s)' %expected_results[2]) - eq_(str(cast(1234, TEXT).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3])) + eq_(str(cast(1234, Text).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3])) eq_(str(cast('test', String(20)).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[4])) # fixme: shoving all of this dialect-specific stuff in one test # is now officialy completely ridiculous AND non-obviously omits # coverage on other dialects. sel = select([tbl, cast(tbl.c.v1, Numeric)]).compile(dialect=dialect) if isinstance(dialect, type(mysql.dialect())): - eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) AS anon_1 \nFROM casttest") + eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL) AS anon_1 \nFROM casttest") else: - eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) AS anon_1 \nFROM casttest") + eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC) AS anon_1 \nFROM casttest") # first test with PostgreSQL engine - check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(param_1)s') + check_results(postgresql.dialect(), ['NUMERIC', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(param_1)s') # then the Oracle engine - check_results(oracle.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':param_1') + check_results(oracle.dialect(), ['NUMERIC', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':param_1') # then the sqlite engine - check_results(sqlite.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?') + check_results(sqlite.dialect(), ['NUMERIC', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?') # then the MySQL engine - check_results(mysql.dialect(), ['DECIMAL(10, 2)', 'DECIMAL(12, 9)', 'DATE', 'CHAR', 'CHAR(20)'], '%s') + check_results(mysql.dialect(), ['DECIMAL', 'DECIMAL(12, 9)', 'DATE', 'CHAR', 'CHAR(20)'], '%s') self.assert_compile(cast(text('NULL'), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect()) self.assert_compile(cast(null(), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect()) @@ -1360,7 +1363,6 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)") s1 = select([table1.c.myid, table1.c.myid.label('foobar'), func.hoho(table1.c.name), func.lala(table1.c.name).label('gg')]) assert s1.c.keys() == ['myid', 'foobar', 'hoho(mytable.name)', 'gg'] - from sqlalchemy.databases.sqlite import SLNumeric meta = MetaData() t1 = Table('mytable', meta, Column('col1', Integer)) @@ -1368,7 +1370,7 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)") (table1.c.name, 'name', 'mytable.name', None), (table1.c.myid==12, 'mytable.myid = :myid_1', 'mytable.myid = :myid_1', 'anon_1'), (func.hoho(table1.c.myid), 'hoho(mytable.myid)', 'hoho(mytable.myid)', 'hoho_1'), - (cast(table1.c.name, SLNumeric), 'CAST(mytable.name AS NUMERIC(10, 2))', 'CAST(mytable.name AS NUMERIC(10, 2))', 'anon_1'), + (cast(table1.c.name, Numeric), 'CAST(mytable.name AS NUMERIC)', 'CAST(mytable.name AS NUMERIC)', 'anon_1'), (t1.c.col1, 'col1', 'mytable.col1', None), (column('some wacky thing'), 'some wacky thing', '"some wacky thing"', '') ): diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index b0501c913..95ca0d17b 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -416,7 +416,7 @@ class ReduceTest(TestBase, AssertsExecutionResults): Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True), ) - # this is essentially the union formed by the ORM's polymorphic_union function. + # this is essentially the union formed by the ORM's polymorphic_union function. # we define two versions with different ordering of selects. # the first selectable has the "real" column classified_page.magazine_page_id @@ -432,7 +432,6 @@ class ReduceTest(TestBase, AssertsExecutionResults): magazine_page_table.c.page_id, cast(null(), Integer).label('magazine_page_id') ]).select_from(page_table.join(magazine_page_table)), - ).alias('pjoin') eq_( diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 15799358a..9c90549e2 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1,101 +1,63 @@ +# coding: utf-8 from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import decimal import datetime, os, re from sqlalchemy import * -from sqlalchemy import exc, types, util +from sqlalchemy import exc, types, util, schema from sqlalchemy.sql import operators from sqlalchemy.test.testing import eq_ import sqlalchemy.engine.url as url -from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird +from sqlalchemy.databases import * + from sqlalchemy.test import * class AdaptTest(TestBase): - def testadapt(self): - e1 = url.URL('postgres').get_dialect()() - e2 = url.URL('mysql').get_dialect()() - e3 = url.URL('sqlite').get_dialect()() - e4 = url.URL('firebird').get_dialect()() - - type = String(40) - - t1 = type.dialect_impl(e1) - t2 = type.dialect_impl(e2) - t3 = type.dialect_impl(e3) - t4 = type.dialect_impl(e4) - - impls = [t1, t2, t3, t4] - for i,ta in enumerate(impls): - for j,tb in enumerate(impls): - if i == j: - assert ta == tb # call me paranoid... :) + def test_uppercase_rendering(self): + """Test that uppercase types from types.py always render as their type. + + As of SQLA 0.6, using an uppercase type means you want specifically that + type. If the database in use doesn't support that DDL, it (the DB backend) + should raise an error - it means you should be using a lowercased (genericized) type. + + """ + + for dialect in [ + oracle.dialect(), + mysql.dialect(), + postgresql.dialect(), + sqlite.dialect(), + sybase.dialect(), + informix.dialect(), + maxdb.dialect(), + mssql.dialect()]: # TODO when dialects are complete: engines.all_dialects(): + for type_, expected in ( + (FLOAT, "FLOAT"), + (NUMERIC, "NUMERIC"), + (DECIMAL, "DECIMAL"), + (INTEGER, "INTEGER"), + (SMALLINT, "SMALLINT"), + (TIMESTAMP, "TIMESTAMP"), + (DATETIME, "DATETIME"), + (DATE, "DATE"), + (TIME, "TIME"), + (CLOB, "CLOB"), + (VARCHAR, "VARCHAR"), + (NVARCHAR, ("NVARCHAR", "NATIONAL VARCHAR")), + (CHAR, "CHAR"), + (NCHAR, ("NCHAR", "NATIONAL CHAR")), + (BLOB, "BLOB"), + (BOOLEAN, ("BOOLEAN", "BOOL")) + ): + if isinstance(expected, str): + expected = (expected, ) + for exp in expected: + compiled = type_().compile(dialect=dialect) + if exp in compiled: + break else: - assert ta != tb - - def testmsnvarchar(self): - dialect = mssql.MSSQLDialect() - # run the test twice to ensure the caching step works too - for x in range(0, 1): - col = Column('', Unicode(length=10)) - dialect_type = col.type.dialect_impl(dialect) - assert isinstance(dialect_type, mssql.MSNVarchar) - assert dialect_type.get_col_spec() == 'NVARCHAR(10)' - - - def testoracletimestamp(self): - dialect = oracle.OracleDialect() - t1 = oracle.OracleTimestamp - t2 = oracle.OracleTimestamp() - t3 = types.TIMESTAMP - assert isinstance(dialect.type_descriptor(t1), oracle.OracleTimestamp) - assert isinstance(dialect.type_descriptor(t2), oracle.OracleTimestamp) - assert isinstance(dialect.type_descriptor(t3), oracle.OracleTimestamp) - - def testmysqlbinary(self): - dialect = mysql.MySQLDialect() - t1 = mysql.MSVarBinary - t2 = mysql.MSVarBinary() - assert isinstance(dialect.type_descriptor(t1), mysql.MSVarBinary) - assert isinstance(dialect.type_descriptor(t2), mysql.MSVarBinary) - - def teststringadapt(self): - """test that String with no size becomes TEXT, *all* others stay as varchar/String""" - - oracle_dialect = oracle.OracleDialect() - mysql_dialect = mysql.MySQLDialect() - postgres_dialect = postgres.PGDialect() - firebird_dialect = firebird.FBDialect() - - for dialect, start, test in [ - (oracle_dialect, String(), oracle.OracleString), - (oracle_dialect, VARCHAR(), oracle.OracleString), - (oracle_dialect, String(50), oracle.OracleString), - (oracle_dialect, Unicode(), oracle.OracleString), - (oracle_dialect, UnicodeText(), oracle.OracleText), - (oracle_dialect, NCHAR(), oracle.OracleString), - (oracle_dialect, oracle.OracleRaw(50), oracle.OracleRaw), - (mysql_dialect, String(), mysql.MSString), - (mysql_dialect, VARCHAR(), mysql.MSString), - (mysql_dialect, String(50), mysql.MSString), - (mysql_dialect, Unicode(), mysql.MSString), - (mysql_dialect, UnicodeText(), mysql.MSText), - (mysql_dialect, NCHAR(), mysql.MSNChar), - (postgres_dialect, String(), postgres.PGString), - (postgres_dialect, VARCHAR(), postgres.PGString), - (postgres_dialect, String(50), postgres.PGString), - (postgres_dialect, Unicode(), postgres.PGString), - (postgres_dialect, UnicodeText(), postgres.PGText), - (postgres_dialect, NCHAR(), postgres.PGString), - (firebird_dialect, String(), firebird.FBString), - (firebird_dialect, VARCHAR(), firebird.FBString), - (firebird_dialect, String(50), firebird.FBString), - (firebird_dialect, Unicode(), firebird.FBString), - (firebird_dialect, UnicodeText(), firebird.FBText), - (firebird_dialect, NCHAR(), firebird.FBString), - ]: - assert isinstance(start.dialect_impl(dialect), test), "wanted %r got %r" % (test, start.dialect_impl(dialect)) - - + assert False, "%r matches none of %r for dialect %s" % (compiled, expected, dialect.name) + class UserDefinedTest(TestBase): """tests user-defined types.""" @@ -131,7 +93,7 @@ class UserDefinedTest(TestBase): def setup_class(cls): global users, metadata - class MyType(types.TypeEngine): + class MyType(types.UserDefinedType): def get_col_spec(self): return "VARCHAR(100)" def bind_processor(self, dialect): @@ -267,124 +229,105 @@ class ColumnsTest(TestBase, AssertsExecutionResults): for aCol in testTable.c: eq_( expectedResults[aCol.name], - db.dialect.schemagenerator(db.dialect, db, None, None).\ + db.dialect.ddl_compiler(db.dialect, schema.CreateTable(testTable)).\ get_column_specification(aCol)) class UnicodeTest(TestBase, AssertsExecutionResults): """tests the Unicode type. also tests the TypeDecorator with instances in the types package.""" + @classmethod def setup_class(cls): - global unicode_table + global unicode_table, metadata metadata = MetaData(testing.db) unicode_table = Table('unicode_table', metadata, Column('id', Integer, Sequence('uni_id_seq', optional=True), primary_key=True), Column('unicode_varchar', Unicode(250)), Column('unicode_text', UnicodeText), - Column('plain_varchar', String(250)) ) - unicode_table.create() + metadata.create_all() + @classmethod def teardown_class(cls): - unicode_table.drop() + metadata.drop_all() + @engines.close_first def teardown(self): unicode_table.delete().execute() def test_round_trip(self): - assert unicode_table.c.unicode_varchar.type.length == 250 - rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' - unicodedata = rawdata.decode('utf-8') - if testing.against('sqlite'): - rawdata = "something" - - unicode_table.insert().execute(unicode_varchar=unicodedata, - unicode_text=unicodedata, - plain_varchar=rawdata) - x = unicode_table.select().execute().fetchone() + unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: « S’il vous plaît… dessine-moi un mouton! »" + + unicode_table.insert().execute(unicode_varchar=unicodedata,unicode_text=unicodedata) + + x = unicode_table.select().execute().first() self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) - if isinstance(x['plain_varchar'], unicode): - # SQLLite and MSSQL return non-unicode data as unicode - self.assert_(testing.against('sqlite', 'mssql')) - if not testing.against('sqlite'): - self.assert_(x['plain_varchar'] == unicodedata) - else: - self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata) - def test_union(self): - """ensure compiler processing works for UNIONs""" + def test_round_trip_executemany(self): + # cx_oracle was producing different behavior for cursor.executemany() + # vs. cursor.execute() + + unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: « S’il vous plaît… dessine-moi un mouton! »" - rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' - unicodedata = rawdata.decode('utf-8') - if testing.against('sqlite'): - rawdata = "something" - unicode_table.insert().execute(unicode_varchar=unicodedata, - unicode_text=unicodedata, - plain_varchar=rawdata) - - x = union(select([unicode_table.c.unicode_varchar]), select([unicode_table.c.unicode_varchar])).execute().fetchone() + unicode_table.insert().execute( + dict(unicode_varchar=unicodedata,unicode_text=unicodedata), + dict(unicode_varchar=unicodedata,unicode_text=unicodedata) + ) + + x = unicode_table.select().execute().first() self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) + self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) - def test_assertions(self): - try: - unicode_table.insert().execute(unicode_varchar='not unicode') - assert False - except exc.SAWarning, e: - assert str(e) == "Unicode type received non-unicode bind param value 'not unicode'", str(e) + def test_union(self): + """ensure compiler processing works for UNIONs""" - unicode_engine = engines.utf8_engine(options={'convert_unicode':True, - 'assert_unicode':True}) - try: - try: - unicode_engine.execute(unicode_table.insert(), plain_varchar='im not unicode') - assert False - except exc.InvalidRequestError, e: - assert str(e) == "Unicode type received non-unicode bind param value 'im not unicode'" - - @testing.emits_warning('.*non-unicode bind') - def warns(): - # test that data still goes in if warning is emitted.... - unicode_table.insert().execute(unicode_varchar='not unicode') - assert (select([unicode_table.c.unicode_varchar]).execute().fetchall() == [('not unicode', )]) - warns() + unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: « S’il vous plaît… dessine-moi un mouton! »" - finally: - unicode_engine.dispose() + unicode_table.insert().execute(unicode_varchar=unicodedata,unicode_text=unicodedata) + + x = union(select([unicode_table.c.unicode_varchar]), select([unicode_table.c.unicode_varchar])).execute().first() + self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) - @testing.fails_on('oracle', 'FIXME: unknown') + @testing.fails_on('oracle', 'oracle converts empty strings to a blank space') def test_blank_strings(self): unicode_table.insert().execute(unicode_varchar=u'') assert select([unicode_table.c.unicode_varchar]).scalar() == u'' - def test_engine_parameter(self): - """tests engine-wide unicode conversion""" - prev_unicode = testing.db.engine.dialect.convert_unicode - prev_assert = testing.db.engine.dialect.assert_unicode - try: - testing.db.engine.dialect.convert_unicode = True - testing.db.engine.dialect.assert_unicode = False - rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' - unicodedata = rawdata.decode('utf-8') - if testing.against('sqlite', 'mssql'): - rawdata = "something" - unicode_table.insert().execute(unicode_varchar=unicodedata, - unicode_text=unicodedata, - plain_varchar=rawdata) - x = unicode_table.select().execute().fetchone() - self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) - self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) - if not testing.against('sqlite', 'mssql'): - self.assert_(isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == unicodedata) - finally: - testing.db.engine.dialect.convert_unicode = prev_unicode - testing.db.engine.dialect.convert_unicode = prev_assert - - @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on') - @testing.fails_on('firebird', 'Data type unknown') - def test_length_function(self): - """checks the database correctly understands the length of a unicode string""" - teststr = u'aaa\x1234' - self.assert_(testing.db.func.length(teststr).scalar() == len(teststr)) + def test_parameters(self): + """test the dialect convert_unicode parameters.""" + + unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: « S’il vous plaît… dessine-moi un mouton! »" + + u = Unicode(assert_unicode=True) + uni = u.dialect_impl(testing.db.dialect).bind_processor(testing.db.dialect) + # Py3K + #assert_raises(exc.InvalidRequestError, uni, b'x') + # Py2K + assert_raises(exc.InvalidRequestError, uni, 'x') + # end Py2K + + u = Unicode() + uni = u.dialect_impl(testing.db.dialect).bind_processor(testing.db.dialect) + # Py3K + #assert_raises(exc.SAWarning, uni, b'x') + # Py2K + assert_raises(exc.SAWarning, uni, 'x') + # end Py2K + + unicode_engine = engines.utf8_engine(options={'convert_unicode':True,'assert_unicode':True}) + unicode_engine.dialect.supports_unicode_binds = False + + s = String() + uni = s.dialect_impl(unicode_engine.dialect).bind_processor(unicode_engine.dialect) + # Py3K + #assert_raises(exc.InvalidRequestError, uni, b'x') + #assert isinstance(uni(unicodedata), bytes) + # Py2K + assert_raises(exc.InvalidRequestError, uni, 'x') + assert isinstance(uni(unicodedata), str) + # end Py2K + + assert uni(unicodedata) == unicodedata.encode('utf-8') class BinaryTest(TestBase, AssertsExecutionResults): __excluded_on__ = ( @@ -409,18 +352,19 @@ class BinaryTest(TestBase, AssertsExecutionResults): return value binary_table = Table('binary_table', MetaData(testing.db), - Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True), - Column('data', Binary), - Column('data_slice', Binary(100)), - Column('misc', String(30)), - # construct PickleType with non-native pickle module, since cPickle uses relative module - # loading and confuses this test's parent package 'sql' with the 'sqlalchemy.sql' package relative - # to the 'types' module - Column('pickled', PickleType), - Column('mypickle', MyPickleType) + Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True), + Column('data', Binary), + Column('data_slice', Binary(100)), + Column('misc', String(30)), + # construct PickleType with non-native pickle module, since cPickle uses relative module + # loading and confuses this test's parent package 'sql' with the 'sqlalchemy.sql' package relative + # to the 'types' module + Column('pickled', PickleType), + Column('mypickle', MyPickleType) ) binary_table.create() + @engines.close_first def teardown(self): binary_table.delete().execute() @@ -428,42 +372,65 @@ class BinaryTest(TestBase, AssertsExecutionResults): def teardown_class(cls): binary_table.drop() - @testing.fails_on('mssql', 'MSSQl BINARY type right pads the fixed length with \x00') - def testbinary(self): + def test_round_trip(self): testobj1 = pickleable.Foo('im foo 1') testobj2 = pickleable.Foo('im foo 2') testobj3 = pickleable.Foo('im foo 3') stream1 =self.load_stream('binary_data_one.dat') stream2 =self.load_stream('binary_data_two.dat') - binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_slice=stream1[0:100], pickled=testobj1, mypickle=testobj3) - binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_slice=stream2[0:99], pickled=testobj2) - binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_slice=stream2[0:99], pickled=None) + binary_table.insert().execute( + primary_id=1, + misc='binary_data_one.dat', + data=stream1, + data_slice=stream1[0:100], + pickled=testobj1, + mypickle=testobj3) + binary_table.insert().execute( + primary_id=2, + misc='binary_data_two.dat', + data=stream2, + data_slice=stream2[0:99], + pickled=testobj2) + binary_table.insert().execute( + primary_id=3, + misc='binary_data_two.dat', + data=None, + data_slice=stream2[0:99], + pickled=None) for stmt in ( binary_table.select(order_by=binary_table.c.primary_id), text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType, 'mypickle':MyPickleType}, bind=testing.db) ): + eq_data = lambda x, y: eq_(list(x), list(y)) + if util.jython: + _eq_data = eq_data + def eq_data(x, y): + # Jython currently returns arrays + from array import ArrayType + if isinstance(y, ArrayType): + return eq_(x, y.tostring()) + return _eq_data(x, y) l = stmt.execute().fetchall() - eq_(list(stream1), list(l[0]['data'])) - eq_(list(stream1[0:100]), list(l[0]['data_slice'])) - eq_(list(stream2), list(l[1]['data'])) + eq_data(stream1, l[0]['data']) + eq_data(stream1[0:100], l[0]['data_slice']) + eq_data(stream2, l[1]['data']) eq_(testobj1, l[0]['pickled']) eq_(testobj2, l[1]['pickled']) eq_(testobj3.moredata, l[0]['mypickle'].moredata) eq_(l[0]['mypickle'].stuff, 'this is the right stuff') - def load_stream(self, name, len=12579): + def load_stream(self, name): f = os.path.join(os.path.dirname(__file__), "..", name) - # put a number less than the typical MySQL default BLOB size - return file(f).read(len) + return open(f, mode='rb').read() class ExpressionTest(TestBase, AssertsExecutionResults): @classmethod def setup_class(cls): global test_table, meta - class MyCustomType(types.TypeEngine): + class MyCustomType(types.UserDefinedType): def get_col_spec(self): return "INT" def bind_processor(self, dialect): @@ -547,7 +514,6 @@ class DateTest(TestBase, AssertsExecutionResults): db = testing.db if testing.against('oracle'): - import sqlalchemy.databases.oracle as oracle insert_data = [ (7, 'jack', datetime.datetime(2005, 11, 10, 0, 0), @@ -576,7 +542,7 @@ class DateTest(TestBase, AssertsExecutionResults): time_micro = 999 # Missing or poor microsecond support: - if testing.against('mssql', 'mysql', 'firebird'): + if testing.against('mssql', 'mysql', 'firebird', '+zxjdbc'): datetime_micro, time_micro = 0, 0 # No microseconds for TIME elif testing.against('maxdb'): @@ -608,7 +574,7 @@ class DateTest(TestBase, AssertsExecutionResults): Column('user_date', Date), Column('user_time', Time)] - if testing.against('sqlite', 'postgres'): + if testing.against('sqlite', 'postgresql'): insert_data.append( (11, 'historic', datetime.datetime(1850, 11, 10, 11, 52, 35, datetime_micro), @@ -676,8 +642,8 @@ class DateTest(TestBase, AssertsExecutionResults): t.drop(checkfirst=True) class StringTest(TestBase, AssertsExecutionResults): - @testing.fails_on('mysql', 'FIXME: unknown') - @testing.fails_on('oracle', 'FIXME: unknown') + + @testing.requires.unbounded_varchar def test_nolength_string(self): metadata = MetaData(testing.db) foo = Table('foo', metadata, Column('one', String)) @@ -700,10 +666,10 @@ class NumericTest(TestBase, AssertsExecutionResults): metadata = MetaData(testing.db) numeric_table = Table('numeric_table', metadata, Column('id', Integer, Sequence('numeric_id_seq', optional=True), primary_key=True), - Column('numericcol', Numeric(asdecimal=False)), - Column('floatcol', Float), - Column('ncasdec', Numeric), - Column('fcasdec', Float(asdecimal=True)) + Column('numericcol', Numeric(precision=10, scale=2, asdecimal=False)), + Column('floatcol', Float(precision=10, )), + Column('ncasdec', Numeric(precision=10, scale=2)), + Column('fcasdec', Float(precision=10, asdecimal=True)) ) metadata.create_all() @@ -711,6 +677,7 @@ class NumericTest(TestBase, AssertsExecutionResults): def teardown_class(cls): metadata.drop_all() + @engines.close_first def teardown(self): numeric_table.delete().execute() @@ -719,6 +686,7 @@ class NumericTest(TestBase, AssertsExecutionResults): from decimal import Decimal numeric_table.insert().execute( numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.75) + numeric_table.insert().execute( numericcol=Decimal("3.5"), floatcol=Decimal("5.6"), ncasdec=Decimal("12.4"), fcasdec=Decimal("15.75")) @@ -744,33 +712,6 @@ class NumericTest(TestBase, AssertsExecutionResults): assert isinstance(row['ncasdec'], decimal.Decimal) assert isinstance(row['fcasdec'], decimal.Decimal) - def test_length_deprecation(self): - assert_raises(exc.SADeprecationWarning, Numeric, length=8) - - @testing.uses_deprecated(".*is deprecated for Numeric") - def go(): - n = Numeric(length=12) - assert n.scale == 12 - go() - - n = Numeric(scale=12) - for dialect in engines.all_dialects(): - n2 = dialect.type_descriptor(n) - eq_(n2.scale, 12, dialect.name) - - # test colspec generates successfully using 'scale' - assert n2.get_col_spec() - - # test constructor of the dialect-specific type - n3 = n2.__class__(scale=5) - eq_(n3.scale, 5, dialect.name) - - @testing.uses_deprecated(".*is deprecated for Numeric") - def go(): - n3 = n2.__class__(length=6) - eq_(n3.scale, 6, dialect.name) - go() - class IntervalTest(TestBase, AssertsExecutionResults): @classmethod @@ -783,6 +724,7 @@ class IntervalTest(TestBase, AssertsExecutionResults): ) metadata.create_all() + @engines.close_first def teardown(self): interval_table.delete().execute() @@ -790,14 +732,16 @@ class IntervalTest(TestBase, AssertsExecutionResults): def teardown_class(cls): metadata.drop_all() + @testing.fails_on("+pg8000", "Not yet known how to pass values of the INTERVAL type") + @testing.fails_on("postgresql+zxjdbc", "Not yet known how to pass values of the INTERVAL type") def test_roundtrip(self): delta = datetime.datetime(2006, 10, 5) - datetime.datetime(2005, 8, 17) interval_table.insert().execute(interval=delta) - assert interval_table.select().execute().fetchone()['interval'] == delta + assert interval_table.select().execute().first()['interval'] == delta def test_null(self): interval_table.insert().execute(id=1, inverval=None) - assert interval_table.select().execute().fetchone()['interval'] is None + assert interval_table.select().execute().first()['interval'] is None class BooleanTest(TestBase, AssertsExecutionResults): @classmethod @@ -825,30 +769,6 @@ class BooleanTest(TestBase, AssertsExecutionResults): assert(res2==[(2, False)]) class PickleTest(TestBase): - def test_noeq_deprecation(self): - p1 = PickleType() - - assert_raises(DeprecationWarning, - p1.compare_values, pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2) - ) - - assert_raises(DeprecationWarning, - p1.compare_values, pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2) - ) - - @testing.uses_deprecated() - def go(): - # test actual dumps comparison - assert p1.compare_values(pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2)) - assert p1.compare_values(pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2)) - go() - - assert p1.compare_values({1:2, 3:4}, {3:4, 1:2}) - - p2 = PickleType(mutable=False) - assert not p2.compare_values(pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2)) - assert not p2.compare_values(pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2)) - def test_eq_comparison(self): p1 = PickleType() diff --git a/test/sql/test_unicode.py b/test/sql/test_unicode.py index d75913267..6551594f3 100644 --- a/test/sql/test_unicode.py +++ b/test/sql/test_unicode.py @@ -56,6 +56,7 @@ class UnicodeSchemaTest(TestBase): ) metadata.create_all() + @engines.close_first def teardown(self): if metadata.tables: t3.delete().execute() @@ -125,11 +126,11 @@ class EscapesDefaultsTest(testing.TestBase): # reset the identifier preparer, so that we can force it to cache # a unicode identifier engine.dialect.identifier_preparer = engine.dialect.preparer(engine.dialect) - select([column(u'special_col')]).select_from(t1).execute() + select([column(u'special_col')]).select_from(t1).execute().close() assert isinstance(engine.dialect.identifier_preparer.format_sequence(Sequence('special_col')), unicode) # now execute, run the sequence. it should run in u"Special_col.nextid" or similar as - # a unicode object; cx_oracle asserts that this is None or a String (postgres lets it pass thru). + # a unicode object; cx_oracle asserts that this is None or a String (postgresql lets it pass thru). # ensure that base.DefaultRunner is encoding. t1.insert().execute(data='foo') finally: diff --git a/test/zblog/mappers.py b/test/zblog/mappers.py index 5203bd866..126d2c568 100644 --- a/test/zblog/mappers.py +++ b/test/zblog/mappers.py @@ -1,7 +1,7 @@ """mapper.py - defines mappers for domain objects, mapping operations""" -import tables, user -from blog import * +from test.zblog import tables, user +from test.zblog.blog import * from sqlalchemy import * from sqlalchemy.orm import * import sqlalchemy.util as util diff --git a/test/zblog/tables.py b/test/zblog/tables.py index 36c7aeb8b..4907259e1 100644 --- a/test/zblog/tables.py +++ b/test/zblog/tables.py @@ -1,12 +1,12 @@ """application table metadata objects are described here.""" from sqlalchemy import * - +from sqlalchemy.test.schema import Table, Column metadata = MetaData() users = Table('users', metadata, - Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key=True), + Column('user_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_name', String(30), nullable=False), Column('fullname', String(100), nullable=False), Column('password', String(40), nullable=False), @@ -14,14 +14,14 @@ users = Table('users', metadata, ) blogs = Table('blogs', metadata, - Column('blog_id', Integer, Sequence('blog_id_seq', optional=True), primary_key=True), + Column('blog_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('owner_id', Integer, ForeignKey('users.user_id'), nullable=False), Column('name', String(100), nullable=False), Column('description', String(500)) ) posts = Table('posts', metadata, - Column('post_id', Integer, Sequence('post_id_seq', optional=True), primary_key=True), + Column('post_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('blog_id', Integer, ForeignKey('blogs.blog_id'), nullable=False), Column('user_id', Integer, ForeignKey('users.user_id'), nullable=False), Column('datetime', DateTime, nullable=False), @@ -31,7 +31,7 @@ posts = Table('posts', metadata, ) topics = Table('topics', metadata, - Column('topic_id', Integer, primary_key=True), + Column('topic_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('keyword', String(50), nullable=False), Column('description', String(500)) ) @@ -43,7 +43,7 @@ topic_xref = Table('topic_post_xref', metadata, ) comments = Table('comments', metadata, - Column('comment_id', Integer, primary_key=True), + Column('comment_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer, ForeignKey('users.user_id'), nullable=False), Column('post_id', Integer, ForeignKey('posts.post_id'), nullable=False), Column('datetime', DateTime, nullable=False), diff --git a/test/zblog/test_zblog.py b/test/zblog/test_zblog.py index 8170766cb..5e46c1ceb 100644 --- a/test/zblog/test_zblog.py +++ b/test/zblog/test_zblog.py @@ -1,9 +1,9 @@ from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.test import * -import mappers, tables -from user import * -from blog import * +from test.zblog import mappers, tables +from test.zblog.user import * +from test.zblog.blog import * class ZBlogTest(TestBase, AssertsExecutionResults): diff --git a/test/zblog/user.py b/test/zblog/user.py index 0a13002cd..30f1e3da1 100644 --- a/test/zblog/user.py +++ b/test/zblog/user.py @@ -14,9 +14,9 @@ groups = [user, administrator] def cryptpw(password, salt=None): if salt is None: - salt = string.join([chr(random.randint(ord('a'), ord('z'))), - chr(random.randint(ord('a'), ord('z')))],'') - return sha(password + salt).hexdigest() + salt = "".join([chr(random.randint(ord('a'), ord('z'))), + chr(random.randint(ord('a'), ord('z')))]) + return sha((password+ salt).encode('ascii')).hexdigest() def checkpw(password, dbpw): return cryptpw(password, dbpw[:2]) == dbpw |
