diff options
Diffstat (limited to 'examples/pybullet/gym/pybullet_envs/agents/tools/in_graph_env.py')
-rw-r--r-- | examples/pybullet/gym/pybullet_envs/agents/tools/in_graph_env.py | 43 |
1 files changed, 19 insertions, 24 deletions
diff --git a/examples/pybullet/gym/pybullet_envs/agents/tools/in_graph_env.py b/examples/pybullet/gym/pybullet_envs/agents/tools/in_graph_env.py index 33ff31d07..6a71f4516 100644 --- a/examples/pybullet/gym/pybullet_envs/agents/tools/in_graph_env.py +++ b/examples/pybullet/gym/pybullet_envs/agents/tools/in_graph_env.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Put an OpenAI Gym environment into the TensorFlow graph.""" from __future__ import absolute_import @@ -42,16 +41,15 @@ class InGraphEnv(object): action_shape = self._parse_shape(self._env.action_space) action_dtype = self._parse_dtype(self._env.action_space) with tf.name_scope('environment'): - self._observ = tf.Variable( - tf.zeros(observ_shape, observ_dtype), name='observ', trainable=False) - self._action = tf.Variable( - tf.zeros(action_shape, action_dtype), name='action', trainable=False) - self._reward = tf.Variable( - 0.0, dtype=tf.float32, name='reward', trainable=False) - self._done = tf.Variable( - True, dtype=tf.bool, name='done', trainable=False) - self._step = tf.Variable( - 0, dtype=tf.int32, name='step', trainable=False) + self._observ = tf.Variable(tf.zeros(observ_shape, observ_dtype), + name='observ', + trainable=False) + self._action = tf.Variable(tf.zeros(action_shape, action_dtype), + name='action', + trainable=False) + self._reward = tf.Variable(0.0, dtype=tf.float32, name='reward', trainable=False) + self._done = tf.Variable(True, dtype=tf.bool, name='done', trainable=False) + self._step = tf.Variable(0, dtype=tf.int32, name='step', trainable=False) def __getattr__(self, name): """Forward unimplemented attributes to the original environment. @@ -79,17 +77,14 @@ class InGraphEnv(object): if action.dtype in (tf.float16, tf.float32, tf.float64): action = tf.check_numerics(action, 'action') observ_dtype = self._parse_dtype(self._env.observation_space) - observ, reward, done = tf.py_func( - lambda a: self._env.step(a)[:3], [action], - [observ_dtype, tf.float32, tf.bool], name='step') + observ, reward, done = tf.py_func(lambda a: self._env.step(a)[:3], [action], + [observ_dtype, tf.float32, tf.bool], + name='step') observ = tf.check_numerics(observ, 'observ') reward = tf.check_numerics(reward, 'reward') - return tf.group( - self._observ.assign(observ), - self._action.assign(action), - self._reward.assign(reward), - self._done.assign(done), - self._step.assign_add(1)) + return tf.group(self._observ.assign(observ), self._action.assign(action), + self._reward.assign(reward), self._done.assign(done), + self._step.assign_add(1)) def reset(self): """Reset the environment. @@ -100,10 +95,10 @@ class InGraphEnv(object): observ_dtype = self._parse_dtype(self._env.observation_space) observ = tf.py_func(self._env.reset, [], observ_dtype, name='reset') observ = tf.check_numerics(observ, 'observ') - with tf.control_dependencies([ - self._observ.assign(observ), - self._reward.assign(0), - self._done.assign(False)]): + with tf.control_dependencies( + [self._observ.assign(observ), + self._reward.assign(0), + self._done.assign(False)]): return tf.identity(observ) @property |