summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/mysql/enumerated.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/mysql/enumerated.py')
-rw-r--r--lib/sqlalchemy/dialects/mysql/enumerated.py31
1 files changed, 15 insertions, 16 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/enumerated.py b/lib/sqlalchemy/dialects/mysql/enumerated.py
index 495bee5a8..e67177b2f 100644
--- a/lib/sqlalchemy/dialects/mysql/enumerated.py
+++ b/lib/sqlalchemy/dialects/mysql/enumerated.py
@@ -9,7 +9,7 @@ import re
from .types import _StringType
from ... import exc, sql, util
-from ... import types as sqltypes
+from ...sql import sqltypes
class _EnumeratedValues(_StringType):
@@ -55,11 +55,13 @@ class _EnumeratedValues(_StringType):
return strip_values
-class ENUM(sqltypes.Enum, _EnumeratedValues):
+class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _EnumeratedValues):
"""MySQL ENUM type."""
__visit_name__ = 'ENUM'
+ native_enum = True
+
def __init__(self, *enums, **kw):
"""Construct an ENUM.
@@ -114,21 +116,21 @@ class ENUM(sqltypes.Enum, _EnumeratedValues):
"""
kw.pop('strict', None)
- validate_strings = kw.pop("validate_strings", False)
- sqltypes.Enum.__init__(
- self, validate_strings=validate_strings, *enums)
- kw.pop('metadata', None)
- kw.pop('schema', None)
- kw.pop('name', None)
- kw.pop('quote', None)
- kw.pop('native_enum', None)
- kw.pop('inherit_schema', None)
- kw.pop('_create_events', None)
+ self._enum_init(enums, kw)
_StringType.__init__(self, length=self.length, **kw)
+ @classmethod
+ def adapt_emulated_to_native(cls, impl, **kw):
+ """Produce a MySQL native :class:`.mysql.ENUM` from plain
+ :class:`.Enum`.
+
+ """
+ kw.setdefault("validate_strings", impl.validate_strings)
+ return cls(**kw)
+
def _setup_for_values(self, values, objects, kw):
values, length = self._init_values(values, kw)
- return sqltypes.Enum._setup_for_values(self, values, objects, kw)
+ return super(ENUM, self)._setup_for_values(values, objects, kw)
def _object_value_for_elem(self, elem):
# mysql sends back a blank string for any value that
@@ -144,9 +146,6 @@ class ENUM(sqltypes.Enum, _EnumeratedValues):
return util.generic_repr(
self, to_inspect=[ENUM, _StringType, sqltypes.Enum])
- def adapt(self, cls, **kw):
- return sqltypes.Enum.adapt(self, cls, **kw)
-
class SET(_EnumeratedValues):
"""MySQL SET type."""