summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJenkins <jenkins@review.openstack.org>2015-02-16 19:15:44 +0000
committerGerrit Code Review <review@openstack.org>2015-02-16 19:15:44 +0000
commit77e6b99afc316ebc8e9200703bdb843452e7b62d (patch)
tree7e14bf940d618eaff6dc778d00dccf7b82accbd6
parenta3f126f0e66fd5cf98db52e6ce7a03de1c4ccd3e (diff)
parent55110111851939cef650ff400f57598f7fb484d2 (diff)
downloadtaskflow-77e6b99afc316ebc8e9200703bdb843452e7b62d.tar.gz
Merge "adding check for str/unicode type in requires"
-rw-r--r--taskflow/atom.py5
-rw-r--r--taskflow/tests/unit/test_arguments_passing.py10
-rw-r--r--taskflow/tests/unit/worker_based/test_worker.py2
-rw-r--r--taskflow/tests/utils.py6
4 files changed, 21 insertions, 2 deletions
diff --git a/taskflow/atom.py b/taskflow/atom.py
index 1c5e61e..82f7a5e 100644
--- a/taskflow/atom.py
+++ b/taskflow/atom.py
@@ -100,7 +100,10 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer,
required = {}
# add reqs to required mappings
if reqs:
- required.update((a, a) for a in reqs)
+ if isinstance(reqs, six.string_types):
+ required.update({reqs: reqs})
+ else:
+ required.update((a, a) for a in reqs)
# add req_args to required mappings if do_infer is set
if do_infer:
diff --git a/taskflow/tests/unit/test_arguments_passing.py b/taskflow/tests/unit/test_arguments_passing.py
index fb4744b..c84d853 100644
--- a/taskflow/tests/unit/test_arguments_passing.py
+++ b/taskflow/tests/unit/test_arguments_passing.py
@@ -149,6 +149,16 @@ class ArgumentsPassingTest(utils.EngineTestBase):
utils.TaskOneArg,
rebind=object())
+ def test_long_arg_name(self):
+ flow = utils.LongArgNameTask(requires='long_arg_name',
+ provides='result')
+ engine = self._make_engine(flow)
+ engine.storage.inject({'long_arg_name': 1})
+ engine.run()
+ self.assertEqual(engine.storage.fetch_all(), {
+ 'long_arg_name': 1, 'result': 1
+ })
+
class SingleThreadedEngineTest(ArgumentsPassingTest,
test.TestCase):
diff --git a/taskflow/tests/unit/worker_based/test_worker.py b/taskflow/tests/unit/worker_based/test_worker.py
index 7020a93..597a64a 100644
--- a/taskflow/tests/unit/worker_based/test_worker.py
+++ b/taskflow/tests/unit/worker_based/test_worker.py
@@ -34,7 +34,7 @@ class TestWorker(test.MockTestCase):
self.exchange = 'test-exchange'
self.topic = 'test-topic'
self.threads_count = 5
- self.endpoint_count = 22
+ self.endpoint_count = 23
# patch classes
self.executor_mock, self.executor_inst_mock = self.patchClass(
diff --git a/taskflow/tests/utils.py b/taskflow/tests/utils.py
index fbbb83c..9e6e2a3 100644
--- a/taskflow/tests/utils.py
+++ b/taskflow/tests/utils.py
@@ -95,6 +95,12 @@ class FakeTask(object):
pass
+class LongArgNameTask(task.Task):
+
+ def execute(self, long_arg_name):
+ return long_arg_name
+
+
if six.PY3:
RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception',
'BaseException', 'object']