diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-04-24 13:00:30 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-04-24 13:00:30 -0400 |
| commit | 71c00115747d2fb13423b0b18e728b402f117528 (patch) | |
| tree | 64362d2cab5db6af78b45c0304ad98e1c0ab5a0f /lib/sqlalchemy | |
| parent | 998c66fa8b1997453c793da5faa7d4cc436739b2 (diff) | |
| download | sqlalchemy-71c00115747d2fb13423b0b18e728b402f117528.tar.gz | |
- [feature] Added a new system
for registration of new dialects in-process
without using an entrypoint. See the
docs for "Registering New Dialects".
[ticket:2462]
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/__init__.py | 28 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/url.py | 57 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/__init__.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/langhelpers.py | 39 |
4 files changed, 83 insertions, 43 deletions
diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py index 2d4832412..16eb32e21 100644 --- a/lib/sqlalchemy/dialects/__init__.py +++ b/lib/sqlalchemy/dialects/__init__.py @@ -17,3 +17,31 @@ __all__ = ( 'sqlite', 'sybase', ) + +from sqlalchemy import util + +def _auto_fn(name): + """default dialect importer. + + plugs into the :class:`.PluginLoader` + as a first-hit system. + + """ + if "." in name: + dialect, driver = name.split(".") + else: + dialect = name + driver = "base" + try: + module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects + except ImportError: + return None + + module = getattr(module, dialect) + if hasattr(module, driver): + module = getattr(module, driver) + return lambda: module.dialect + else: + return None + +registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn)
\ No newline at end of file diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 392ecda11..5bbdb9d65 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -14,6 +14,7 @@ be used directly and is also accepted directly by ``create_engine()``. import re, urllib from sqlalchemy import exc, util +from sqlalchemy.engine import base class URL(object): @@ -96,49 +97,21 @@ class URL(object): to this URL's driver name. """ - try: - if '+' in self.drivername: - dialect, driver = self.drivername.split('+') - else: - dialect, driver = self.drivername, 'base' - - module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects - module = getattr(module, dialect) - if hasattr(module, driver): - module = getattr(module, driver) - else: - module = self._load_entry_point() - if module is None: - raise exc.ArgumentError( - "Could not determine dialect for '%s'." % - self.drivername) - - return module.dialect - except ImportError: - module = self._load_entry_point() - if module is not None: - return module - else: - raise exc.ArgumentError( - "Could not determine dialect for '%s'." % self.drivername) - - def _load_entry_point(self): - """attempt to load this url's dialect from entry points, or return None - if pkg_resources is not installed or there is no matching entry point. - - Raise ImportError if the actual load fails. - - """ - try: - import pkg_resources - except ImportError: - return None - - for res in pkg_resources.iter_entry_points('sqlalchemy.dialects'): - if res.name == self.drivername.replace("+", "."): - return res.load() + if '+' not in self.drivername: + name = self.drivername + else: + name = self.drivername.replace('+', '.') + from sqlalchemy.dialects import registry + cls = registry.load(name) + # check for legacy dialects that + # would return a module with 'dialect' as the + # actual class + if hasattr(cls, 'dialect') and \ + isinstance(cls.dialect, type) and \ + issubclass(cls.dialect, base.Dialect): + return cls.dialect else: - return None + return cls def translate_connect_args(self, names=[], **kw): """Translate url attributes into a dictionary of connection arguments. diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 13914aa7d..76c3c829d 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -27,7 +27,7 @@ from langhelpers import iterate_attributes, class_hierarchy, \ duck_type_collection, assert_arg_type, symbol, dictlike_iteritems,\ classproperty, set_creation_order, warn_exception, warn, NoneType,\ constructor_copy, methods_equivalent, chop_traceback, asint,\ - generic_repr, counter + generic_repr, counter, PluginLoader from deprecations import warn_deprecated, warn_pending_deprecation, \ deprecated, pending_deprecation diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index d266c9664..9e5b0e4ad 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -52,6 +52,45 @@ def decorator(target): return update_wrapper(decorated, fn) return update_wrapper(decorate, target) +class PluginLoader(object): + def __init__(self, group, auto_fn=None): + self.group = group + self.impls = {} + self.auto_fn = auto_fn + + def load(self, name): + if name in self.impls: + return self.impls[name]() + + if self.auto_fn: + loader = self.auto_fn(name) + if loader: + self.impls[name] = loader + return loader() + + try: + import pkg_resources + except ImportError: + pass + else: + for impl in pkg_resources.iter_entry_points( + self.group, name): + self.impls[name] = impl.load + return impl.load() + + from sqlalchemy import exc + raise exc.ArgumentError( + "Can't load plugin: %s:%s" % + (self.group, name)) + + def register(self, name, modulepath, objname): + def load(): + mod = __import__(modulepath) + for token in modulepath.split(".")[1:]: + mod = getattr(mod, token) + return getattr(mod, objname) + self.impls[name] = load + def get_cls_kwargs(cls): """Return the full set of inherited kwargs for the given `cls`. |
