summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJenkins <jenkins@review.openstack.org>2015-02-11 08:39:50 +0000
committerGerrit Code Review <review@openstack.org>2015-02-11 08:39:50 +0000
commit761321dec705434befcc9005e16434a46d412c98 (patch)
tree88c97e5c3d99c966d476431497dd0d6720642698
parent687ec913790653f79badc8f5d656c86792e94271 (diff)
parent7f0c457e72a8946a01ff7a93c67e3d35e383728c (diff)
downloadtaskflow-761321dec705434befcc9005e16434a46d412c98.tar.gz
Merge "Map optional arguments as well as required arguments"
-rw-r--r--doc/source/arguments_and_results.rst27
-rw-r--r--taskflow/atom.py79
-rw-r--r--taskflow/engines/action_engine/actions/retry.py9
-rw-r--r--taskflow/engines/action_engine/actions/task.py18
-rw-r--r--taskflow/examples/optional_arguments.py93
-rw-r--r--taskflow/storage.py9
-rw-r--r--taskflow/tests/unit/test_storage.py14
-rw-r--r--taskflow/tests/unit/test_task.py9
8 files changed, 207 insertions, 51 deletions
diff --git a/doc/source/arguments_and_results.rst b/doc/source/arguments_and_results.rst
index cb2c876..998db88 100644
--- a/doc/source/arguments_and_results.rst
+++ b/doc/source/arguments_and_results.rst
@@ -22,9 +22,12 @@ are and how to use those ways to accomplish your desired usage pattern.
Task/retry arguments
Set of names of task/retry arguments available as the ``requires``
- property of the task/retry instance. When a task or retry object is
- about to be executed values with these names are retrieved from storage
- and passed to the ``execute`` method of the task/retry.
+ and/or ``optional`` property of the task/retry instance. When a task or
+ retry object is about to be executed values with these names are
+ retrieved from storage and passed to the ``execute`` method of the
+ task/retry. If any names in the ``requires`` property cannot be
+ found in storage, an exception will be thrown. Any names in the
+ ``optional`` property that cannot be found are ignored.
Task/retry results
Set of names of task/retry results (what task/retry provides) available
@@ -53,32 +56,26 @@ method of a task (or the |retry.execute| of a retry object).
.. doctest::
>>> class MyTask(task.Task):
- ... def execute(self, spam, eggs):
+ ... def execute(self, spam, eggs, bacon=None):
... return spam + eggs
...
>>> sorted(MyTask().requires)
['eggs', 'spam']
+ >>> sorted(MyTask().optional)
+ ['bacon']
Inference from the method signature is the ''simplest'' way to specify
-arguments. Optional arguments (with default values), and special arguments like
-``self``, ``*args`` and ``**kwargs`` are ignored during inference (as these
-names have special meaning/usage in python).
+arguments. Special arguments like ``self``, ``*args`` and ``**kwargs`` are
+ignored during inference (as these names have special meaning/usage in python).
.. doctest::
- >>> class MyTask(task.Task):
- ... def execute(self, spam, eggs=()):
- ... return spam + eggs
- ...
- >>> MyTask().requires
- set(['spam'])
- >>>
>>> class UniTask(task.Task):
... def execute(self, *args, **kwargs):
... pass
...
>>> UniTask().requires
- set([])
+ frozenset([])
.. make vim sphinx highlighter* happy**
diff --git a/taskflow/atom.py b/taskflow/atom.py
index d236ff9..3ece83f 100644
--- a/taskflow/atom.py
+++ b/taskflow/atom.py
@@ -82,33 +82,50 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer,
well as verify that the final argument mapping does not have missing or
extra arguments (where applicable).
"""
- atom_args = reflection.get_callable_args(function, required_only=True)
+
+ # build a list of required arguments based on function signature
+ req_args = reflection.get_callable_args(function, required_only=True)
+ all_args = reflection.get_callable_args(function, required_only=False)
+
+ # remove arguments that are part of ignore_list
if ignore_list:
for arg in ignore_list:
- if arg in atom_args:
- atom_args.remove(arg)
+ if arg in req_args:
+ req_args.remove(arg)
+ else:
+ ignore_list = []
- result = {}
+ required = {}
+ # add reqs to required mappings
if reqs:
- result.update((a, a) for a in reqs)
+ required.update((a, a) for a in reqs)
+
+ # add req_args to required mappings if do_infer is set
+ if do_infer:
+ required.update((a, a) for a in req_args)
+
+ # update required mappings based on rebind_args
+ required.update(_build_rebind_dict(req_args, rebind_args))
+
if do_infer:
- result.update((a, a) for a in atom_args)
- result.update(_build_rebind_dict(atom_args, rebind_args))
+ opt_args = set(all_args) - set(required) - set(ignore_list)
+ optional = dict((a, a) for a in opt_args)
+ else:
+ optional = {}
if not reflection.accepts_kwargs(function):
- all_args = reflection.get_callable_args(function, required_only=False)
- extra_args = set(result) - set(all_args)
+ extra_args = set(required) - set(all_args)
if extra_args:
extra_args_str = ', '.join(sorted(extra_args))
raise ValueError('Extra arguments given to atom %s: %s'
% (atom_name, extra_args_str))
# NOTE(imelnikov): don't use set to preserve order in error message
- missing_args = [arg for arg in atom_args if arg not in result]
+ missing_args = [arg for arg in req_args if arg not in required]
if missing_args:
raise ValueError('Missing arguments for atom %s: %s'
% (atom_name, ' ,'.join(missing_args)))
- return result
+ return required, optional
class Atom(object):
@@ -146,6 +163,13 @@ class Atom(object):
commences (this allows for providing atom *local* values that
do not need to be provided by other atoms/dependents).
:ivar inject: See parameter ``inject``.
+ :ivar requires: Any inputs this atom requires to function (if applicable).
+ NOTE(harlowja): there can be no intersection between what
+ this atom requires and what it produces (since this would
+ be an impossible dependency to satisfy).
+ :ivar optional: Any inputs that are optional for this atom's execute
+ method.
+
"""
def __init__(self, name=None, provides=None, inject=None):
@@ -153,11 +177,27 @@ class Atom(object):
self.save_as = _save_as_to_mapping(provides)
self.version = (1, 0)
self.inject = inject
+ self.requires = frozenset()
+ self.optional = frozenset()
def _build_arg_mapping(self, executor, requires=None, rebind=None,
auto_extract=True, ignore_list=None):
- self.rebind = _build_arg_mapping(self.name, requires, rebind,
- executor, auto_extract, ignore_list)
+ req_arg, opt_arg = _build_arg_mapping(self.name, requires, rebind,
+ executor, auto_extract,
+ ignore_list)
+
+ self.rebind = {}
+ if opt_arg:
+ self.rebind.update(opt_arg)
+ if req_arg:
+ self.rebind.update(req_arg)
+ self.requires = frozenset(req_arg.values())
+ self.optional = frozenset(opt_arg.values())
+ if self.inject:
+ inject_set = set(six.iterkeys(self.inject))
+ self.requires -= inject_set
+ self.optional -= inject_set
+
out_of_order = self.provides.intersection(self.requires)
if out_of_order:
raise exceptions.DependencyFailure(
@@ -185,16 +225,3 @@ class Atom(object):
dependency to satisfy).
"""
return set(self.save_as)
-
- @property
- def requires(self):
- """Any inputs this atom requires to function (if applicable).
-
- NOTE(harlowja): there can be no intersection between what this atom
- requires and what it produces (since this would be an impossible
- dependency to satisfy).
- """
- requires = set(self.rebind.values())
- if self.inject:
- requires = requires - set(six.iterkeys(self.inject))
- return requires
diff --git a/taskflow/engines/action_engine/actions/retry.py b/taskflow/engines/action_engine/actions/retry.py
index bd96c89..05496d9 100644
--- a/taskflow/engines/action_engine/actions/retry.py
+++ b/taskflow/engines/action_engine/actions/retry.py
@@ -54,9 +54,12 @@ class RetryAction(base.Action):
def _get_retry_args(self, retry, addons=None):
scope_walker = self._walker_factory(retry)
- arguments = self._storage.fetch_mapped_args(retry.rebind,
- atom_name=retry.name,
- scope_walker=scope_walker)
+ arguments = self._storage.fetch_mapped_args(
+ retry.rebind,
+ atom_name=retry.name,
+ scope_walker=scope_walker,
+ optional_args=retry.optional
+ )
history = self._storage.get_retry_history(retry.name)
arguments[retry_atom.EXECUTE_REVERT_HISTORY] = history
if addons:
diff --git a/taskflow/engines/action_engine/actions/task.py b/taskflow/engines/action_engine/actions/task.py
index 607b26d..8c64931 100644
--- a/taskflow/engines/action_engine/actions/task.py
+++ b/taskflow/engines/action_engine/actions/task.py
@@ -101,9 +101,12 @@ class TaskAction(base.Action):
def schedule_execution(self, task):
self.change_state(task, states.RUNNING, progress=0.0)
scope_walker = self._walker_factory(task)
- arguments = self._storage.fetch_mapped_args(task.rebind,
- atom_name=task.name,
- scope_walker=scope_walker)
+ arguments = self._storage.fetch_mapped_args(
+ task.rebind,
+ atom_name=task.name,
+ scope_walker=scope_walker,
+ optional_args=task.optional
+ )
if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS):
progress_callback = functools.partial(self._on_update_progress,
task)
@@ -124,9 +127,12 @@ class TaskAction(base.Action):
def schedule_reversion(self, task):
self.change_state(task, states.REVERTING, progress=0.0)
scope_walker = self._walker_factory(task)
- arguments = self._storage.fetch_mapped_args(task.rebind,
- atom_name=task.name,
- scope_walker=scope_walker)
+ arguments = self._storage.fetch_mapped_args(
+ task.rebind,
+ atom_name=task.name,
+ scope_walker=scope_walker,
+ optional_args=task.optional
+ )
task_uuid = self._storage.get_atom_uuid(task.name)
task_result = self._storage.get(task.name)
failures = self._storage.get_failures()
diff --git a/taskflow/examples/optional_arguments.py b/taskflow/examples/optional_arguments.py
new file mode 100644
index 0000000..66a4d38
--- /dev/null
+++ b/taskflow/examples/optional_arguments.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+# Copyright (C) 2015 Hewlett-Packard Development Company, L.P.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+from taskflow import engines
+from taskflow.patterns import linear_flow
+from taskflow import task
+
+
+class TestTask(task.Task):
+ def execute(self, a, b=5):
+ result = a * b
+ return result
+
+flow_no_inject = linear_flow.Flow("flow").add(TestTask(provides='result'))
+flow_inject_a = linear_flow.Flow("flow").add(TestTask(provides='result',
+ inject={'a': 10}))
+flow_inject_b = linear_flow.Flow("flow").add(TestTask(provides='result',
+ inject={'b': 1000}))
+
+ASSERT = True
+
+
+class MyTest(unittest.TestCase):
+ def test_my_test(self):
+ print("Expected result = 15")
+ result = engines.run(flow_no_inject, store={'a': 3})
+ print(result)
+ if ASSERT:
+ self.assertEqual(result,
+ {'a': 3, 'result': 15}
+ )
+
+ print("Expected result = 39")
+ result = engines.run(flow_no_inject, store={'a': 3, 'b': 7})
+ print(result)
+ if ASSERT:
+ self.assertEqual(
+ result,
+ {'a': 3, 'b': 7, 'result': 21}
+ )
+
+ print("Expected result = 200")
+ result = engines.run(flow_inject_a, store={'a': 3})
+ print(result)
+ if ASSERT:
+ self.assertEqual(
+ result,
+ {'a': 3, 'result': 50}
+ )
+
+ print("Expected result = 400")
+ result = engines.run(flow_inject_a, store={'a': 3, 'b': 7})
+ print(result)
+ if ASSERT:
+ self.assertEqual(
+ result,
+ {'a': 3, 'b': 7, 'result': 70}
+ )
+
+ print("Expected result = 40")
+ result = engines.run(flow_inject_b, store={'a': 3})
+ print(result)
+ if ASSERT:
+ self.assertEqual(
+ result,
+ {'a': 3, 'result': 3000}
+ )
+
+ print("Expected result = 40")
+ result = engines.run(flow_inject_b, store={'a': 3, 'b': 7})
+ print(result)
+ if ASSERT:
+ self.assertEqual(
+ result,
+ {'a': 3, 'b': 7, 'result': 3000}
+ )
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/taskflow/storage.py b/taskflow/storage.py
index 8734b2a..e96874a 100644
--- a/taskflow/storage.py
+++ b/taskflow/storage.py
@@ -629,7 +629,8 @@ class Storage(object):
return results
def fetch_mapped_args(self, args_mapping,
- atom_name=None, scope_walker=None):
+ atom_name=None, scope_walker=None,
+ optional_args=None):
"""Fetch arguments for an atom using an atoms argument mapping."""
def _get_results(looking_for, provider):
@@ -668,10 +669,14 @@ class Storage(object):
return []
with self._lock.read_lock():
+ if optional_args is None:
+ optional_args = []
+
if atom_name and atom_name not in self._atom_name_to_uuid:
raise exceptions.NotFound("Unknown atom name: %s" % atom_name)
if not args_mapping:
return {}
+
# The order of lookup is the following:
#
# 1. Injected atom specific arguments.
@@ -705,6 +710,8 @@ class Storage(object):
try:
possible_providers = self._reverse_mapping[name]
except KeyError:
+ if bound_name in optional_args:
+ continue
raise exceptions.NotFound("Name %r is not mapped as a"
" produced output by any"
" providers" % name)
diff --git a/taskflow/tests/unit/test_storage.py b/taskflow/tests/unit/test_storage.py
index af6afb2..a27a811 100644
--- a/taskflow/tests/unit/test_storage.py
+++ b/taskflow/tests/unit/test_storage.py
@@ -351,6 +351,20 @@ class StorageTestMixin(object):
self.assertRaises(exceptions.NotFound,
s.fetch_mapped_args, {'viking': 'helmet'})
+ def test_fetch_optional_args_found(self):
+ s = self._get_storage()
+ s.inject({'foo': 'bar', 'spam': 'eggs'})
+ self.assertEqual(s.fetch_mapped_args({'viking': 'spam'},
+ optional_args=set(['viking'])),
+ {'viking': 'eggs'})
+
+ def test_fetch_optional_args_not_found(self):
+ s = self._get_storage()
+ s.inject({'foo': 'bar', 'spam': 'eggs'})
+ self.assertEqual(s.fetch_mapped_args({'viking': 'helmet'},
+ optional_args=set(['viking'])),
+ {})
+
def test_set_and_get_task_state(self):
s = self._get_storage()
state = states.PENDING
diff --git a/taskflow/tests/unit/test_task.py b/taskflow/tests/unit/test_task.py
index 50e783f..9a9ae1c 100644
--- a/taskflow/tests/unit/test_task.py
+++ b/taskflow/tests/unit/test_task.py
@@ -120,10 +120,19 @@ class TaskTest(test.TestCase):
def test_requires_ignores_optional(self):
my_task = DefaultArgTask()
self.assertEqual(my_task.requires, set(['spam']))
+ self.assertEqual(my_task.optional, set(['eggs']))
def test_requires_allows_optional(self):
my_task = DefaultArgTask(requires=('spam', 'eggs'))
self.assertEqual(my_task.requires, set(['spam', 'eggs']))
+ self.assertEqual(my_task.optional, set())
+
+ def test_rebind_includes_optional(self):
+ my_task = DefaultArgTask()
+ self.assertEqual(my_task.rebind, {
+ 'spam': 'spam',
+ 'eggs': 'eggs',
+ })
def test_rebind_all_args(self):
my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'})