summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Johnston <paj@pajhome.org.uk>2007-06-13 18:53:16 +0000
committerPaul Johnston <paj@pajhome.org.uk>2007-06-13 18:53:16 +0000
commit2c65ce75360c6018ecc6160d6ef11e23ae628553 (patch)
tree79908f4ba077776287975fee2fb694c22e49e450
parente4cd7b2ed4303d2553692037ee74827e3f9dfa12 (diff)
downloadsqlalchemy-2c65ce75360c6018ecc6160d6ef11e23ae628553.tar.gz
Multiple MSSQL fixes; see ticket #581
-rw-r--r--lib/sqlalchemy/databases/mssql.py87
-rw-r--r--lib/sqlalchemy/sql.py1
-rw-r--r--test/sql/query.py13
-rw-r--r--test/sql/rowcount.py17
-rw-r--r--test/sql/testtypes.py23
5 files changed, 112 insertions, 29 deletions
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
index 1cadbd14d..4336296dd 100644
--- a/lib/sqlalchemy/databases/mssql.py
+++ b/lib/sqlalchemy/databases/mssql.py
@@ -99,14 +99,14 @@ class MSDateTime(sqltypes.DateTime):
def get_col_spec(self):
return "DATETIME"
- def convert_bind_param(self, value, dialect):
- if hasattr(value, "isoformat"):
- #return value.isoformat(' ')
- # isoformat() bings on apodbapi -- reported/suggested by Peter Buschman
- return value.strftime('%Y-%m-%d %H:%M:%S')
- else:
- return value
+class MSDate(sqltypes.Date):
+ def __init__(self, *a, **kw):
+ super(MSDate, self).__init__(False)
+ def get_col_spec(self):
+ return "SMALLDATETIME"
+
+class MSDateTime_adodbapi(MSDateTime):
def convert_result_value(self, value, dialect):
# adodbapi will return datetimes with empty time values as datetime.date() objects.
# Promote them back to full datetime.datetime()
@@ -114,23 +114,34 @@ class MSDateTime(sqltypes.DateTime):
return datetime.datetime(value.year, value.month, value.day)
return value
-class MSDate(sqltypes.Date):
- def __init__(self, *a, **kw):
- super(MSDate, self).__init__(False)
+class MSDateTime_pyodbc(MSDateTime):
+ def convert_bind_param(self, value, dialect):
+ if value and not hasattr(value, 'second'):
+ return datetime.datetime(value.year, value.month, value.day)
+ else:
+ return value
- def get_col_spec(self):
- return "SMALLDATETIME"
-
+class MSDate_pyodbc(MSDate):
def convert_bind_param(self, value, dialect):
- if value and hasattr(value, "isoformat"):
- return value.strftime('%Y-%m-%d %H:%M')
- return value
+ if value and not hasattr(value, 'second'):
+ return datetime.datetime(value.year, value.month, value.day)
+ else:
+ return value
def convert_result_value(self, value, dialect):
+ # pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date()
+ if value and hasattr(value, 'second'):
+ return value.date()
+ else:
+ return value
+
+class MSDate_pymssql(MSDate):
+ def convert_result_value(self, value, dialect):
# pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date()
if value and hasattr(value, 'second'):
return value.date()
- return value
+ else:
+ return value
class MSText(sqltypes.TEXT):
def get_col_spec(self):
@@ -143,7 +154,7 @@ class MSString(sqltypes.String):
def get_col_spec(self):
return "VARCHAR(%(length)s)" % {'length' : self.length}
-class MSNVarchar(MSString):
+class MSNVarchar(sqltypes.Unicode):
def get_col_spec(self):
if self.length:
return "NVARCHAR(%(length)s)" % {'length' : self.length}
@@ -191,6 +202,10 @@ class MSBoolean(sqltypes.Boolean):
else:
return value and True or False
+class MSTimeStamp(sqltypes.TIMESTAMP):
+ def get_col_spec(self):
+ return "TIMESTAMP"
+
def descriptor():
return {'name':'mssql',
'description':'MSSQL',
@@ -240,7 +255,7 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
if self.IINSERT:
# TODO: quoting rules for table name here ?
- self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.name)
+ self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.fullname)
super(MSSQLExecutionContext, self).pre_exec()
@@ -253,7 +268,7 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
if self.compiled.isinsert:
if self.IINSERT:
# TODO: quoting rules for table name here ?
- self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.name)
+ self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.fullname)
self.IINSERT = False
elif self.HASIDENT:
if self.dialect.use_scope_identity:
@@ -294,6 +309,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
sqltypes.TEXT : MSText,
sqltypes.CHAR: MSChar,
sqltypes.NCHAR: MSNChar,
+ sqltypes.TIMESTAMP: MSTimeStamp,
}
ischema_names = {
@@ -314,7 +330,8 @@ class MSSQLDialect(ansisql.ANSIDialect):
'binary' : MSBinary,
'bit': MSBoolean,
'real' : MSFloat,
- 'image' : MSBinary
+ 'image' : MSBinary,
+ 'timestamp': MSTimeStamp,
}
def __new__(cls, dbapi=None, *args, **kwargs):
@@ -330,7 +347,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
super(MSSQLDialect, self).__init__(**params)
self.auto_identity_insert = auto_identity_insert
self.text_as_varchar = False
- self.use_scope_identity = True
+ self.use_scope_identity = False
self.set_default_schema_name("dbo")
def dbapi(cls, module_name=None):
@@ -570,6 +587,16 @@ class MSSQLDialect_pymssql(MSSQLDialect):
return module
import_dbapi = classmethod(import_dbapi)
+ colspecs = MSSQLDialect.colspecs.copy()
+ colspecs[sqltypes.Date] = MSDate_pymssql
+
+ ischema_names = MSSQLDialect.ischema_names.copy()
+ ischema_names['smalldatetime'] = MSDate_pymssql
+
+ def __init__(self, **params):
+ super(MSSQLDialect_pymssql, self).__init__(**params)
+ self.use_scope_identity = True
+
def supports_sane_rowcount(self):
return True
@@ -641,12 +668,21 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
colspecs = MSSQLDialect.colspecs.copy()
colspecs[sqltypes.Unicode] = AdoMSNVarchar
+ colspecs[sqltypes.Date] = MSDate_pyodbc
+ colspecs[sqltypes.DateTime] = MSDateTime_pyodbc
+
ischema_names = MSSQLDialect.ischema_names.copy()
ischema_names['nvarchar'] = AdoMSNVarchar
+ ischema_names['smalldatetime'] = MSDate_pyodbc
+ ischema_names['datetime'] = MSDateTime_pyodbc
def supports_sane_rowcount(self):
return False
+ def supports_unicode_statements(self):
+ """indicate whether the DBAPI can receive SQL statements as Python unicode strings"""
+ return True
+
def make_connect_string(self, keys):
connectors = ["Driver={SQL Server}"]
connectors.append("Server=%s" % keys.get("host"))
@@ -674,12 +710,19 @@ class MSSQLDialect_adodbapi(MSSQLDialect):
colspecs = MSSQLDialect.colspecs.copy()
colspecs[sqltypes.Unicode] = AdoMSNVarchar
+ colspecs[sqltypes.DateTime] = MSDateTime_adodbapi
+
ischema_names = MSSQLDialect.ischema_names.copy()
ischema_names['nvarchar'] = AdoMSNVarchar
+ ischema_names['datetime'] = MSDateTime_adodbapi
def supports_sane_rowcount(self):
return True
+ def supports_unicode_statements(self):
+ """indicate whether the DBAPI can receive SQL statements as Python unicode strings"""
+ return True
+
def make_connect_string(self, keys):
connectors = ["Provider=SQLOLEDB"]
if 'port' in keys:
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index 41b61d4af..489e9d59e 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -2686,6 +2686,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
self.is_compound = True
self.is_where = False
self.is_scalar = False
+ self.is_subquery = False
self.selects = selects
diff --git a/test/sql/query.py b/test/sql/query.py
index d788544c0..6e43f8779 100644
--- a/test/sql/query.py
+++ b/test/sql/query.py
@@ -402,6 +402,19 @@ class QueryTest(PersistTest):
con.execute("""drop trigger paj""")
meta.drop_all()
+ @testbase.supported('mssql')
+ def test_insertid_schema(self):
+ meta = BoundMetaData(testbase.db)
+ con = testbase.db.connect()
+ con.execute('create schema paj')
+ tbl = Table('test', meta, Column('id', Integer, primary_key=True), schema='paj')
+ tbl.create()
+ try:
+ tbl.insert().execute({'id':1})
+ finally:
+ tbl.drop()
+ con.execute('drop schema paj')
+
class CompoundTest(PersistTest):
"""test compound statements like UNION, INTERSECT, particularly their ability to nest on
diff --git a/test/sql/rowcount.py b/test/sql/rowcount.py
index 05d0f2110..95cab898c 100644
--- a/test/sql/rowcount.py
+++ b/test/sql/rowcount.py
@@ -31,7 +31,7 @@ class FoundRowsTest(testbase.AssertMixin):
i.execute(*[{'name':n, 'department':d} for n, d in data])
def tearDown(self):
employees_table.delete().execute()
-
+
def tearDownAll(self):
employees_table.drop()
@@ -45,23 +45,26 @@ class FoundRowsTest(testbase.AssertMixin):
# WHERE matches 3, 3 rows changed
department = employees_table.c.department
r = employees_table.update(department=='C').execute(department='Z')
- assert r.rowcount == 3
-
+ if testbase.db.dialect.supports_sane_rowcount():
+ assert r.rowcount == 3
+
def test_update_rowcount2(self):
# WHERE matches 3, 0 rows changed
department = employees_table.c.department
r = employees_table.update(department=='C').execute(department='C')
- assert r.rowcount == 3
-
+ if testbase.db.dialect.supports_sane_rowcount():
+ assert r.rowcount == 3
+
def test_delete_rowcount(self):
# WHERE matches 3, 3 rows deleted
department = employees_table.c.department
r = employees_table.delete(department=='C').execute()
- assert r.rowcount == 3
+ if testbase.db.dialect.supports_sane_rowcount():
+ assert r.rowcount == 3
if __name__ == '__main__':
testbase.main()
-
+
diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py
index acf21b917..b2b747a33 100644
--- a/test/sql/testtypes.py
+++ b/test/sql/testtypes.py
@@ -190,6 +190,11 @@ class UnicodeTest(AssertMixin):
finally:
db.engine.dialect.convert_unicode = prev_unicode
+ def testlength(self):
+ """checks the database correctly understands the length of a unicode string"""
+ teststr = u'aaa\x1234'
+ self.assert_(db.func.length(teststr).scalar() == len(teststr))
+
class BinaryTest(AssertMixin):
def setUpAll(self):
global binary_table
@@ -313,6 +318,24 @@ class DateTest(AssertMixin):
#x = db.text("select * from query_users_with_date where user_datetime=:date", bindparams=[bindparam('date', )]).execute(date=datetime.datetime(2005, 11, 10, 11, 52, 35)).fetchall()
#print repr(x)
+ @testbase.unsupported('sqlite')
+ def testdate2(self):
+ t = Table('testdate', testbase.metadata, Column('id', Integer, primary_key=True),
+ Column('adate', Date), Column('adatetime', DateTime))
+ t.create()
+ try:
+ d1 = datetime.date(2007, 10, 30)
+ t.insert().execute(adate=d1, adatetime=d1)
+ d2 = datetime.datetime(2007, 10, 30)
+ t.insert().execute(adate=d2, adatetime=d2)
+
+ x = t.select().execute().fetchall()[0]
+ self.assert_(x.adate.__class__ == datetime.date)
+ self.assert_(x.adatetime.__class__ == datetime.datetime)
+
+ finally:
+ t.drop()
+
class TimezoneTest(AssertMixin):
"""test timezone-aware datetimes. psycopg will return a datetime with a tzinfo attached to it,
if postgres returns it. python then will not let you compare a datetime with a tzinfo to a datetime