diff options
author | Min Pae <sputnik13@gmail.com> | 2015-01-30 22:49:28 -0800 |
---|---|---|
committer | Min Pae <sputnik13@gmail.com> | 2015-02-10 22:32:38 -0800 |
commit | 7f0c457e72a8946a01ff7a93c67e3d35e383728c (patch) | |
tree | d0c1d8a330775e7215aa94c8ef669f1e9c18947c | |
parent | eae693406ef3205d33c3bc712e19041d0a19f3bc (diff) | |
download | taskflow-7f0c457e72a8946a01ff7a93c67e3d35e383728c.tar.gz |
Map optional arguments as well as required arguments
Optional arguments that are not explicitly required are being ignored
when arguments are being mapped based on inference from atoms' execute
method signatures. This patch adds support for mapping optional
arguments in addition to required arguments.
Change-Id: I440c02dcd901a563df512e33754b13e3c05d4155
-rw-r--r-- | doc/source/arguments_and_results.rst | 27 | ||||
-rw-r--r-- | taskflow/atom.py | 79 | ||||
-rw-r--r-- | taskflow/engines/action_engine/actions/retry.py | 9 | ||||
-rw-r--r-- | taskflow/engines/action_engine/actions/task.py | 18 | ||||
-rw-r--r-- | taskflow/examples/optional_arguments.py | 93 | ||||
-rw-r--r-- | taskflow/storage.py | 9 | ||||
-rw-r--r-- | taskflow/tests/unit/test_storage.py | 14 | ||||
-rw-r--r-- | taskflow/tests/unit/test_task.py | 9 |
8 files changed, 207 insertions, 51 deletions
diff --git a/doc/source/arguments_and_results.rst b/doc/source/arguments_and_results.rst index cb2c876..998db88 100644 --- a/doc/source/arguments_and_results.rst +++ b/doc/source/arguments_and_results.rst @@ -22,9 +22,12 @@ are and how to use those ways to accomplish your desired usage pattern. Task/retry arguments Set of names of task/retry arguments available as the ``requires`` - property of the task/retry instance. When a task or retry object is - about to be executed values with these names are retrieved from storage - and passed to the ``execute`` method of the task/retry. + and/or ``optional`` property of the task/retry instance. When a task or + retry object is about to be executed values with these names are + retrieved from storage and passed to the ``execute`` method of the + task/retry. If any names in the ``requires`` property cannot be + found in storage, an exception will be thrown. Any names in the + ``optional`` property that cannot be found are ignored. Task/retry results Set of names of task/retry results (what task/retry provides) available @@ -53,32 +56,26 @@ method of a task (or the |retry.execute| of a retry object). .. doctest:: >>> class MyTask(task.Task): - ... def execute(self, spam, eggs): + ... def execute(self, spam, eggs, bacon=None): ... return spam + eggs ... >>> sorted(MyTask().requires) ['eggs', 'spam'] + >>> sorted(MyTask().optional) + ['bacon'] Inference from the method signature is the ''simplest'' way to specify -arguments. Optional arguments (with default values), and special arguments like -``self``, ``*args`` and ``**kwargs`` are ignored during inference (as these -names have special meaning/usage in python). +arguments. Special arguments like ``self``, ``*args`` and ``**kwargs`` are +ignored during inference (as these names have special meaning/usage in python). .. doctest:: - >>> class MyTask(task.Task): - ... def execute(self, spam, eggs=()): - ... return spam + eggs - ... - >>> MyTask().requires - set(['spam']) - >>> >>> class UniTask(task.Task): ... def execute(self, *args, **kwargs): ... pass ... >>> UniTask().requires - set([]) + frozenset([]) .. make vim sphinx highlighter* happy** diff --git a/taskflow/atom.py b/taskflow/atom.py index d236ff9..3ece83f 100644 --- a/taskflow/atom.py +++ b/taskflow/atom.py @@ -82,33 +82,50 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer, well as verify that the final argument mapping does not have missing or extra arguments (where applicable). """ - atom_args = reflection.get_callable_args(function, required_only=True) + + # build a list of required arguments based on function signature + req_args = reflection.get_callable_args(function, required_only=True) + all_args = reflection.get_callable_args(function, required_only=False) + + # remove arguments that are part of ignore_list if ignore_list: for arg in ignore_list: - if arg in atom_args: - atom_args.remove(arg) + if arg in req_args: + req_args.remove(arg) + else: + ignore_list = [] - result = {} + required = {} + # add reqs to required mappings if reqs: - result.update((a, a) for a in reqs) + required.update((a, a) for a in reqs) + + # add req_args to required mappings if do_infer is set + if do_infer: + required.update((a, a) for a in req_args) + + # update required mappings based on rebind_args + required.update(_build_rebind_dict(req_args, rebind_args)) + if do_infer: - result.update((a, a) for a in atom_args) - result.update(_build_rebind_dict(atom_args, rebind_args)) + opt_args = set(all_args) - set(required) - set(ignore_list) + optional = dict((a, a) for a in opt_args) + else: + optional = {} if not reflection.accepts_kwargs(function): - all_args = reflection.get_callable_args(function, required_only=False) - extra_args = set(result) - set(all_args) + extra_args = set(required) - set(all_args) if extra_args: extra_args_str = ', '.join(sorted(extra_args)) raise ValueError('Extra arguments given to atom %s: %s' % (atom_name, extra_args_str)) # NOTE(imelnikov): don't use set to preserve order in error message - missing_args = [arg for arg in atom_args if arg not in result] + missing_args = [arg for arg in req_args if arg not in required] if missing_args: raise ValueError('Missing arguments for atom %s: %s' % (atom_name, ' ,'.join(missing_args))) - return result + return required, optional class Atom(object): @@ -146,6 +163,13 @@ class Atom(object): commences (this allows for providing atom *local* values that do not need to be provided by other atoms/dependents). :ivar inject: See parameter ``inject``. + :ivar requires: Any inputs this atom requires to function (if applicable). + NOTE(harlowja): there can be no intersection between what + this atom requires and what it produces (since this would + be an impossible dependency to satisfy). + :ivar optional: Any inputs that are optional for this atom's execute + method. + """ def __init__(self, name=None, provides=None, inject=None): @@ -153,11 +177,27 @@ class Atom(object): self.save_as = _save_as_to_mapping(provides) self.version = (1, 0) self.inject = inject + self.requires = frozenset() + self.optional = frozenset() def _build_arg_mapping(self, executor, requires=None, rebind=None, auto_extract=True, ignore_list=None): - self.rebind = _build_arg_mapping(self.name, requires, rebind, - executor, auto_extract, ignore_list) + req_arg, opt_arg = _build_arg_mapping(self.name, requires, rebind, + executor, auto_extract, + ignore_list) + + self.rebind = {} + if opt_arg: + self.rebind.update(opt_arg) + if req_arg: + self.rebind.update(req_arg) + self.requires = frozenset(req_arg.values()) + self.optional = frozenset(opt_arg.values()) + if self.inject: + inject_set = set(six.iterkeys(self.inject)) + self.requires -= inject_set + self.optional -= inject_set + out_of_order = self.provides.intersection(self.requires) if out_of_order: raise exceptions.DependencyFailure( @@ -185,16 +225,3 @@ class Atom(object): dependency to satisfy). """ return set(self.save_as) - - @property - def requires(self): - """Any inputs this atom requires to function (if applicable). - - NOTE(harlowja): there can be no intersection between what this atom - requires and what it produces (since this would be an impossible - dependency to satisfy). - """ - requires = set(self.rebind.values()) - if self.inject: - requires = requires - set(six.iterkeys(self.inject)) - return requires diff --git a/taskflow/engines/action_engine/actions/retry.py b/taskflow/engines/action_engine/actions/retry.py index bd96c89..05496d9 100644 --- a/taskflow/engines/action_engine/actions/retry.py +++ b/taskflow/engines/action_engine/actions/retry.py @@ -54,9 +54,12 @@ class RetryAction(base.Action): def _get_retry_args(self, retry, addons=None): scope_walker = self._walker_factory(retry) - arguments = self._storage.fetch_mapped_args(retry.rebind, - atom_name=retry.name, - scope_walker=scope_walker) + arguments = self._storage.fetch_mapped_args( + retry.rebind, + atom_name=retry.name, + scope_walker=scope_walker, + optional_args=retry.optional + ) history = self._storage.get_retry_history(retry.name) arguments[retry_atom.EXECUTE_REVERT_HISTORY] = history if addons: diff --git a/taskflow/engines/action_engine/actions/task.py b/taskflow/engines/action_engine/actions/task.py index 607b26d..8c64931 100644 --- a/taskflow/engines/action_engine/actions/task.py +++ b/taskflow/engines/action_engine/actions/task.py @@ -101,9 +101,12 @@ class TaskAction(base.Action): def schedule_execution(self, task): self.change_state(task, states.RUNNING, progress=0.0) scope_walker = self._walker_factory(task) - arguments = self._storage.fetch_mapped_args(task.rebind, - atom_name=task.name, - scope_walker=scope_walker) + arguments = self._storage.fetch_mapped_args( + task.rebind, + atom_name=task.name, + scope_walker=scope_walker, + optional_args=task.optional + ) if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS): progress_callback = functools.partial(self._on_update_progress, task) @@ -124,9 +127,12 @@ class TaskAction(base.Action): def schedule_reversion(self, task): self.change_state(task, states.REVERTING, progress=0.0) scope_walker = self._walker_factory(task) - arguments = self._storage.fetch_mapped_args(task.rebind, - atom_name=task.name, - scope_walker=scope_walker) + arguments = self._storage.fetch_mapped_args( + task.rebind, + atom_name=task.name, + scope_walker=scope_walker, + optional_args=task.optional + ) task_uuid = self._storage.get_atom_uuid(task.name) task_result = self._storage.get(task.name) failures = self._storage.get_failures() diff --git a/taskflow/examples/optional_arguments.py b/taskflow/examples/optional_arguments.py new file mode 100644 index 0000000..66a4d38 --- /dev/null +++ b/taskflow/examples/optional_arguments.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2015 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +from taskflow import engines +from taskflow.patterns import linear_flow +from taskflow import task + + +class TestTask(task.Task): + def execute(self, a, b=5): + result = a * b + return result + +flow_no_inject = linear_flow.Flow("flow").add(TestTask(provides='result')) +flow_inject_a = linear_flow.Flow("flow").add(TestTask(provides='result', + inject={'a': 10})) +flow_inject_b = linear_flow.Flow("flow").add(TestTask(provides='result', + inject={'b': 1000})) + +ASSERT = True + + +class MyTest(unittest.TestCase): + def test_my_test(self): + print("Expected result = 15") + result = engines.run(flow_no_inject, store={'a': 3}) + print(result) + if ASSERT: + self.assertEqual(result, + {'a': 3, 'result': 15} + ) + + print("Expected result = 39") + result = engines.run(flow_no_inject, store={'a': 3, 'b': 7}) + print(result) + if ASSERT: + self.assertEqual( + result, + {'a': 3, 'b': 7, 'result': 21} + ) + + print("Expected result = 200") + result = engines.run(flow_inject_a, store={'a': 3}) + print(result) + if ASSERT: + self.assertEqual( + result, + {'a': 3, 'result': 50} + ) + + print("Expected result = 400") + result = engines.run(flow_inject_a, store={'a': 3, 'b': 7}) + print(result) + if ASSERT: + self.assertEqual( + result, + {'a': 3, 'b': 7, 'result': 70} + ) + + print("Expected result = 40") + result = engines.run(flow_inject_b, store={'a': 3}) + print(result) + if ASSERT: + self.assertEqual( + result, + {'a': 3, 'result': 3000} + ) + + print("Expected result = 40") + result = engines.run(flow_inject_b, store={'a': 3, 'b': 7}) + print(result) + if ASSERT: + self.assertEqual( + result, + {'a': 3, 'b': 7, 'result': 3000} + ) + +if __name__ == '__main__': + unittest.main() diff --git a/taskflow/storage.py b/taskflow/storage.py index df80148..7b63d58 100644 --- a/taskflow/storage.py +++ b/taskflow/storage.py @@ -635,7 +635,8 @@ class Storage(object): return results def fetch_mapped_args(self, args_mapping, - atom_name=None, scope_walker=None): + atom_name=None, scope_walker=None, + optional_args=None): """Fetch arguments for an atom using an atoms argument mapping.""" def _get_results(looking_for, provider): @@ -674,10 +675,14 @@ class Storage(object): return [] with self._lock.read_lock(): + if optional_args is None: + optional_args = [] + if atom_name and atom_name not in self._atom_name_to_uuid: raise exceptions.NotFound("Unknown atom name: %s" % atom_name) if not args_mapping: return {} + # The order of lookup is the following: # # 1. Injected atom specific arguments. @@ -711,6 +716,8 @@ class Storage(object): try: possible_providers = self._reverse_mapping[name] except KeyError: + if bound_name in optional_args: + continue raise exceptions.NotFound("Name %r is not mapped as a" " produced output by any" " providers" % name) diff --git a/taskflow/tests/unit/test_storage.py b/taskflow/tests/unit/test_storage.py index 5f521af..d221a7f 100644 --- a/taskflow/tests/unit/test_storage.py +++ b/taskflow/tests/unit/test_storage.py @@ -354,6 +354,20 @@ class StorageTestMixin(object): self.assertRaises(exceptions.NotFound, s.fetch_mapped_args, {'viking': 'helmet'}) + def test_fetch_optional_args_found(self): + s = self._get_storage() + s.inject({'foo': 'bar', 'spam': 'eggs'}) + self.assertEqual(s.fetch_mapped_args({'viking': 'spam'}, + optional_args=set(['viking'])), + {'viking': 'eggs'}) + + def test_fetch_optional_args_not_found(self): + s = self._get_storage() + s.inject({'foo': 'bar', 'spam': 'eggs'}) + self.assertEqual(s.fetch_mapped_args({'viking': 'helmet'}, + optional_args=set(['viking'])), + {}) + def test_set_and_get_task_state(self): s = self._get_storage() state = states.PENDING diff --git a/taskflow/tests/unit/test_task.py b/taskflow/tests/unit/test_task.py index 50e783f..9a9ae1c 100644 --- a/taskflow/tests/unit/test_task.py +++ b/taskflow/tests/unit/test_task.py @@ -120,10 +120,19 @@ class TaskTest(test.TestCase): def test_requires_ignores_optional(self): my_task = DefaultArgTask() self.assertEqual(my_task.requires, set(['spam'])) + self.assertEqual(my_task.optional, set(['eggs'])) def test_requires_allows_optional(self): my_task = DefaultArgTask(requires=('spam', 'eggs')) self.assertEqual(my_task.requires, set(['spam', 'eggs'])) + self.assertEqual(my_task.optional, set()) + + def test_rebind_includes_optional(self): + my_task = DefaultArgTask() + self.assertEqual(my_task.rebind, { + 'spam': 'spam', + 'eggs': 'eggs', + }) def test_rebind_all_args(self): my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'}) |