diff options
-rw-r--r-- | taskflow/atom.py | 5 | ||||
-rw-r--r-- | taskflow/jobs/backends/impl_zookeeper.py | 3 | ||||
-rw-r--r-- | taskflow/persistence/backends/impl_zookeeper.py | 8 | ||||
-rw-r--r-- | taskflow/tests/unit/test_arguments_passing.py | 10 | ||||
-rw-r--r-- | taskflow/tests/unit/worker_based/test_worker.py | 2 | ||||
-rw-r--r-- | taskflow/tests/utils.py | 6 |
6 files changed, 28 insertions, 6 deletions
diff --git a/taskflow/atom.py b/taskflow/atom.py index 1c5e61e..82f7a5e 100644 --- a/taskflow/atom.py +++ b/taskflow/atom.py @@ -100,7 +100,10 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer, required = {} # add reqs to required mappings if reqs: - required.update((a, a) for a in reqs) + if isinstance(reqs, six.string_types): + required.update({reqs: reqs}) + else: + required.update((a, a) for a in reqs) # add req_args to required mappings if do_infer is set if do_infer: diff --git a/taskflow/jobs/backends/impl_zookeeper.py b/taskflow/jobs/backends/impl_zookeeper.py index 3e52f65..87ccac6 100644 --- a/taskflow/jobs/backends/impl_zookeeper.py +++ b/taskflow/jobs/backends/impl_zookeeper.py @@ -720,7 +720,8 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): k_exceptions.KazooException) as e: raise excp.JobFailure("Failed to connect to zookeeper", e) try: - kazoo_utils.check_compatible(self._client, MIN_ZK_VERSION) + if self._conf.get('check_compatible', True): + kazoo_utils.check_compatible(self._client, MIN_ZK_VERSION) if self._worker is None and self._emit_notifications: self._worker = futures.ThreadPoolExecutor(max_workers=1) self._client.ensure_path(self.path) diff --git a/taskflow/persistence/backends/impl_zookeeper.py b/taskflow/persistence/backends/impl_zookeeper.py index 916a889..ae8096f 100644 --- a/taskflow/persistence/backends/impl_zookeeper.py +++ b/taskflow/persistence/backends/impl_zookeeper.py @@ -71,7 +71,7 @@ class ZkBackend(base.Backend): return self._path def get_connection(self): - conn = ZkConnection(self, self._client) + conn = ZkConnection(self, self._client, self._conf) if not self._validated: conn.validate() self._validated = True @@ -88,9 +88,10 @@ class ZkBackend(base.Backend): class ZkConnection(base.Connection): - def __init__(self, backend, client): + def __init__(self, backend, client, conf): self._backend = backend self._client = client + self._conf = conf self._book_path = paths.join(self._backend.path, "books") self._flow_path = paths.join(self._backend.path, "flow_details") self._atom_path = paths.join(self._backend.path, "atom_details") @@ -101,7 +102,8 @@ class ZkConnection(base.Connection): def validate(self): with self._exc_wrapper(): try: - k_utils.check_compatible(self._client, MIN_ZK_VERSION) + if self._conf.get('check_compatible', True): + k_utils.check_compatible(self._client, MIN_ZK_VERSION) except exc.IncompatibleVersion as e: raise exc.StorageFailure("Backend storage is not a" " compatible version", e) diff --git a/taskflow/tests/unit/test_arguments_passing.py b/taskflow/tests/unit/test_arguments_passing.py index fb4744b..c84d853 100644 --- a/taskflow/tests/unit/test_arguments_passing.py +++ b/taskflow/tests/unit/test_arguments_passing.py @@ -149,6 +149,16 @@ class ArgumentsPassingTest(utils.EngineTestBase): utils.TaskOneArg, rebind=object()) + def test_long_arg_name(self): + flow = utils.LongArgNameTask(requires='long_arg_name', + provides='result') + engine = self._make_engine(flow) + engine.storage.inject({'long_arg_name': 1}) + engine.run() + self.assertEqual(engine.storage.fetch_all(), { + 'long_arg_name': 1, 'result': 1 + }) + class SingleThreadedEngineTest(ArgumentsPassingTest, test.TestCase): diff --git a/taskflow/tests/unit/worker_based/test_worker.py b/taskflow/tests/unit/worker_based/test_worker.py index 7020a93..597a64a 100644 --- a/taskflow/tests/unit/worker_based/test_worker.py +++ b/taskflow/tests/unit/worker_based/test_worker.py @@ -34,7 +34,7 @@ class TestWorker(test.MockTestCase): self.exchange = 'test-exchange' self.topic = 'test-topic' self.threads_count = 5 - self.endpoint_count = 22 + self.endpoint_count = 23 # patch classes self.executor_mock, self.executor_inst_mock = self.patchClass( diff --git a/taskflow/tests/utils.py b/taskflow/tests/utils.py index fbbb83c..9e6e2a3 100644 --- a/taskflow/tests/utils.py +++ b/taskflow/tests/utils.py @@ -95,6 +95,12 @@ class FakeTask(object): pass +class LongArgNameTask(task.Task): + + def execute(self, long_arg_name): + return long_arg_name + + if six.PY3: RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception', 'BaseException', 'object'] |