summaryrefslogtreecommitdiff
path: root/examples/pybullet/gym/pybullet_envs/agents/networks.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/pybullet/gym/pybullet_envs/agents/networks.py')
-rw-r--r--examples/pybullet/gym/pybullet_envs/agents/networks.py61
1 files changed, 29 insertions, 32 deletions
diff --git a/examples/pybullet/gym/pybullet_envs/agents/networks.py b/examples/pybullet/gym/pybullet_envs/agents/networks.py
index 3d5de1fbb..3b6802d02 100644
--- a/examples/pybullet/gym/pybullet_envs/agents/networks.py
+++ b/examples/pybullet/gym/pybullet_envs/agents/networks.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Network definitions for the PPO algorithm."""
from __future__ import absolute_import
@@ -24,13 +23,10 @@ import operator
import tensorflow as tf
-
-NetworkOutput = collections.namedtuple(
- 'NetworkOutput', 'policy, mean, logstd, value, state')
+NetworkOutput = collections.namedtuple('NetworkOutput', 'policy, mean, logstd, value, state')
-def feed_forward_gaussian(
- config, action_size, observations, unused_length, state=None):
+def feed_forward_gaussian(config, action_size, observations, unused_length, state=None):
"""Independent feed forward networks for policy and value.
The policy network outputs the mean action and the log standard deviation
@@ -50,20 +46,22 @@ def feed_forward_gaussian(
factor=config.init_mean_factor)
logstd_initializer = tf.random_normal_initializer(config.init_logstd, 1e-10)
flat_observations = tf.reshape(observations, [
- tf.shape(observations)[0], tf.shape(observations)[1],
- functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
+ tf.shape(observations)[0],
+ tf.shape(observations)[1],
+ functools.reduce(operator.mul,
+ observations.shape.as_list()[2:], 1)
+ ])
with tf.variable_scope('policy'):
x = flat_observations
for size in config.policy_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
- mean = tf.contrib.layers.fully_connected(
- x, action_size, tf.tanh,
- weights_initializer=mean_weights_initializer)
- logstd = tf.get_variable(
- 'logstd', mean.shape[2:], tf.float32, logstd_initializer)
- logstd = tf.tile(
- logstd[None, None],
- [tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
+ mean = tf.contrib.layers.fully_connected(x,
+ action_size,
+ tf.tanh,
+ weights_initializer=mean_weights_initializer)
+ logstd = tf.get_variable('logstd', mean.shape[2:], tf.float32, logstd_initializer)
+ logstd = tf.tile(logstd[None, None],
+ [tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
with tf.variable_scope('value'):
x = flat_observations
for size in config.value_layers:
@@ -72,13 +70,11 @@ def feed_forward_gaussian(
mean = tf.check_numerics(mean, 'mean')
logstd = tf.check_numerics(logstd, 'logstd')
value = tf.check_numerics(value, 'value')
- policy = tf.contrib.distributions.MultivariateNormalDiag(
- mean, tf.exp(logstd))
+ policy = tf.contrib.distributions.MultivariateNormalDiag(mean, tf.exp(logstd))
return NetworkOutput(policy, mean, logstd, value, state)
-def recurrent_gaussian(
- config, action_size, observations, length, state=None):
+def recurrent_gaussian(config, action_size, observations, length, state=None):
"""Independent recurrent policy and feed forward value networks.
The policy network outputs the mean action and the log standard deviation
@@ -100,21 +96,23 @@ def recurrent_gaussian(
logstd_initializer = tf.random_normal_initializer(config.init_logstd, 1e-10)
cell = tf.contrib.rnn.GRUBlockCell(config.policy_layers[-1])
flat_observations = tf.reshape(observations, [
- tf.shape(observations)[0], tf.shape(observations)[1],
- functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
+ tf.shape(observations)[0],
+ tf.shape(observations)[1],
+ functools.reduce(operator.mul,
+ observations.shape.as_list()[2:], 1)
+ ])
with tf.variable_scope('policy'):
x = flat_observations
for size in config.policy_layers[:-1]:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
x, state = tf.nn.dynamic_rnn(cell, x, length, state, tf.float32)
- mean = tf.contrib.layers.fully_connected(
- x, action_size, tf.tanh,
- weights_initializer=mean_weights_initializer)
- logstd = tf.get_variable(
- 'logstd', mean.shape[2:], tf.float32, logstd_initializer)
- logstd = tf.tile(
- logstd[None, None],
- [tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
+ mean = tf.contrib.layers.fully_connected(x,
+ action_size,
+ tf.tanh,
+ weights_initializer=mean_weights_initializer)
+ logstd = tf.get_variable('logstd', mean.shape[2:], tf.float32, logstd_initializer)
+ logstd = tf.tile(logstd[None, None],
+ [tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
with tf.variable_scope('value'):
x = flat_observations
for size in config.value_layers:
@@ -123,7 +121,6 @@ def recurrent_gaussian(
mean = tf.check_numerics(mean, 'mean')
logstd = tf.check_numerics(logstd, 'logstd')
value = tf.check_numerics(value, 'value')
- policy = tf.contrib.distributions.MultivariateNormalDiag(
- mean, tf.exp(logstd))
+ policy = tf.contrib.distributions.MultivariateNormalDiag(mean, tf.exp(logstd))
# assert state.shape.as_list()[0] is not None
return NetworkOutput(policy, mean, logstd, value, state)