diff options
-rw-r--r-- | buildstream/_plugincontext.py | 47 |
1 files changed, 42 insertions, 5 deletions
diff --git a/buildstream/_plugincontext.py b/buildstream/_plugincontext.py index ca57a5b65..8c208af48 100644 --- a/buildstream/_plugincontext.py +++ b/buildstream/_plugincontext.py @@ -19,6 +19,9 @@ # Tristan Van Berkom <tristan.vanberkom@codethink.co.uk> import os +import inspect +import pkg_resources + from .exceptions import PluginError from . import utils @@ -51,6 +54,7 @@ class PluginContext(): self.assert_searchpath(searchpath) # The PluginSource object + self.plugin_base = plugin_base self.source = plugin_base.make_plugin_source(searchpath=searchpath) # lookup(): @@ -70,17 +74,38 @@ class PluginContext(): def ensure_plugin(self, kind): if kind not in self.types: - if kind not in self.source.list_plugins(): + source = None + dist, package = self.split_name(kind) + + if dist: + # Find the plugin on disk using setuptools - this + # potentially unpacks the file and puts it in a + # temporary directory, but it is otherwise guaranteed + # to exist. + plugin = pkg_resources.get_entry_info(dist, 'buildstream.plugins', package) + location = plugin.dist.get_resource_filename( + pkg_resources._manager, + plugin.module_name.replace('.', os.sep) + '.py' + ) + + # Set the plugin-base source to the setuptools directory + source = self.plugin_base.make_plugin_source(searchpath=[os.path.dirname(location)]) + + elif package in self.source.list_plugins(): + source = self.source + + if not source: raise PluginError("No {} type registered for kind '{}'" .format(self.base_type.__name__, kind)) - self.load_plugin(kind) + self.types[kind] = self.load_plugin(source, package) return self.types[kind] - def load_plugin(self, kind): + def load_plugin(self, source, kind): try: - plugin = self.source.load_plugin(kind) + plugin = source.load_plugin(kind) + except ImportError as e: raise PluginError("Failed to load {} plugin '{}': {}" .format(self.base_type.__name__, kind, e)) from e @@ -96,7 +121,19 @@ class PluginContext(): self.assert_plugin(kind, plugin_type) self.assert_version(kind, plugin_type) - self.types[kind] = plugin_type + return plugin_type + + def split_name(self, name): + if name.count(':') > 1: + raise PluginError("Plugin and package names must not contain ':'") + + try: + dist, kind = name.split(':', maxsplit=1) + except ValueError: + dist = None + kind = name + + return dist, kind def assert_plugin(self, kind, plugin_type): if kind in self.types: |