summaryrefslogtreecommitdiff
path: root/examples/pybullet/gym/pybullet_envs/agents/tools/in_graph_env.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/pybullet/gym/pybullet_envs/agents/tools/in_graph_env.py')
-rw-r--r--examples/pybullet/gym/pybullet_envs/agents/tools/in_graph_env.py43
1 files changed, 19 insertions, 24 deletions
diff --git a/examples/pybullet/gym/pybullet_envs/agents/tools/in_graph_env.py b/examples/pybullet/gym/pybullet_envs/agents/tools/in_graph_env.py
index 33ff31d07..6a71f4516 100644
--- a/examples/pybullet/gym/pybullet_envs/agents/tools/in_graph_env.py
+++ b/examples/pybullet/gym/pybullet_envs/agents/tools/in_graph_env.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Put an OpenAI Gym environment into the TensorFlow graph."""
from __future__ import absolute_import
@@ -42,16 +41,15 @@ class InGraphEnv(object):
action_shape = self._parse_shape(self._env.action_space)
action_dtype = self._parse_dtype(self._env.action_space)
with tf.name_scope('environment'):
- self._observ = tf.Variable(
- tf.zeros(observ_shape, observ_dtype), name='observ', trainable=False)
- self._action = tf.Variable(
- tf.zeros(action_shape, action_dtype), name='action', trainable=False)
- self._reward = tf.Variable(
- 0.0, dtype=tf.float32, name='reward', trainable=False)
- self._done = tf.Variable(
- True, dtype=tf.bool, name='done', trainable=False)
- self._step = tf.Variable(
- 0, dtype=tf.int32, name='step', trainable=False)
+ self._observ = tf.Variable(tf.zeros(observ_shape, observ_dtype),
+ name='observ',
+ trainable=False)
+ self._action = tf.Variable(tf.zeros(action_shape, action_dtype),
+ name='action',
+ trainable=False)
+ self._reward = tf.Variable(0.0, dtype=tf.float32, name='reward', trainable=False)
+ self._done = tf.Variable(True, dtype=tf.bool, name='done', trainable=False)
+ self._step = tf.Variable(0, dtype=tf.int32, name='step', trainable=False)
def __getattr__(self, name):
"""Forward unimplemented attributes to the original environment.
@@ -79,17 +77,14 @@ class InGraphEnv(object):
if action.dtype in (tf.float16, tf.float32, tf.float64):
action = tf.check_numerics(action, 'action')
observ_dtype = self._parse_dtype(self._env.observation_space)
- observ, reward, done = tf.py_func(
- lambda a: self._env.step(a)[:3], [action],
- [observ_dtype, tf.float32, tf.bool], name='step')
+ observ, reward, done = tf.py_func(lambda a: self._env.step(a)[:3], [action],
+ [observ_dtype, tf.float32, tf.bool],
+ name='step')
observ = tf.check_numerics(observ, 'observ')
reward = tf.check_numerics(reward, 'reward')
- return tf.group(
- self._observ.assign(observ),
- self._action.assign(action),
- self._reward.assign(reward),
- self._done.assign(done),
- self._step.assign_add(1))
+ return tf.group(self._observ.assign(observ), self._action.assign(action),
+ self._reward.assign(reward), self._done.assign(done),
+ self._step.assign_add(1))
def reset(self):
"""Reset the environment.
@@ -100,10 +95,10 @@ class InGraphEnv(object):
observ_dtype = self._parse_dtype(self._env.observation_space)
observ = tf.py_func(self._env.reset, [], observ_dtype, name='reset')
observ = tf.check_numerics(observ, 'observ')
- with tf.control_dependencies([
- self._observ.assign(observ),
- self._reward.assign(0),
- self._done.assign(False)]):
+ with tf.control_dependencies(
+ [self._observ.assign(observ),
+ self._reward.assign(0),
+ self._done.assign(False)]):
return tf.identity(observ)
@property