summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMin Pae <sputnik13@gmail.com>2015-02-15 16:55:06 -0800
committerMin Pae <sputnik13@gmail.com>2015-02-15 23:20:48 -0800
commit55110111851939cef650ff400f57598f7fb484d2 (patch)
tree3e80f66f2be64674962c9c44d4ebf3a3b3d2e892
parent14009d23341e67ebc6031b04a75401521f8daaa2 (diff)
downloadtaskflow-55110111851939cef650ff400f57598f7fb484d2.tar.gz
adding check for str/unicode type in requires
When the requires argument for an Atom is passed in as a string, each character of the string is iterated over to build up a requirement list. This works for simple one letter argument names but not for long argument names. Added check for str and unicode types to prevent iterating over a string. Change-Id: Ida584221b48966d26935fb2ede0075aabb7ce972
-rw-r--r--taskflow/atom.py5
-rw-r--r--taskflow/tests/unit/test_arguments_passing.py10
-rw-r--r--taskflow/tests/unit/worker_based/test_worker.py2
-rw-r--r--taskflow/tests/utils.py6
4 files changed, 21 insertions, 2 deletions
diff --git a/taskflow/atom.py b/taskflow/atom.py
index 1c5e61e..82f7a5e 100644
--- a/taskflow/atom.py
+++ b/taskflow/atom.py
@@ -100,7 +100,10 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer,
required = {}
# add reqs to required mappings
if reqs:
- required.update((a, a) for a in reqs)
+ if isinstance(reqs, six.string_types):
+ required.update({reqs: reqs})
+ else:
+ required.update((a, a) for a in reqs)
# add req_args to required mappings if do_infer is set
if do_infer:
diff --git a/taskflow/tests/unit/test_arguments_passing.py b/taskflow/tests/unit/test_arguments_passing.py
index fb4744b..c84d853 100644
--- a/taskflow/tests/unit/test_arguments_passing.py
+++ b/taskflow/tests/unit/test_arguments_passing.py
@@ -149,6 +149,16 @@ class ArgumentsPassingTest(utils.EngineTestBase):
utils.TaskOneArg,
rebind=object())
+ def test_long_arg_name(self):
+ flow = utils.LongArgNameTask(requires='long_arg_name',
+ provides='result')
+ engine = self._make_engine(flow)
+ engine.storage.inject({'long_arg_name': 1})
+ engine.run()
+ self.assertEqual(engine.storage.fetch_all(), {
+ 'long_arg_name': 1, 'result': 1
+ })
+
class SingleThreadedEngineTest(ArgumentsPassingTest,
test.TestCase):
diff --git a/taskflow/tests/unit/worker_based/test_worker.py b/taskflow/tests/unit/worker_based/test_worker.py
index 7020a93..597a64a 100644
--- a/taskflow/tests/unit/worker_based/test_worker.py
+++ b/taskflow/tests/unit/worker_based/test_worker.py
@@ -34,7 +34,7 @@ class TestWorker(test.MockTestCase):
self.exchange = 'test-exchange'
self.topic = 'test-topic'
self.threads_count = 5
- self.endpoint_count = 22
+ self.endpoint_count = 23
# patch classes
self.executor_mock, self.executor_inst_mock = self.patchClass(
diff --git a/taskflow/tests/utils.py b/taskflow/tests/utils.py
index fbbb83c..9e6e2a3 100644
--- a/taskflow/tests/utils.py
+++ b/taskflow/tests/utils.py
@@ -95,6 +95,12 @@ class FakeTask(object):
pass
+class LongArgNameTask(task.Task):
+
+ def execute(self, long_arg_name):
+ return long_arg_name
+
+
if six.PY3:
RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception',
'BaseException', 'object']