summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-12-13 21:06:38 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-12-13 21:06:38 +0000
commite1a52eb7dfb19edf3baeff6d2878b6b0afb9a04d (patch)
treec85e9a9fa87166b08f8ab12dbcfd107c48b78220 /lib/sqlalchemy
parent8061aaaed94924067122dc7568e8cb0e55eda329 (diff)
downloadsqlalchemy-e1a52eb7dfb19edf3baeff6d2878b6b0afb9a04d.tar.gz
- patch that makes MySQL rowcount work correctly! [ticket:396]
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/databases/mysql.py8
-rw-r--r--lib/sqlalchemy/orm/mapper.py2
2 files changed, 8 insertions, 2 deletions
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index c795ae7d4..19dedd826 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -13,6 +13,7 @@ import sqlalchemy.exceptions as exceptions
try:
import MySQLdb as mysql
+ import MySQLdb.constants.CLIENT as CLIENT_FLAGS
except:
mysql = None
@@ -270,6 +271,11 @@ class MySQLDialect(ansisql.ANSIDialect):
coercetype('use_unicode', bool) # this could break SA Unicode type
coercetype('charset', str) # this could break SA Unicode type
# TODO: what about options like "ssl", "cursorclass" and "conv" ?
+
+ client_flag = opts.get('client_flag', 0)
+ client_flag |= CLIENT_FLAGS.FOUND_ROWS
+ opts['client_flag'] = client_flag
+
return [[], opts]
def create_execution_context(self):
@@ -279,7 +285,7 @@ class MySQLDialect(ansisql.ANSIDialect):
return sqltypes.adapt_type(typeobj, colspecs)
def supports_sane_rowcount(self):
- return False
+ return True
def compiler(self, statement, bindparams, **kwargs):
return MySQLCompiler(self, statement, bindparams, **kwargs)
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index d7449c7ca..78fd3a1cc 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -960,7 +960,7 @@ class Mapper(object):
mapper._postfetch(connection, table, obj, c, c.last_updated_params())
updated_objects.add(obj)
- rows += c.cursor.rowcount
+ rows += c.rowcount
if c.supports_sane_rowcount() and rows != len(update):
raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (rows, len(update)))