summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/mysql/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/mysql/base.py')
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py255
1 files changed, 166 insertions, 89 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 58eb3afa0..c8e33bfb2 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -106,7 +106,7 @@ to be used.
Transaction Isolation Level
---------------------------
-:func:`.create_engine` accepts an ``isolation_level``
+:func:`.create_engine` accepts an :paramref:`.create_engine.isolation_level`
parameter which results in the command ``SET SESSION
TRANSACTION ISOLATION LEVEL <level>`` being invoked for
every new connection. Valid values for this parameter are
@@ -602,6 +602,14 @@ class _StringType(sqltypes.String):
to_inspect=[_StringType, sqltypes.String])
+class _MatchType(sqltypes.Float, sqltypes.MatchType):
+ def __init__(self, **kw):
+ # TODO: float arguments?
+ sqltypes.Float.__init__(self)
+ sqltypes.MatchType.__init__(self)
+
+
+
class NUMERIC(_NumericType, sqltypes.NUMERIC):
"""MySQL NUMERIC type."""
@@ -1420,32 +1428,28 @@ class SET(_EnumeratedValues):
Column('myset', SET("foo", "bar", "baz"))
- :param values: The range of valid values for this SET. Values will be
- quoted when generating the schema according to the quoting flag (see
- below).
- .. versionchanged:: 0.9.0 quoting is applied automatically to
- :class:`.mysql.SET` in the same way as for :class:`.mysql.ENUM`.
+ The list of potential values is required in the case that this
+ set will be used to generate DDL for a table, or if the
+ :paramref:`.SET.retrieve_as_bitwise` flag is set to True.
- :param charset: Optional, a column-level character set for this string
- value. Takes precedence to 'ascii' or 'unicode' short-hand.
+ :param values: The range of valid values for this SET.
- :param collation: Optional, a column-level collation for this string
- value. Takes precedence to 'binary' short-hand.
+ :param convert_unicode: Same flag as that of
+ :paramref:`.String.convert_unicode`.
- :param ascii: Defaults to False: short-hand for the ``latin1``
- character set, generates ASCII in schema.
+ :param collation: same as that of :paramref:`.String.collation`
- :param unicode: Defaults to False: short-hand for the ``ucs2``
- character set, generates UNICODE in schema.
+ :param charset: same as that of :paramref:`.VARCHAR.charset`.
- :param binary: Defaults to False: short-hand, pick the binary
- collation type that matches the column's character set. Generates
- BINARY in schema. This does not affect the type of data stored,
- only the collation of character data.
+ :param ascii: same as that of :paramref:`.VARCHAR.ascii`.
- :param quoting: Defaults to 'auto': automatically determine enum value
- quoting. If all enum values are surrounded by the same quoting
+ :param unicode: same as that of :paramref:`.VARCHAR.unicode`.
+
+ :param binary: same as that of :paramref:`.VARCHAR.binary`.
+
+ :param quoting: Defaults to 'auto': automatically determine set value
+ quoting. If all values are surrounded by the same quoting
character, then use 'quoted' mode. Otherwise, use 'unquoted' mode.
'quoted': values in enums are already quoted, they will be used
@@ -1460,50 +1464,117 @@ class SET(_EnumeratedValues):
.. versionadded:: 0.9.0
+ :param retrieve_as_bitwise: if True, the data for the set type will be
+ persisted and selected using an integer value, where a set is coerced
+ into a bitwise mask for persistence. MySQL allows this mode which
+ has the advantage of being able to store values unambiguously,
+ such as the blank string ``''``. The datatype will appear
+ as the expression ``col + 0`` in a SELECT statement, so that the
+ value is coerced into an integer value in result sets.
+ This flag is required if one wishes
+ to persist a set that can store the blank string ``''`` as a value.
+
+ .. warning::
+
+ When using :paramref:`.mysql.SET.retrieve_as_bitwise`, it is
+ essential that the list of set values is expressed in the
+ **exact same order** as exists on the MySQL database.
+
+ .. versionadded:: 1.0.0
+
+
"""
+ self.retrieve_as_bitwise = kw.pop('retrieve_as_bitwise', False)
values, length = self._init_values(values, kw)
self.values = tuple(values)
-
+ if not self.retrieve_as_bitwise and '' in values:
+ raise exc.ArgumentError(
+ "Can't use the blank value '' in a SET without "
+ "setting retrieve_as_bitwise=True")
+ if self.retrieve_as_bitwise:
+ self._bitmap = dict(
+ (value, 2 ** idx)
+ for idx, value in enumerate(self.values)
+ )
+ self._bitmap.update(
+ (2 ** idx, value)
+ for idx, value in enumerate(self.values)
+ )
kw.setdefault('length', length)
super(SET, self).__init__(**kw)
+ def column_expression(self, colexpr):
+ if self.retrieve_as_bitwise:
+ return colexpr + 0
+ else:
+ return colexpr
+
def result_processor(self, dialect, coltype):
- def process(value):
- # The good news:
- # No ',' quoting issues- commas aren't allowed in SET values
- # The bad news:
- # Plenty of driver inconsistencies here.
- if isinstance(value, set):
- # ..some versions convert '' to an empty set
- if not value:
- value.add('')
- return value
- # ...and some versions return strings
- if value is not None:
- return set(value.split(','))
- else:
- return value
+ if self.retrieve_as_bitwise:
+ def process(value):
+ if value is not None:
+ value = int(value)
+
+ return set(
+ util.map_bits(self._bitmap.__getitem__, value)
+ )
+ else:
+ return None
+ else:
+ super_convert = super(SET, self).result_processor(dialect, coltype)
+
+ def process(value):
+ if isinstance(value, util.string_types):
+ # MySQLdb returns a string, let's parse
+ if super_convert:
+ value = super_convert(value)
+ return set(re.findall(r'[^,]+', value))
+ else:
+ # mysql-connector-python does a naive
+ # split(",") which throws in an empty string
+ if value is not None:
+ value.discard('')
+ return value
return process
def bind_processor(self, dialect):
super_convert = super(SET, self).bind_processor(dialect)
+ if self.retrieve_as_bitwise:
+ def process(value):
+ if value is None:
+ return None
+ elif isinstance(value, util.int_types + util.string_types):
+ if super_convert:
+ return super_convert(value)
+ else:
+ return value
+ else:
+ int_value = 0
+ for v in value:
+ int_value |= self._bitmap[v]
+ return int_value
+ else:
- def process(value):
- if value is None or isinstance(
- value, util.int_types + util.string_types):
- pass
- else:
- if None in value:
- value = set(value)
- value.remove(None)
- value.add('')
- value = ','.join(value)
- if super_convert:
- return super_convert(value)
- else:
- return value
+ def process(value):
+ # accept strings and int (actually bitflag) values directly
+ if value is not None and not isinstance(
+ value, util.int_types + util.string_types):
+ value = ",".join(value)
+
+ if super_convert:
+ return super_convert(value)
+ else:
+ return value
return process
+ def adapt(self, impltype, **kw):
+ kw['retrieve_as_bitwise'] = self.retrieve_as_bitwise
+ return util.constructor_copy(
+ self, impltype,
+ *self.values,
+ **kw
+ )
+
# old names
MSTime = TIME
MSSet = SET
@@ -1544,6 +1615,7 @@ colspecs = {
sqltypes.Float: FLOAT,
sqltypes.Time: TIME,
sqltypes.Enum: ENUM,
+ sqltypes.MatchType: _MatchType
}
# Everything 3.23 through 5.1 excepting OpenGIS types.
@@ -1758,10 +1830,10 @@ class MySQLCompiler(compiler.SQLCompiler):
# creation of foreign key constraints fails."
class MySQLDDLCompiler(compiler.DDLCompiler):
- def create_table_constraints(self, table):
+ def create_table_constraints(self, table, **kw):
"""Get table constraints."""
constraint_string = super(
- MySQLDDLCompiler, self).create_table_constraints(table)
+ MySQLDDLCompiler, self).create_table_constraints(table, **kw)
# why self.dialect.name and not 'mysql'? because of drizzle
is_innodb = 'engine' in table.dialect_options[self.dialect.name] and \
@@ -1787,9 +1859,11 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kw):
"""Builds column DDL."""
- colspec = [self.preparer.format_column(column),
- self.dialect.type_compiler.process(column.type)
- ]
+ colspec = [
+ self.preparer.format_column(column),
+ self.dialect.type_compiler.process(
+ column.type, type_expression=column)
+ ]
default = self.get_column_default_string(column)
if default is not None:
@@ -1987,7 +2061,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
def _mysql_type(self, type_):
return isinstance(type_, (_StringType, _NumericType))
- def visit_NUMERIC(self, type_):
+ def visit_NUMERIC(self, type_, **kw):
if type_.precision is None:
return self._extend_numeric(type_, "NUMERIC")
elif type_.scale is None:
@@ -2000,7 +2074,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
{'precision': type_.precision,
'scale': type_.scale})
- def visit_DECIMAL(self, type_):
+ def visit_DECIMAL(self, type_, **kw):
if type_.precision is None:
return self._extend_numeric(type_, "DECIMAL")
elif type_.scale is None:
@@ -2013,7 +2087,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
{'precision': type_.precision,
'scale': type_.scale})
- def visit_DOUBLE(self, type_):
+ def visit_DOUBLE(self, type_, **kw):
if type_.precision is not None and type_.scale is not None:
return self._extend_numeric(type_,
"DOUBLE(%(precision)s, %(scale)s)" %
@@ -2022,7 +2096,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, 'DOUBLE')
- def visit_REAL(self, type_):
+ def visit_REAL(self, type_, **kw):
if type_.precision is not None and type_.scale is not None:
return self._extend_numeric(type_,
"REAL(%(precision)s, %(scale)s)" %
@@ -2031,7 +2105,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, 'REAL')
- def visit_FLOAT(self, type_):
+ def visit_FLOAT(self, type_, **kw):
if self._mysql_type(type_) and \
type_.scale is not None and \
type_.precision is not None:
@@ -2043,7 +2117,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, "FLOAT")
- def visit_INTEGER(self, type_):
+ def visit_INTEGER(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
type_, "INTEGER(%(display_width)s)" %
@@ -2051,7 +2125,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, "INTEGER")
- def visit_BIGINT(self, type_):
+ def visit_BIGINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
type_, "BIGINT(%(display_width)s)" %
@@ -2059,7 +2133,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, "BIGINT")
- def visit_MEDIUMINT(self, type_):
+ def visit_MEDIUMINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
type_, "MEDIUMINT(%(display_width)s)" %
@@ -2067,14 +2141,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, "MEDIUMINT")
- def visit_TINYINT(self, type_):
+ def visit_TINYINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(type_,
"TINYINT(%s)" % type_.display_width)
else:
return self._extend_numeric(type_, "TINYINT")
- def visit_SMALLINT(self, type_):
+ def visit_SMALLINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(type_,
"SMALLINT(%(display_width)s)" %
@@ -2083,55 +2157,55 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, "SMALLINT")
- def visit_BIT(self, type_):
+ def visit_BIT(self, type_, **kw):
if type_.length is not None:
return "BIT(%s)" % type_.length
else:
return "BIT"
- def visit_DATETIME(self, type_):
+ def visit_DATETIME(self, type_, **kw):
if getattr(type_, 'fsp', None):
return "DATETIME(%d)" % type_.fsp
else:
return "DATETIME"
- def visit_DATE(self, type_):
+ def visit_DATE(self, type_, **kw):
return "DATE"
- def visit_TIME(self, type_):
+ def visit_TIME(self, type_, **kw):
if getattr(type_, 'fsp', None):
return "TIME(%d)" % type_.fsp
else:
return "TIME"
- def visit_TIMESTAMP(self, type_):
+ def visit_TIMESTAMP(self, type_, **kw):
if getattr(type_, 'fsp', None):
return "TIMESTAMP(%d)" % type_.fsp
else:
return "TIMESTAMP"
- def visit_YEAR(self, type_):
+ def visit_YEAR(self, type_, **kw):
if type_.display_width is None:
return "YEAR"
else:
return "YEAR(%s)" % type_.display_width
- def visit_TEXT(self, type_):
+ def visit_TEXT(self, type_, **kw):
if type_.length:
return self._extend_string(type_, {}, "TEXT(%d)" % type_.length)
else:
return self._extend_string(type_, {}, "TEXT")
- def visit_TINYTEXT(self, type_):
+ def visit_TINYTEXT(self, type_, **kw):
return self._extend_string(type_, {}, "TINYTEXT")
- def visit_MEDIUMTEXT(self, type_):
+ def visit_MEDIUMTEXT(self, type_, **kw):
return self._extend_string(type_, {}, "MEDIUMTEXT")
- def visit_LONGTEXT(self, type_):
+ def visit_LONGTEXT(self, type_, **kw):
return self._extend_string(type_, {}, "LONGTEXT")
- def visit_VARCHAR(self, type_):
+ def visit_VARCHAR(self, type_, **kw):
if type_.length:
return self._extend_string(
type_, {}, "VARCHAR(%d)" % type_.length)
@@ -2140,14 +2214,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
"VARCHAR requires a length on dialect %s" %
self.dialect.name)
- def visit_CHAR(self, type_):
+ def visit_CHAR(self, type_, **kw):
if type_.length:
return self._extend_string(type_, {}, "CHAR(%(length)s)" %
{'length': type_.length})
else:
return self._extend_string(type_, {}, "CHAR")
- def visit_NVARCHAR(self, type_):
+ def visit_NVARCHAR(self, type_, **kw):
# We'll actually generate the equiv. "NATIONAL VARCHAR" instead
# of "NVARCHAR".
if type_.length:
@@ -2159,7 +2233,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
"NVARCHAR requires a length on dialect %s" %
self.dialect.name)
- def visit_NCHAR(self, type_):
+ def visit_NCHAR(self, type_, **kw):
# We'll actually generate the equiv.
# "NATIONAL CHAR" instead of "NCHAR".
if type_.length:
@@ -2169,31 +2243,31 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_string(type_, {'national': True}, "CHAR")
- def visit_VARBINARY(self, type_):
+ def visit_VARBINARY(self, type_, **kw):
return "VARBINARY(%d)" % type_.length
- def visit_large_binary(self, type_):
+ def visit_large_binary(self, type_, **kw):
return self.visit_BLOB(type_)
- def visit_enum(self, type_):
+ def visit_enum(self, type_, **kw):
if not type_.native_enum:
return super(MySQLTypeCompiler, self).visit_enum(type_)
else:
return self._visit_enumerated_values("ENUM", type_, type_.enums)
- def visit_BLOB(self, type_):
+ def visit_BLOB(self, type_, **kw):
if type_.length:
return "BLOB(%d)" % type_.length
else:
return "BLOB"
- def visit_TINYBLOB(self, type_):
+ def visit_TINYBLOB(self, type_, **kw):
return "TINYBLOB"
- def visit_MEDIUMBLOB(self, type_):
+ def visit_MEDIUMBLOB(self, type_, **kw):
return "MEDIUMBLOB"
- def visit_LONGBLOB(self, type_):
+ def visit_LONGBLOB(self, type_, **kw):
return "LONGBLOB"
def _visit_enumerated_values(self, name, type_, enumerated_values):
@@ -2204,15 +2278,15 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
name, ",".join(quoted_enums))
)
- def visit_ENUM(self, type_):
+ def visit_ENUM(self, type_, **kw):
return self._visit_enumerated_values("ENUM", type_,
type_._enumerated_values)
- def visit_SET(self, type_):
+ def visit_SET(self, type_, **kw):
return self._visit_enumerated_values("SET", type_,
type_._enumerated_values)
- def visit_BOOLEAN(self, type):
+ def visit_BOOLEAN(self, type, **kw):
return "BOOL"
@@ -2963,6 +3037,9 @@ class MySQLTableDefinitionParser(object):
if issubclass(col_type, _EnumeratedValues):
type_args = _EnumeratedValues._strip_values(type_args)
+ if issubclass(col_type, SET) and '' in type_args:
+ type_kw['retrieve_as_bitwise'] = True
+
type_instance = col_type(*type_args, **type_kw)
col_args, col_kw = [], {}