summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMin Pae <sputnik13@gmail.com>2015-01-30 22:49:28 -0800
committerMin Pae <sputnik13@gmail.com>2015-02-10 22:32:38 -0800
commit7f0c457e72a8946a01ff7a93c67e3d35e383728c (patch)
treed0c1d8a330775e7215aa94c8ef669f1e9c18947c
parenteae693406ef3205d33c3bc712e19041d0a19f3bc (diff)
downloadtaskflow-7f0c457e72a8946a01ff7a93c67e3d35e383728c.tar.gz
Map optional arguments as well as required arguments
Optional arguments that are not explicitly required are being ignored when arguments are being mapped based on inference from atoms' execute method signatures. This patch adds support for mapping optional arguments in addition to required arguments. Change-Id: I440c02dcd901a563df512e33754b13e3c05d4155
-rw-r--r--doc/source/arguments_and_results.rst27
-rw-r--r--taskflow/atom.py79
-rw-r--r--taskflow/engines/action_engine/actions/retry.py9
-rw-r--r--taskflow/engines/action_engine/actions/task.py18
-rw-r--r--taskflow/examples/optional_arguments.py93
-rw-r--r--taskflow/storage.py9
-rw-r--r--taskflow/tests/unit/test_storage.py14
-rw-r--r--taskflow/tests/unit/test_task.py9
8 files changed, 207 insertions, 51 deletions
diff --git a/doc/source/arguments_and_results.rst b/doc/source/arguments_and_results.rst
index cb2c876..998db88 100644
--- a/doc/source/arguments_and_results.rst
+++ b/doc/source/arguments_and_results.rst
@@ -22,9 +22,12 @@ are and how to use those ways to accomplish your desired usage pattern.
Task/retry arguments
Set of names of task/retry arguments available as the ``requires``
- property of the task/retry instance. When a task or retry object is
- about to be executed values with these names are retrieved from storage
- and passed to the ``execute`` method of the task/retry.
+ and/or ``optional`` property of the task/retry instance. When a task or
+ retry object is about to be executed values with these names are
+ retrieved from storage and passed to the ``execute`` method of the
+ task/retry. If any names in the ``requires`` property cannot be
+ found in storage, an exception will be thrown. Any names in the
+ ``optional`` property that cannot be found are ignored.
Task/retry results
Set of names of task/retry results (what task/retry provides) available
@@ -53,32 +56,26 @@ method of a task (or the |retry.execute| of a retry object).
.. doctest::
>>> class MyTask(task.Task):
- ... def execute(self, spam, eggs):
+ ... def execute(self, spam, eggs, bacon=None):
... return spam + eggs
...
>>> sorted(MyTask().requires)
['eggs', 'spam']
+ >>> sorted(MyTask().optional)
+ ['bacon']
Inference from the method signature is the ''simplest'' way to specify
-arguments. Optional arguments (with default values), and special arguments like
-``self``, ``*args`` and ``**kwargs`` are ignored during inference (as these
-names have special meaning/usage in python).
+arguments. Special arguments like ``self``, ``*args`` and ``**kwargs`` are
+ignored during inference (as these names have special meaning/usage in python).
.. doctest::
- >>> class MyTask(task.Task):
- ... def execute(self, spam, eggs=()):
- ... return spam + eggs
- ...
- >>> MyTask().requires
- set(['spam'])
- >>>
>>> class UniTask(task.Task):
... def execute(self, *args, **kwargs):
... pass
...
>>> UniTask().requires
- set([])
+ frozenset([])
.. make vim sphinx highlighter* happy**
diff --git a/taskflow/atom.py b/taskflow/atom.py
index d236ff9..3ece83f 100644
--- a/taskflow/atom.py
+++ b/taskflow/atom.py
@@ -82,33 +82,50 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer,
well as verify that the final argument mapping does not have missing or
extra arguments (where applicable).
"""
- atom_args = reflection.get_callable_args(function, required_only=True)
+
+ # build a list of required arguments based on function signature
+ req_args = reflection.get_callable_args(function, required_only=True)
+ all_args = reflection.get_callable_args(function, required_only=False)
+
+ # remove arguments that are part of ignore_list
if ignore_list:
for arg in ignore_list:
- if arg in atom_args:
- atom_args.remove(arg)
+ if arg in req_args:
+ req_args.remove(arg)
+ else:
+ ignore_list = []
- result = {}
+ required = {}
+ # add reqs to required mappings
if reqs:
- result.update((a, a) for a in reqs)
+ required.update((a, a) for a in reqs)
+
+ # add req_args to required mappings if do_infer is set
+ if do_infer:
+ required.update((a, a) for a in req_args)
+
+ # update required mappings based on rebind_args
+ required.update(_build_rebind_dict(req_args, rebind_args))
+
if do_infer:
- result.update((a, a) for a in atom_args)
- result.update(_build_rebind_dict(atom_args, rebind_args))
+ opt_args = set(all_args) - set(required) - set(ignore_list)
+ optional = dict((a, a) for a in opt_args)
+ else:
+ optional = {}
if not reflection.accepts_kwargs(function):
- all_args = reflection.get_callable_args(function, required_only=False)
- extra_args = set(result) - set(all_args)
+ extra_args = set(required) - set(all_args)
if extra_args:
extra_args_str = ', '.join(sorted(extra_args))
raise ValueError('Extra arguments given to atom %s: %s'
% (atom_name, extra_args_str))
# NOTE(imelnikov): don't use set to preserve order in error message
- missing_args = [arg for arg in atom_args if arg not in result]
+ missing_args = [arg for arg in req_args if arg not in required]
if missing_args:
raise ValueError('Missing arguments for atom %s: %s'
% (atom_name, ' ,'.join(missing_args)))
- return result
+ return required, optional
class Atom(object):
@@ -146,6 +163,13 @@ class Atom(object):
commences (this allows for providing atom *local* values that
do not need to be provided by other atoms/dependents).
:ivar inject: See parameter ``inject``.
+ :ivar requires: Any inputs this atom requires to function (if applicable).
+ NOTE(harlowja): there can be no intersection between what
+ this atom requires and what it produces (since this would
+ be an impossible dependency to satisfy).
+ :ivar optional: Any inputs that are optional for this atom's execute
+ method.
+
"""
def __init__(self, name=None, provides=None, inject=None):
@@ -153,11 +177,27 @@ class Atom(object):
self.save_as = _save_as_to_mapping(provides)
self.version = (1, 0)
self.inject = inject
+ self.requires = frozenset()
+ self.optional = frozenset()
def _build_arg_mapping(self, executor, requires=None, rebind=None,
auto_extract=True, ignore_list=None):
- self.rebind = _build_arg_mapping(self.name, requires, rebind,
- executor, auto_extract, ignore_list)
+ req_arg, opt_arg = _build_arg_mapping(self.name, requires, rebind,
+ executor, auto_extract,
+ ignore_list)
+
+ self.rebind = {}
+ if opt_arg:
+ self.rebind.update(opt_arg)
+ if req_arg:
+ self.rebind.update(req_arg)
+ self.requires = frozenset(req_arg.values())
+ self.optional = frozenset(opt_arg.values())
+ if self.inject:
+ inject_set = set(six.iterkeys(self.inject))
+ self.requires -= inject_set
+ self.optional -= inject_set
+
out_of_order = self.provides.intersection(self.requires)
if out_of_order:
raise exceptions.DependencyFailure(
@@ -185,16 +225,3 @@ class Atom(object):
dependency to satisfy).
"""
return set(self.save_as)
-
- @property
- def requires(self):
- """Any inputs this atom requires to function (if applicable).
-
- NOTE(harlowja): there can be no intersection between what this atom
- requires and what it produces (since this would be an impossible
- dependency to satisfy).
- """
- requires = set(self.rebind.values())
- if self.inject:
- requires = requires - set(six.iterkeys(self.inject))
- return requires
diff --git a/taskflow/engines/action_engine/actions/retry.py b/taskflow/engines/action_engine/actions/retry.py
index bd96c89..05496d9 100644
--- a/taskflow/engines/action_engine/actions/retry.py
+++ b/taskflow/engines/action_engine/actions/retry.py
@@ -54,9 +54,12 @@ class RetryAction(base.Action):
def _get_retry_args(self, retry, addons=None):
scope_walker = self._walker_factory(retry)
- arguments = self._storage.fetch_mapped_args(retry.rebind,
- atom_name=retry.name,
- scope_walker=scope_walker)
+ arguments = self._storage.fetch_mapped_args(
+ retry.rebind,
+ atom_name=retry.name,
+ scope_walker=scope_walker,
+ optional_args=retry.optional
+ )
history = self._storage.get_retry_history(retry.name)
arguments[retry_atom.EXECUTE_REVERT_HISTORY] = history
if addons:
diff --git a/taskflow/engines/action_engine/actions/task.py b/taskflow/engines/action_engine/actions/task.py
index 607b26d..8c64931 100644
--- a/taskflow/engines/action_engine/actions/task.py
+++ b/taskflow/engines/action_engine/actions/task.py
@@ -101,9 +101,12 @@ class TaskAction(base.Action):
def schedule_execution(self, task):
self.change_state(task, states.RUNNING, progress=0.0)
scope_walker = self._walker_factory(task)
- arguments = self._storage.fetch_mapped_args(task.rebind,
- atom_name=task.name,
- scope_walker=scope_walker)
+ arguments = self._storage.fetch_mapped_args(
+ task.rebind,
+ atom_name=task.name,
+ scope_walker=scope_walker,
+ optional_args=task.optional
+ )
if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS):
progress_callback = functools.partial(self._on_update_progress,
task)
@@ -124,9 +127,12 @@ class TaskAction(base.Action):
def schedule_reversion(self, task):
self.change_state(task, states.REVERTING, progress=0.0)
scope_walker = self._walker_factory(task)
- arguments = self._storage.fetch_mapped_args(task.rebind,
- atom_name=task.name,
- scope_walker=scope_walker)
+ arguments = self._storage.fetch_mapped_args(
+ task.rebind,
+ atom_name=task.name,
+ scope_walker=scope_walker,
+ optional_args=task.optional
+ )
task_uuid = self._storage.get_atom_uuid(task.name)
task_result = self._storage.get(task.name)
failures = self._storage.get_failures()
diff --git a/taskflow/examples/optional_arguments.py b/taskflow/examples/optional_arguments.py
new file mode 100644
index 0000000..66a4d38
--- /dev/null
+++ b/taskflow/examples/optional_arguments.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+# Copyright (C) 2015 Hewlett-Packard Development Company, L.P.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+from taskflow import engines
+from taskflow.patterns import linear_flow
+from taskflow import task
+
+
+class TestTask(task.Task):
+ def execute(self, a, b=5):
+ result = a * b
+ return result
+
+flow_no_inject = linear_flow.Flow("flow").add(TestTask(provides='result'))
+flow_inject_a = linear_flow.Flow("flow").add(TestTask(provides='result',
+ inject={'a': 10}))
+flow_inject_b = linear_flow.Flow("flow").add(TestTask(provides='result',
+ inject={'b': 1000}))
+
+ASSERT = True
+
+
+class MyTest(unittest.TestCase):
+ def test_my_test(self):
+ print("Expected result = 15")
+ result = engines.run(flow_no_inject, store={'a': 3})
+ print(result)
+ if ASSERT:
+ self.assertEqual(result,
+ {'a': 3, 'result': 15}
+ )
+
+ print("Expected result = 39")
+ result = engines.run(flow_no_inject, store={'a': 3, 'b': 7})
+ print(result)
+ if ASSERT:
+ self.assertEqual(
+ result,
+ {'a': 3, 'b': 7, 'result': 21}
+ )
+
+ print("Expected result = 200")
+ result = engines.run(flow_inject_a, store={'a': 3})
+ print(result)
+ if ASSERT:
+ self.assertEqual(
+ result,
+ {'a': 3, 'result': 50}
+ )
+
+ print("Expected result = 400")
+ result = engines.run(flow_inject_a, store={'a': 3, 'b': 7})
+ print(result)
+ if ASSERT:
+ self.assertEqual(
+ result,
+ {'a': 3, 'b': 7, 'result': 70}
+ )
+
+ print("Expected result = 40")
+ result = engines.run(flow_inject_b, store={'a': 3})
+ print(result)
+ if ASSERT:
+ self.assertEqual(
+ result,
+ {'a': 3, 'result': 3000}
+ )
+
+ print("Expected result = 40")
+ result = engines.run(flow_inject_b, store={'a': 3, 'b': 7})
+ print(result)
+ if ASSERT:
+ self.assertEqual(
+ result,
+ {'a': 3, 'b': 7, 'result': 3000}
+ )
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/taskflow/storage.py b/taskflow/storage.py
index df80148..7b63d58 100644
--- a/taskflow/storage.py
+++ b/taskflow/storage.py
@@ -635,7 +635,8 @@ class Storage(object):
return results
def fetch_mapped_args(self, args_mapping,
- atom_name=None, scope_walker=None):
+ atom_name=None, scope_walker=None,
+ optional_args=None):
"""Fetch arguments for an atom using an atoms argument mapping."""
def _get_results(looking_for, provider):
@@ -674,10 +675,14 @@ class Storage(object):
return []
with self._lock.read_lock():
+ if optional_args is None:
+ optional_args = []
+
if atom_name and atom_name not in self._atom_name_to_uuid:
raise exceptions.NotFound("Unknown atom name: %s" % atom_name)
if not args_mapping:
return {}
+
# The order of lookup is the following:
#
# 1. Injected atom specific arguments.
@@ -711,6 +716,8 @@ class Storage(object):
try:
possible_providers = self._reverse_mapping[name]
except KeyError:
+ if bound_name in optional_args:
+ continue
raise exceptions.NotFound("Name %r is not mapped as a"
" produced output by any"
" providers" % name)
diff --git a/taskflow/tests/unit/test_storage.py b/taskflow/tests/unit/test_storage.py
index 5f521af..d221a7f 100644
--- a/taskflow/tests/unit/test_storage.py
+++ b/taskflow/tests/unit/test_storage.py
@@ -354,6 +354,20 @@ class StorageTestMixin(object):
self.assertRaises(exceptions.NotFound,
s.fetch_mapped_args, {'viking': 'helmet'})
+ def test_fetch_optional_args_found(self):
+ s = self._get_storage()
+ s.inject({'foo': 'bar', 'spam': 'eggs'})
+ self.assertEqual(s.fetch_mapped_args({'viking': 'spam'},
+ optional_args=set(['viking'])),
+ {'viking': 'eggs'})
+
+ def test_fetch_optional_args_not_found(self):
+ s = self._get_storage()
+ s.inject({'foo': 'bar', 'spam': 'eggs'})
+ self.assertEqual(s.fetch_mapped_args({'viking': 'helmet'},
+ optional_args=set(['viking'])),
+ {})
+
def test_set_and_get_task_state(self):
s = self._get_storage()
state = states.PENDING
diff --git a/taskflow/tests/unit/test_task.py b/taskflow/tests/unit/test_task.py
index 50e783f..9a9ae1c 100644
--- a/taskflow/tests/unit/test_task.py
+++ b/taskflow/tests/unit/test_task.py
@@ -120,10 +120,19 @@ class TaskTest(test.TestCase):
def test_requires_ignores_optional(self):
my_task = DefaultArgTask()
self.assertEqual(my_task.requires, set(['spam']))
+ self.assertEqual(my_task.optional, set(['eggs']))
def test_requires_allows_optional(self):
my_task = DefaultArgTask(requires=('spam', 'eggs'))
self.assertEqual(my_task.requires, set(['spam', 'eggs']))
+ self.assertEqual(my_task.optional, set())
+
+ def test_rebind_includes_optional(self):
+ my_task = DefaultArgTask()
+ self.assertEqual(my_task.rebind, {
+ 'spam': 'spam',
+ 'eggs': 'eggs',
+ })
def test_rebind_all_args(self):
my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'})