diff options
author | Min Pae <sputnik13@gmail.com> | 2015-02-15 16:55:06 -0800 |
---|---|---|
committer | Min Pae <sputnik13@gmail.com> | 2015-02-15 23:20:48 -0800 |
commit | 55110111851939cef650ff400f57598f7fb484d2 (patch) | |
tree | 3e80f66f2be64674962c9c44d4ebf3a3b3d2e892 | |
parent | 14009d23341e67ebc6031b04a75401521f8daaa2 (diff) | |
download | taskflow-55110111851939cef650ff400f57598f7fb484d2.tar.gz |
adding check for str/unicode type in requires
When the requires argument for an Atom is passed in as a string,
each character of the string is iterated over to build up a
requirement list. This works for simple one letter argument
names but not for long argument names.
Added check for str and unicode types to prevent iterating over a
string.
Change-Id: Ida584221b48966d26935fb2ede0075aabb7ce972
-rw-r--r-- | taskflow/atom.py | 5 | ||||
-rw-r--r-- | taskflow/tests/unit/test_arguments_passing.py | 10 | ||||
-rw-r--r-- | taskflow/tests/unit/worker_based/test_worker.py | 2 | ||||
-rw-r--r-- | taskflow/tests/utils.py | 6 |
4 files changed, 21 insertions, 2 deletions
diff --git a/taskflow/atom.py b/taskflow/atom.py index 1c5e61e..82f7a5e 100644 --- a/taskflow/atom.py +++ b/taskflow/atom.py @@ -100,7 +100,10 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer, required = {} # add reqs to required mappings if reqs: - required.update((a, a) for a in reqs) + if isinstance(reqs, six.string_types): + required.update({reqs: reqs}) + else: + required.update((a, a) for a in reqs) # add req_args to required mappings if do_infer is set if do_infer: diff --git a/taskflow/tests/unit/test_arguments_passing.py b/taskflow/tests/unit/test_arguments_passing.py index fb4744b..c84d853 100644 --- a/taskflow/tests/unit/test_arguments_passing.py +++ b/taskflow/tests/unit/test_arguments_passing.py @@ -149,6 +149,16 @@ class ArgumentsPassingTest(utils.EngineTestBase): utils.TaskOneArg, rebind=object()) + def test_long_arg_name(self): + flow = utils.LongArgNameTask(requires='long_arg_name', + provides='result') + engine = self._make_engine(flow) + engine.storage.inject({'long_arg_name': 1}) + engine.run() + self.assertEqual(engine.storage.fetch_all(), { + 'long_arg_name': 1, 'result': 1 + }) + class SingleThreadedEngineTest(ArgumentsPassingTest, test.TestCase): diff --git a/taskflow/tests/unit/worker_based/test_worker.py b/taskflow/tests/unit/worker_based/test_worker.py index 7020a93..597a64a 100644 --- a/taskflow/tests/unit/worker_based/test_worker.py +++ b/taskflow/tests/unit/worker_based/test_worker.py @@ -34,7 +34,7 @@ class TestWorker(test.MockTestCase): self.exchange = 'test-exchange' self.topic = 'test-topic' self.threads_count = 5 - self.endpoint_count = 22 + self.endpoint_count = 23 # patch classes self.executor_mock, self.executor_inst_mock = self.patchClass( diff --git a/taskflow/tests/utils.py b/taskflow/tests/utils.py index fbbb83c..9e6e2a3 100644 --- a/taskflow/tests/utils.py +++ b/taskflow/tests/utils.py @@ -95,6 +95,12 @@ class FakeTask(object): pass +class LongArgNameTask(task.Task): + + def execute(self, long_arg_name): + return long_arg_name + + if six.PY3: RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception', 'BaseException', 'object'] |