From 9da81f57dcf7ffb1e43ef1f331635a6a85bb2da7 Mon Sep 17 00:00:00 2001 From: dhruvat Date: Fri, 6 Aug 2021 16:43:06 +0100 Subject: [PATCH] Added continuous retrace ops to trfl losses. PiperOrigin-RevId: 389183481 --- trfl/__init__.py | 2 + trfl/continuous_retrace_ops.py | 230 ++++++++++++++++++++++++++++ trfl/continuous_retrace_ops_test.py | 126 +++++++++++++++ 3 files changed, 358 insertions(+) create mode 100644 trfl/continuous_retrace_ops.py create mode 100644 trfl/continuous_retrace_ops_test.py diff --git a/trfl/__init__.py b/trfl/__init__.py index 946e72d..50bbe18 100644 --- a/trfl/__init__.py +++ b/trfl/__init__.py @@ -25,6 +25,8 @@ from trfl.base_ops import assert_rank_and_shape_compatibility from trfl.base_ops import best_effort_shape from trfl.clipping_ops import huber_loss +from trfl.continuous_retrace_ops import retrace_from_action_log_probs +from trfl.continuous_retrace_ops import retrace_from_importance_weights from trfl.discrete_policy_gradient_ops import discrete_policy_entropy_loss from trfl.discrete_policy_gradient_ops import discrete_policy_gradient from trfl.discrete_policy_gradient_ops import discrete_policy_gradient_loss diff --git a/trfl/continuous_retrace_ops.py b/trfl/continuous_retrace_ops.py new file mode 100644 index 0000000..7c3791e --- /dev/null +++ b/trfl/continuous_retrace_ops.py @@ -0,0 +1,230 @@ +# Copyright 2018 The trfl Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""TensorFlow ops for the Retrace algorithm and continuous actions. + +Safe and Efficient Off-Policy Reinforcement Learning +R. Munos, T. Stepleton, A. Harutyunyan, M. G. Bellemare +https://arxiv.org/abs/1606.02647 + +This variant is commonly used to update the Q function in RS0, which +additionally uses SVG or a SVG variant to update the policy. + +Learning by Playing - Solving Sparse Reward Tasks from Scratch +M. Riedmiller, R. Hafner, T. Lampe, M. Neunert, J. Degrave, T. Van de Wiele, +V. Mnih, N. Heess, J. T. Springenberg +https://arxiv.org/abs/1802.10567 + +Learning Continuous Control Policies by Stochastic Value Gradients +N. Heess, G. Wayne, D. Silver, T. Lillicrap, Y. Tassa, T. Erez +https://arxiv.org/abs/1510.09142 + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import tensorflow.compat.v1 as tf + + +QTraceReturns = collections.namedtuple("QTraceReturns", [ + "qs", "importance_weights", "log_importance_weights", + "truncated_importance_weights", "deltas", "vs_minus_q_xs" +]) + + +def retrace_from_action_log_probs( + behaviour_action_log_probs, + target_action_log_probs, + discounts, + rewards, + q_values, + values, + bootstrap_value, + lambda_=1., + name="retrace_from_action_log_probs"): + """Constructs Q/Retrace ops. + + This is an implementation of Retrace. In the description of the arguments + the notation is as follows: `T` refers to the sequence size over which + the return is calculated, finally `B` denotes the batch size. + + Args: + behaviour_action_log_probs: Log-probabilities. Shape [T, B]. + target_action_log_probs: Log-probabilities for target policy. Shape [T, B] + discounts: Also called pcontinues. Discount encountered when following + the behaviour policy. Shape [T, B]. + rewards: A tensor containing rewards generated by following the behaviour + policy. Shape [T, B]. + q_values: Q-function estimates wrt. the target policy. Shape [T, B]. + values: Value function estimates wrt. the target policy. Shape [T, B]. + bootstrap_value: Value function estimate at time `T`. Shape [B]. + lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1). + name: The name scope that all qtrace ops will be created in. + + Returns: + A `QTraceReturns` namedtuple containing: + + * qs: The Retrace regression/policy gradient targets. + Can be used to calculate estimates of the advantage for policy + gradients or as regression target for Q-value functions. Shape [T, B]. + * importance_weights: Importance sampling weights. Shape [T, B]. + * log_importance_weights: Importance sampling weights. Shape [T, B]. + * truncated_importance_weights: Called c_t in the paper. Shape [T, B]. + * deltas: Shape [T, B] + * vs_minus_q_xs: Q-Retrace targets - Q(x_s, u_s). Shape [T, B]. + """ + # Turn arguments to tensors. + behaviour_action_log_probs = tf.convert_to_tensor( + behaviour_action_log_probs, dtype=tf.float32) + target_action_log_probs = tf.convert_to_tensor( + target_action_log_probs, dtype=tf.float32) + values = tf.convert_to_tensor(values, dtype=tf.float32) + q_values = tf.convert_to_tensor(q_values, dtype=tf.float32) + bootstrap_value = tf.convert_to_tensor(bootstrap_value, dtype=tf.float32) + discounts = tf.convert_to_tensor(discounts, dtype=tf.float32) + rewards = tf.convert_to_tensor(rewards, dtype=tf.float32) + + # Make sure tensor ranks are as expected. + behaviour_action_log_probs.get_shape().assert_has_rank(2) + target_action_log_probs.get_shape().assert_has_rank(2) + values.get_shape().assert_has_rank(2) + q_values.get_shape().assert_has_rank(2) + bootstrap_value.get_shape().assert_has_rank(1) + discounts.get_shape().assert_has_rank(2) + rewards.get_shape().assert_has_rank(2) + + with tf.name_scope( + name, + values=[ + behaviour_action_log_probs, target_action_log_probs, discounts, + rewards, q_values, values, bootstrap_value + ]): + log_rhos = target_action_log_probs - behaviour_action_log_probs + return retrace_from_importance_weights( + log_rhos=log_rhos, + discounts=discounts, + rewards=rewards, + q_values=q_values, + values=values, + bootstrap_value=bootstrap_value, + lambda_=lambda_) + + +def retrace_from_importance_weights(log_rhos, + discounts, + rewards, + q_values, + values, + bootstrap_value, + lambda_=1.0, + name="retrace_from_importance_weights"): + """Constructs Q/Retrace ops. + + This is an implementation of Retrace. In the description of the arguments + the notation is as follows: `T` refers to the sequence size over which + the return is calculated, finally `B` denotes the batch size. + + Args: + log_rhos: Log-probabilities for target policy. Shape [T, B] + discounts: Also called pcontinues. Discount encountered when following + the behaviour policy. Shape [T, B]. + rewards: A tensor containing rewards generated by following the behaviour + policy. Shape [T, B]. + q_values: Q-function estimates wrt. the target policy. Shape [T, B]. + values: Value function estimates wrt. the target policy. Shape [T, B]. + bootstrap_value: Value function estimate at time `T`. Shape [B]. + lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1). + name: The name scope that all qtrace ops will be created in. + + Returns: + A `QTraceReturns` namedtuple containing: + + * qs: The Retrace regression/policy gradient targets. + Can be used to calculate estimates of the advantage for policy + gradients or as regression target for Q-value functions. Shape [T, B]. + * importance_weights: Importance sampling weights. Shape [T, B]. + * log_importance_weights: Importance sampling weights. Shape [T, B]. + * truncated_importance_weights: Called c_t in the paper. Shape [T, B]. + * deltas: Shape [T, B] + * vs_minus_q_xs: Q-Retrace targets - Q(x_s, u_s). Shape [T, B]. + + Raises: + ValueError: If compiled=True, but log_rhos has rank other than 2. + """ + # Make sure tensor ranks are consistent. + rho_rank = log_rhos.get_shape().ndims # Usually 2. + q_values.get_shape().assert_has_rank(rho_rank) + values.get_shape().assert_has_rank(rho_rank) + bootstrap_value.get_shape().assert_has_rank(rho_rank - 1) + discounts.get_shape().assert_has_rank(rho_rank) + rewards.get_shape().assert_has_rank(rho_rank) + + lambda_ = tf.convert_to_tensor(lambda_, dtype=tf.float32) + + with tf.name_scope( + name, values=[log_rhos, discounts, rewards, values, bootstrap_value]): + rhos = tf.exp(log_rhos) + + cs = tf.minimum(1.0, rhos, name="cs") + + # Set the last c to 1. + cs = tf.concat([cs[1:], tf.ones_like(cs[-1:])], axis=0) + cs *= lambda_ + + # Append bootstrapped value to get [v1, ..., v_t+1] + values_t_plus_1 = tf.concat( + [values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0) + + # delta_t = (r_t + discount * V(x_{t+1}) - Q(x_t, a_t)) + deltas = (rewards + discounts * values_t_plus_1 - q_values) + + # Note that all sequences are reversed, computation starts from the back. + sequences = ( + tf.reverse(discounts, axis=[0]), + tf.reverse(cs, axis=[0]), + tf.reverse(deltas, axis=[0]), + ) + + # Re-trace vs are calculated through a scan from the back to the beginning + # of the given trajectory. + def scanfunc(acc, sequence_item): + discount_t, c_t, delta_t = sequence_item + return delta_t + discount_t * c_t * acc + + initial_values = tf.zeros_like(bootstrap_value) + vs_minus_q_xs = tf.scan( + fn=scanfunc, + elems=sequences, + initializer=initial_values, + parallel_iterations=1, + back_prop=False, + name="scan") + # Reverse the results back to original order. + vs_minus_q_xs = tf.reverse(vs_minus_q_xs, [0], name="vs_minus_q_xs") + + # Add V(x_s) to get q targets. + qs = tf.add(vs_minus_q_xs, q_values, name="s") + + result = QTraceReturns( + qs=tf.stop_gradient(qs), + importance_weights=tf.stop_gradient(rhos), + log_importance_weights=tf.stop_gradient(log_rhos), + truncated_importance_weights=tf.stop_gradient(cs), + deltas=tf.stop_gradient(deltas), + vs_minus_q_xs=tf.stop_gradient(vs_minus_q_xs)) + return result diff --git a/trfl/continuous_retrace_ops_test.py b/trfl/continuous_retrace_ops_test.py new file mode 100644 index 0000000..b062dc2 --- /dev/null +++ b/trfl/continuous_retrace_ops_test.py @@ -0,0 +1,126 @@ +# Copyright 2018 The trfl Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for continuous_retrace_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +import numpy as np +import six +from six.moves import range +import tensorflow.compat.v1 as tf + +from trfl import continuous_retrace_ops + + +def _shaped_arange(*shape): + """Runs np.arange, converts to float and reshapes.""" + return np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) + + +def _ground_truth_calculation(discounts, log_rhos, rewards, q_values, values, + bootstrap_value, lambda_): + """Calculates the ground truth for Retrace in python/numpy.""" + qs = [] + seq_len = len(discounts) + rhos = np.exp(log_rhos) + cs = np.minimum(rhos, 1.0) + cs *= lambda_ + # This is a very inefficient way to calculate the Retrace ground truth. + values_t_plus_1 = np.concatenate([values, bootstrap_value[None, :]], axis=0) + for s in range(seq_len): + q_s = np.copy(q_values[s]) # Very important copy... + delta = rewards[s] + discounts[s] * values_t_plus_1[s + 1] - q_values[s] + q_s += delta + for t in range(s + 1, seq_len): + q_s += ( + np.prod(discounts[s:t], axis=0) * np.prod(cs[s + 1:t + 1], axis=0) * + (rewards[t] + discounts[t] * values_t_plus_1[t + 1] - q_values[t])) + qs.append(q_s) + qs = np.stack(qs, axis=0) + return qs + + +class ContinuousRetraceTest(tf.test.TestCase): + + def testSingleElem(self): + """Tests Retrace with a single element batch and lambda set to 1.0.""" + batch_size = 1 + lambda_ = 1.0 + self._main_test(batch_size, lambda_) + + def testLargerBatch(self): + """Tests Retrace with a larger batch.""" + batch_size = 2 + lambda_ = 1.0 + self._main_test(batch_size, lambda_) + + def testLowerLambda(self): + """Tests Retrace with a lower lambda.""" + batch_size = 2 + lambda_ = 0.5 + self._main_test(batch_size, lambda_) + + def _main_test(self, batch_size, lambda_): + """Tests Retrace against ground truth data calculated in python.""" + seq_len = 5 + # Create log_rhos such that rho will span from near-zero to above the + # clipping thresholds. In particular, calculate log_rhos in [-2.5, 2.5), + # so that rho is in approx [0.08, 12.2). + log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len) + log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5). + values = { + "discounts": + np.array( # T, B where B_i: [0.9 / (i+1)] * T + [[0.9 / (b + 1) + for b in range(batch_size)] + for _ in range(seq_len)]), + "rewards": + _shaped_arange(seq_len, batch_size), + "q_values": + _shaped_arange(seq_len, batch_size) / batch_size, + "values": + _shaped_arange(seq_len, batch_size) / batch_size, + "bootstrap_value": + _shaped_arange(batch_size) + 1.0, # B + "log_rhos": + log_rhos + } + placeholders = { + key: tf.placeholder(tf.float32, shape=val.shape) + for key, val in six.iteritems(values) + } + placeholders = { + k: tf.placeholder(dtype=p.dtype, shape=[None] * len(p.shape)) + for k, p in placeholders.items() + } + + retrace_returns = continuous_retrace_ops.retrace_from_importance_weights( + lambda_=lambda_, **placeholders) + + feed_dict = {placeholders[k]: v for k, v in values.items()} + with self.test_session() as sess: + retrace_outputvalues = sess.run(retrace_returns, feed_dict=feed_dict) + + ground_truth_data = _ground_truth_calculation(lambda_=lambda_, **values) + + self.assertAllClose(ground_truth_data, retrace_outputvalues.qs) + + +if __name__ == "__main__": + tf.test.main()