diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/mysql/enumerated.py')
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/enumerated.py | 31 |
1 files changed, 15 insertions, 16 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/enumerated.py b/lib/sqlalchemy/dialects/mysql/enumerated.py index 495bee5a8..e67177b2f 100644 --- a/lib/sqlalchemy/dialects/mysql/enumerated.py +++ b/lib/sqlalchemy/dialects/mysql/enumerated.py @@ -9,7 +9,7 @@ import re from .types import _StringType from ... import exc, sql, util -from ... import types as sqltypes +from ...sql import sqltypes class _EnumeratedValues(_StringType): @@ -55,11 +55,13 @@ class _EnumeratedValues(_StringType): return strip_values -class ENUM(sqltypes.Enum, _EnumeratedValues): +class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _EnumeratedValues): """MySQL ENUM type.""" __visit_name__ = 'ENUM' + native_enum = True + def __init__(self, *enums, **kw): """Construct an ENUM. @@ -114,21 +116,21 @@ class ENUM(sqltypes.Enum, _EnumeratedValues): """ kw.pop('strict', None) - validate_strings = kw.pop("validate_strings", False) - sqltypes.Enum.__init__( - self, validate_strings=validate_strings, *enums) - kw.pop('metadata', None) - kw.pop('schema', None) - kw.pop('name', None) - kw.pop('quote', None) - kw.pop('native_enum', None) - kw.pop('inherit_schema', None) - kw.pop('_create_events', None) + self._enum_init(enums, kw) _StringType.__init__(self, length=self.length, **kw) + @classmethod + def adapt_emulated_to_native(cls, impl, **kw): + """Produce a MySQL native :class:`.mysql.ENUM` from plain + :class:`.Enum`. + + """ + kw.setdefault("validate_strings", impl.validate_strings) + return cls(**kw) + def _setup_for_values(self, values, objects, kw): values, length = self._init_values(values, kw) - return sqltypes.Enum._setup_for_values(self, values, objects, kw) + return super(ENUM, self)._setup_for_values(values, objects, kw) def _object_value_for_elem(self, elem): # mysql sends back a blank string for any value that @@ -144,9 +146,6 @@ class ENUM(sqltypes.Enum, _EnumeratedValues): return util.generic_repr( self, to_inspect=[ENUM, _StringType, sqltypes.Enum]) - def adapt(self, cls, **kw): - return sqltypes.Enum.adapt(self, cls, **kw) - class SET(_EnumeratedValues): """MySQL SET type.""" |
