diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/mysql/base.py')
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 255 |
1 files changed, 166 insertions, 89 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 58eb3afa0..c8e33bfb2 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -106,7 +106,7 @@ to be used. Transaction Isolation Level --------------------------- -:func:`.create_engine` accepts an ``isolation_level`` +:func:`.create_engine` accepts an :paramref:`.create_engine.isolation_level` parameter which results in the command ``SET SESSION TRANSACTION ISOLATION LEVEL <level>`` being invoked for every new connection. Valid values for this parameter are @@ -602,6 +602,14 @@ class _StringType(sqltypes.String): to_inspect=[_StringType, sqltypes.String]) +class _MatchType(sqltypes.Float, sqltypes.MatchType): + def __init__(self, **kw): + # TODO: float arguments? + sqltypes.Float.__init__(self) + sqltypes.MatchType.__init__(self) + + + class NUMERIC(_NumericType, sqltypes.NUMERIC): """MySQL NUMERIC type.""" @@ -1420,32 +1428,28 @@ class SET(_EnumeratedValues): Column('myset', SET("foo", "bar", "baz")) - :param values: The range of valid values for this SET. Values will be - quoted when generating the schema according to the quoting flag (see - below). - .. versionchanged:: 0.9.0 quoting is applied automatically to - :class:`.mysql.SET` in the same way as for :class:`.mysql.ENUM`. + The list of potential values is required in the case that this + set will be used to generate DDL for a table, or if the + :paramref:`.SET.retrieve_as_bitwise` flag is set to True. - :param charset: Optional, a column-level character set for this string - value. Takes precedence to 'ascii' or 'unicode' short-hand. + :param values: The range of valid values for this SET. - :param collation: Optional, a column-level collation for this string - value. Takes precedence to 'binary' short-hand. + :param convert_unicode: Same flag as that of + :paramref:`.String.convert_unicode`. - :param ascii: Defaults to False: short-hand for the ``latin1`` - character set, generates ASCII in schema. + :param collation: same as that of :paramref:`.String.collation` - :param unicode: Defaults to False: short-hand for the ``ucs2`` - character set, generates UNICODE in schema. + :param charset: same as that of :paramref:`.VARCHAR.charset`. - :param binary: Defaults to False: short-hand, pick the binary - collation type that matches the column's character set. Generates - BINARY in schema. This does not affect the type of data stored, - only the collation of character data. + :param ascii: same as that of :paramref:`.VARCHAR.ascii`. - :param quoting: Defaults to 'auto': automatically determine enum value - quoting. If all enum values are surrounded by the same quoting + :param unicode: same as that of :paramref:`.VARCHAR.unicode`. + + :param binary: same as that of :paramref:`.VARCHAR.binary`. + + :param quoting: Defaults to 'auto': automatically determine set value + quoting. If all values are surrounded by the same quoting character, then use 'quoted' mode. Otherwise, use 'unquoted' mode. 'quoted': values in enums are already quoted, they will be used @@ -1460,50 +1464,117 @@ class SET(_EnumeratedValues): .. versionadded:: 0.9.0 + :param retrieve_as_bitwise: if True, the data for the set type will be + persisted and selected using an integer value, where a set is coerced + into a bitwise mask for persistence. MySQL allows this mode which + has the advantage of being able to store values unambiguously, + such as the blank string ``''``. The datatype will appear + as the expression ``col + 0`` in a SELECT statement, so that the + value is coerced into an integer value in result sets. + This flag is required if one wishes + to persist a set that can store the blank string ``''`` as a value. + + .. warning:: + + When using :paramref:`.mysql.SET.retrieve_as_bitwise`, it is + essential that the list of set values is expressed in the + **exact same order** as exists on the MySQL database. + + .. versionadded:: 1.0.0 + + """ + self.retrieve_as_bitwise = kw.pop('retrieve_as_bitwise', False) values, length = self._init_values(values, kw) self.values = tuple(values) - + if not self.retrieve_as_bitwise and '' in values: + raise exc.ArgumentError( + "Can't use the blank value '' in a SET without " + "setting retrieve_as_bitwise=True") + if self.retrieve_as_bitwise: + self._bitmap = dict( + (value, 2 ** idx) + for idx, value in enumerate(self.values) + ) + self._bitmap.update( + (2 ** idx, value) + for idx, value in enumerate(self.values) + ) kw.setdefault('length', length) super(SET, self).__init__(**kw) + def column_expression(self, colexpr): + if self.retrieve_as_bitwise: + return colexpr + 0 + else: + return colexpr + def result_processor(self, dialect, coltype): - def process(value): - # The good news: - # No ',' quoting issues- commas aren't allowed in SET values - # The bad news: - # Plenty of driver inconsistencies here. - if isinstance(value, set): - # ..some versions convert '' to an empty set - if not value: - value.add('') - return value - # ...and some versions return strings - if value is not None: - return set(value.split(',')) - else: - return value + if self.retrieve_as_bitwise: + def process(value): + if value is not None: + value = int(value) + + return set( + util.map_bits(self._bitmap.__getitem__, value) + ) + else: + return None + else: + super_convert = super(SET, self).result_processor(dialect, coltype) + + def process(value): + if isinstance(value, util.string_types): + # MySQLdb returns a string, let's parse + if super_convert: + value = super_convert(value) + return set(re.findall(r'[^,]+', value)) + else: + # mysql-connector-python does a naive + # split(",") which throws in an empty string + if value is not None: + value.discard('') + return value return process def bind_processor(self, dialect): super_convert = super(SET, self).bind_processor(dialect) + if self.retrieve_as_bitwise: + def process(value): + if value is None: + return None + elif isinstance(value, util.int_types + util.string_types): + if super_convert: + return super_convert(value) + else: + return value + else: + int_value = 0 + for v in value: + int_value |= self._bitmap[v] + return int_value + else: - def process(value): - if value is None or isinstance( - value, util.int_types + util.string_types): - pass - else: - if None in value: - value = set(value) - value.remove(None) - value.add('') - value = ','.join(value) - if super_convert: - return super_convert(value) - else: - return value + def process(value): + # accept strings and int (actually bitflag) values directly + if value is not None and not isinstance( + value, util.int_types + util.string_types): + value = ",".join(value) + + if super_convert: + return super_convert(value) + else: + return value return process + def adapt(self, impltype, **kw): + kw['retrieve_as_bitwise'] = self.retrieve_as_bitwise + return util.constructor_copy( + self, impltype, + *self.values, + **kw + ) + # old names MSTime = TIME MSSet = SET @@ -1544,6 +1615,7 @@ colspecs = { sqltypes.Float: FLOAT, sqltypes.Time: TIME, sqltypes.Enum: ENUM, + sqltypes.MatchType: _MatchType } # Everything 3.23 through 5.1 excepting OpenGIS types. @@ -1758,10 +1830,10 @@ class MySQLCompiler(compiler.SQLCompiler): # creation of foreign key constraints fails." class MySQLDDLCompiler(compiler.DDLCompiler): - def create_table_constraints(self, table): + def create_table_constraints(self, table, **kw): """Get table constraints.""" constraint_string = super( - MySQLDDLCompiler, self).create_table_constraints(table) + MySQLDDLCompiler, self).create_table_constraints(table, **kw) # why self.dialect.name and not 'mysql'? because of drizzle is_innodb = 'engine' in table.dialect_options[self.dialect.name] and \ @@ -1787,9 +1859,11 @@ class MySQLDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kw): """Builds column DDL.""" - colspec = [self.preparer.format_column(column), - self.dialect.type_compiler.process(column.type) - ] + colspec = [ + self.preparer.format_column(column), + self.dialect.type_compiler.process( + column.type, type_expression=column) + ] default = self.get_column_default_string(column) if default is not None: @@ -1987,7 +2061,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def _mysql_type(self, type_): return isinstance(type_, (_StringType, _NumericType)) - def visit_NUMERIC(self, type_): + def visit_NUMERIC(self, type_, **kw): if type_.precision is None: return self._extend_numeric(type_, "NUMERIC") elif type_.scale is None: @@ -2000,7 +2074,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): {'precision': type_.precision, 'scale': type_.scale}) - def visit_DECIMAL(self, type_): + def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return self._extend_numeric(type_, "DECIMAL") elif type_.scale is None: @@ -2013,7 +2087,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): {'precision': type_.precision, 'scale': type_.scale}) - def visit_DOUBLE(self, type_): + def visit_DOUBLE(self, type_, **kw): if type_.precision is not None and type_.scale is not None: return self._extend_numeric(type_, "DOUBLE(%(precision)s, %(scale)s)" % @@ -2022,7 +2096,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, 'DOUBLE') - def visit_REAL(self, type_): + def visit_REAL(self, type_, **kw): if type_.precision is not None and type_.scale is not None: return self._extend_numeric(type_, "REAL(%(precision)s, %(scale)s)" % @@ -2031,7 +2105,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, 'REAL') - def visit_FLOAT(self, type_): + def visit_FLOAT(self, type_, **kw): if self._mysql_type(type_) and \ type_.scale is not None and \ type_.precision is not None: @@ -2043,7 +2117,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "FLOAT") - def visit_INTEGER(self, type_): + def visit_INTEGER(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, "INTEGER(%(display_width)s)" % @@ -2051,7 +2125,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "INTEGER") - def visit_BIGINT(self, type_): + def visit_BIGINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, "BIGINT(%(display_width)s)" % @@ -2059,7 +2133,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "BIGINT") - def visit_MEDIUMINT(self, type_): + def visit_MEDIUMINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, "MEDIUMINT(%(display_width)s)" % @@ -2067,14 +2141,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "MEDIUMINT") - def visit_TINYINT(self, type_): + def visit_TINYINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric(type_, "TINYINT(%s)" % type_.display_width) else: return self._extend_numeric(type_, "TINYINT") - def visit_SMALLINT(self, type_): + def visit_SMALLINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric(type_, "SMALLINT(%(display_width)s)" % @@ -2083,55 +2157,55 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "SMALLINT") - def visit_BIT(self, type_): + def visit_BIT(self, type_, **kw): if type_.length is not None: return "BIT(%s)" % type_.length else: return "BIT" - def visit_DATETIME(self, type_): + def visit_DATETIME(self, type_, **kw): if getattr(type_, 'fsp', None): return "DATETIME(%d)" % type_.fsp else: return "DATETIME" - def visit_DATE(self, type_): + def visit_DATE(self, type_, **kw): return "DATE" - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): if getattr(type_, 'fsp', None): return "TIME(%d)" % type_.fsp else: return "TIME" - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): if getattr(type_, 'fsp', None): return "TIMESTAMP(%d)" % type_.fsp else: return "TIMESTAMP" - def visit_YEAR(self, type_): + def visit_YEAR(self, type_, **kw): if type_.display_width is None: return "YEAR" else: return "YEAR(%s)" % type_.display_width - def visit_TEXT(self, type_): + def visit_TEXT(self, type_, **kw): if type_.length: return self._extend_string(type_, {}, "TEXT(%d)" % type_.length) else: return self._extend_string(type_, {}, "TEXT") - def visit_TINYTEXT(self, type_): + def visit_TINYTEXT(self, type_, **kw): return self._extend_string(type_, {}, "TINYTEXT") - def visit_MEDIUMTEXT(self, type_): + def visit_MEDIUMTEXT(self, type_, **kw): return self._extend_string(type_, {}, "MEDIUMTEXT") - def visit_LONGTEXT(self, type_): + def visit_LONGTEXT(self, type_, **kw): return self._extend_string(type_, {}, "LONGTEXT") - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): if type_.length: return self._extend_string( type_, {}, "VARCHAR(%d)" % type_.length) @@ -2140,14 +2214,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): "VARCHAR requires a length on dialect %s" % self.dialect.name) - def visit_CHAR(self, type_): + def visit_CHAR(self, type_, **kw): if type_.length: return self._extend_string(type_, {}, "CHAR(%(length)s)" % {'length': type_.length}) else: return self._extend_string(type_, {}, "CHAR") - def visit_NVARCHAR(self, type_): + def visit_NVARCHAR(self, type_, **kw): # We'll actually generate the equiv. "NATIONAL VARCHAR" instead # of "NVARCHAR". if type_.length: @@ -2159,7 +2233,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): "NVARCHAR requires a length on dialect %s" % self.dialect.name) - def visit_NCHAR(self, type_): + def visit_NCHAR(self, type_, **kw): # We'll actually generate the equiv. # "NATIONAL CHAR" instead of "NCHAR". if type_.length: @@ -2169,31 +2243,31 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_string(type_, {'national': True}, "CHAR") - def visit_VARBINARY(self, type_): + def visit_VARBINARY(self, type_, **kw): return "VARBINARY(%d)" % type_.length - def visit_large_binary(self, type_): + def visit_large_binary(self, type_, **kw): return self.visit_BLOB(type_) - def visit_enum(self, type_): + def visit_enum(self, type_, **kw): if not type_.native_enum: return super(MySQLTypeCompiler, self).visit_enum(type_) else: return self._visit_enumerated_values("ENUM", type_, type_.enums) - def visit_BLOB(self, type_): + def visit_BLOB(self, type_, **kw): if type_.length: return "BLOB(%d)" % type_.length else: return "BLOB" - def visit_TINYBLOB(self, type_): + def visit_TINYBLOB(self, type_, **kw): return "TINYBLOB" - def visit_MEDIUMBLOB(self, type_): + def visit_MEDIUMBLOB(self, type_, **kw): return "MEDIUMBLOB" - def visit_LONGBLOB(self, type_): + def visit_LONGBLOB(self, type_, **kw): return "LONGBLOB" def _visit_enumerated_values(self, name, type_, enumerated_values): @@ -2204,15 +2278,15 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): name, ",".join(quoted_enums)) ) - def visit_ENUM(self, type_): + def visit_ENUM(self, type_, **kw): return self._visit_enumerated_values("ENUM", type_, type_._enumerated_values) - def visit_SET(self, type_): + def visit_SET(self, type_, **kw): return self._visit_enumerated_values("SET", type_, type_._enumerated_values) - def visit_BOOLEAN(self, type): + def visit_BOOLEAN(self, type, **kw): return "BOOL" @@ -2963,6 +3037,9 @@ class MySQLTableDefinitionParser(object): if issubclass(col_type, _EnumeratedValues): type_args = _EnumeratedValues._strip_values(type_args) + if issubclass(col_type, SET) and '' in type_args: + type_kw['retrieve_as_bitwise'] = True + type_instance = col_type(*type_args, **type_kw) col_args, col_kw = [], {} |