diff options
Diffstat (limited to 'examples/pybullet/gym/pybullet_envs/agents/networks.py')
-rw-r--r-- | examples/pybullet/gym/pybullet_envs/agents/networks.py | 61 |
1 files changed, 29 insertions, 32 deletions
diff --git a/examples/pybullet/gym/pybullet_envs/agents/networks.py b/examples/pybullet/gym/pybullet_envs/agents/networks.py index 3d5de1fbb..3b6802d02 100644 --- a/examples/pybullet/gym/pybullet_envs/agents/networks.py +++ b/examples/pybullet/gym/pybullet_envs/agents/networks.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Network definitions for the PPO algorithm.""" from __future__ import absolute_import @@ -24,13 +23,10 @@ import operator import tensorflow as tf - -NetworkOutput = collections.namedtuple( - 'NetworkOutput', 'policy, mean, logstd, value, state') +NetworkOutput = collections.namedtuple('NetworkOutput', 'policy, mean, logstd, value, state') -def feed_forward_gaussian( - config, action_size, observations, unused_length, state=None): +def feed_forward_gaussian(config, action_size, observations, unused_length, state=None): """Independent feed forward networks for policy and value. The policy network outputs the mean action and the log standard deviation @@ -50,20 +46,22 @@ def feed_forward_gaussian( factor=config.init_mean_factor) logstd_initializer = tf.random_normal_initializer(config.init_logstd, 1e-10) flat_observations = tf.reshape(observations, [ - tf.shape(observations)[0], tf.shape(observations)[1], - functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)]) + tf.shape(observations)[0], + tf.shape(observations)[1], + functools.reduce(operator.mul, + observations.shape.as_list()[2:], 1) + ]) with tf.variable_scope('policy'): x = flat_observations for size in config.policy_layers: x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) - mean = tf.contrib.layers.fully_connected( - x, action_size, tf.tanh, - weights_initializer=mean_weights_initializer) - logstd = tf.get_variable( - 'logstd', mean.shape[2:], tf.float32, logstd_initializer) - logstd = tf.tile( - logstd[None, None], - [tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2)) + mean = tf.contrib.layers.fully_connected(x, + action_size, + tf.tanh, + weights_initializer=mean_weights_initializer) + logstd = tf.get_variable('logstd', mean.shape[2:], tf.float32, logstd_initializer) + logstd = tf.tile(logstd[None, None], + [tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2)) with tf.variable_scope('value'): x = flat_observations for size in config.value_layers: @@ -72,13 +70,11 @@ def feed_forward_gaussian( mean = tf.check_numerics(mean, 'mean') logstd = tf.check_numerics(logstd, 'logstd') value = tf.check_numerics(value, 'value') - policy = tf.contrib.distributions.MultivariateNormalDiag( - mean, tf.exp(logstd)) + policy = tf.contrib.distributions.MultivariateNormalDiag(mean, tf.exp(logstd)) return NetworkOutput(policy, mean, logstd, value, state) -def recurrent_gaussian( - config, action_size, observations, length, state=None): +def recurrent_gaussian(config, action_size, observations, length, state=None): """Independent recurrent policy and feed forward value networks. The policy network outputs the mean action and the log standard deviation @@ -100,21 +96,23 @@ def recurrent_gaussian( logstd_initializer = tf.random_normal_initializer(config.init_logstd, 1e-10) cell = tf.contrib.rnn.GRUBlockCell(config.policy_layers[-1]) flat_observations = tf.reshape(observations, [ - tf.shape(observations)[0], tf.shape(observations)[1], - functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)]) + tf.shape(observations)[0], + tf.shape(observations)[1], + functools.reduce(operator.mul, + observations.shape.as_list()[2:], 1) + ]) with tf.variable_scope('policy'): x = flat_observations for size in config.policy_layers[:-1]: x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) x, state = tf.nn.dynamic_rnn(cell, x, length, state, tf.float32) - mean = tf.contrib.layers.fully_connected( - x, action_size, tf.tanh, - weights_initializer=mean_weights_initializer) - logstd = tf.get_variable( - 'logstd', mean.shape[2:], tf.float32, logstd_initializer) - logstd = tf.tile( - logstd[None, None], - [tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2)) + mean = tf.contrib.layers.fully_connected(x, + action_size, + tf.tanh, + weights_initializer=mean_weights_initializer) + logstd = tf.get_variable('logstd', mean.shape[2:], tf.float32, logstd_initializer) + logstd = tf.tile(logstd[None, None], + [tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2)) with tf.variable_scope('value'): x = flat_observations for size in config.value_layers: @@ -123,7 +121,6 @@ def recurrent_gaussian( mean = tf.check_numerics(mean, 'mean') logstd = tf.check_numerics(logstd, 'logstd') value = tf.check_numerics(value, 'value') - policy = tf.contrib.distributions.MultivariateNormalDiag( - mean, tf.exp(logstd)) + policy = tf.contrib.distributions.MultivariateNormalDiag(mean, tf.exp(logstd)) # assert state.shape.as_list()[0] is not None return NetworkOutput(policy, mean, logstd, value, state) |