summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--taskflow/atom.py5
-rw-r--r--taskflow/jobs/backends/impl_zookeeper.py3
-rw-r--r--taskflow/persistence/backends/impl_zookeeper.py8
-rw-r--r--taskflow/tests/unit/test_arguments_passing.py10
-rw-r--r--taskflow/tests/unit/worker_based/test_worker.py2
-rw-r--r--taskflow/tests/utils.py6
6 files changed, 28 insertions, 6 deletions
diff --git a/taskflow/atom.py b/taskflow/atom.py
index 1c5e61e..82f7a5e 100644
--- a/taskflow/atom.py
+++ b/taskflow/atom.py
@@ -100,7 +100,10 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer,
required = {}
# add reqs to required mappings
if reqs:
- required.update((a, a) for a in reqs)
+ if isinstance(reqs, six.string_types):
+ required.update({reqs: reqs})
+ else:
+ required.update((a, a) for a in reqs)
# add req_args to required mappings if do_infer is set
if do_infer:
diff --git a/taskflow/jobs/backends/impl_zookeeper.py b/taskflow/jobs/backends/impl_zookeeper.py
index 3e52f65..87ccac6 100644
--- a/taskflow/jobs/backends/impl_zookeeper.py
+++ b/taskflow/jobs/backends/impl_zookeeper.py
@@ -720,7 +720,8 @@ class ZookeeperJobBoard(base.NotifyingJobBoard):
k_exceptions.KazooException) as e:
raise excp.JobFailure("Failed to connect to zookeeper", e)
try:
- kazoo_utils.check_compatible(self._client, MIN_ZK_VERSION)
+ if self._conf.get('check_compatible', True):
+ kazoo_utils.check_compatible(self._client, MIN_ZK_VERSION)
if self._worker is None and self._emit_notifications:
self._worker = futures.ThreadPoolExecutor(max_workers=1)
self._client.ensure_path(self.path)
diff --git a/taskflow/persistence/backends/impl_zookeeper.py b/taskflow/persistence/backends/impl_zookeeper.py
index 916a889..ae8096f 100644
--- a/taskflow/persistence/backends/impl_zookeeper.py
+++ b/taskflow/persistence/backends/impl_zookeeper.py
@@ -71,7 +71,7 @@ class ZkBackend(base.Backend):
return self._path
def get_connection(self):
- conn = ZkConnection(self, self._client)
+ conn = ZkConnection(self, self._client, self._conf)
if not self._validated:
conn.validate()
self._validated = True
@@ -88,9 +88,10 @@ class ZkBackend(base.Backend):
class ZkConnection(base.Connection):
- def __init__(self, backend, client):
+ def __init__(self, backend, client, conf):
self._backend = backend
self._client = client
+ self._conf = conf
self._book_path = paths.join(self._backend.path, "books")
self._flow_path = paths.join(self._backend.path, "flow_details")
self._atom_path = paths.join(self._backend.path, "atom_details")
@@ -101,7 +102,8 @@ class ZkConnection(base.Connection):
def validate(self):
with self._exc_wrapper():
try:
- k_utils.check_compatible(self._client, MIN_ZK_VERSION)
+ if self._conf.get('check_compatible', True):
+ k_utils.check_compatible(self._client, MIN_ZK_VERSION)
except exc.IncompatibleVersion as e:
raise exc.StorageFailure("Backend storage is not a"
" compatible version", e)
diff --git a/taskflow/tests/unit/test_arguments_passing.py b/taskflow/tests/unit/test_arguments_passing.py
index fb4744b..c84d853 100644
--- a/taskflow/tests/unit/test_arguments_passing.py
+++ b/taskflow/tests/unit/test_arguments_passing.py
@@ -149,6 +149,16 @@ class ArgumentsPassingTest(utils.EngineTestBase):
utils.TaskOneArg,
rebind=object())
+ def test_long_arg_name(self):
+ flow = utils.LongArgNameTask(requires='long_arg_name',
+ provides='result')
+ engine = self._make_engine(flow)
+ engine.storage.inject({'long_arg_name': 1})
+ engine.run()
+ self.assertEqual(engine.storage.fetch_all(), {
+ 'long_arg_name': 1, 'result': 1
+ })
+
class SingleThreadedEngineTest(ArgumentsPassingTest,
test.TestCase):
diff --git a/taskflow/tests/unit/worker_based/test_worker.py b/taskflow/tests/unit/worker_based/test_worker.py
index 7020a93..597a64a 100644
--- a/taskflow/tests/unit/worker_based/test_worker.py
+++ b/taskflow/tests/unit/worker_based/test_worker.py
@@ -34,7 +34,7 @@ class TestWorker(test.MockTestCase):
self.exchange = 'test-exchange'
self.topic = 'test-topic'
self.threads_count = 5
- self.endpoint_count = 22
+ self.endpoint_count = 23
# patch classes
self.executor_mock, self.executor_inst_mock = self.patchClass(
diff --git a/taskflow/tests/utils.py b/taskflow/tests/utils.py
index fbbb83c..9e6e2a3 100644
--- a/taskflow/tests/utils.py
+++ b/taskflow/tests/utils.py
@@ -95,6 +95,12 @@ class FakeTask(object):
pass
+class LongArgNameTask(task.Task):
+
+ def execute(self, long_arg_name):
+ return long_arg_name
+
+
if six.PY3:
RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception',
'BaseException', 'object']