From 1ba4d3554d56757808bcf1665993dd4e7c3b787d Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Tue, 2 Apr 2024 21:32:44 +0000 Subject: [PATCH 1/2] chore: make using standardized advtange a choice --- stoix/configs/system/ff_awr.yaml | 1 + stoix/configs/system/ff_dpo.yaml | 1 + stoix/configs/system/ff_ppo.yaml | 1 + stoix/configs/system/rec_ppo.yaml | 1 + stoix/systems/awr/ff_awr.py | 7 ++++++- stoix/systems/awr/ff_awr_continuous.py | 7 ++++++- stoix/systems/ppo/ff_dpo_continuous.py | 7 ++++++- stoix/systems/ppo/ff_ppo.py | 7 ++++++- stoix/systems/ppo/ff_ppo_continuous.py | 7 ++++++- stoix/systems/ppo/rec_ppo.py | 7 ++++++- stoix/utils/loss.py | 2 -- stoix/utils/multistep.py | 5 +++++ 12 files changed, 45 insertions(+), 8 deletions(-) diff --git a/stoix/configs/system/ff_awr.yaml b/stoix/configs/system/ff_awr.yaml index 6c13a1ea..ef05f927 100644 --- a/stoix/configs/system/ff_awr.yaml +++ b/stoix/configs/system/ff_awr.yaml @@ -20,3 +20,4 @@ decay_learning_rates: False # Whether learning rates should be linearly decayed gae_lambda: 0.95 # The lambda parameter for the generalized advantage estimator. beta: 0.05 # The temperature of the exponentiated advantage weights. weight_clip: 20.0 # The maximum absolute value of the advantage weights. +standardize_advantages: True # Whether to standardize the advantages. diff --git a/stoix/configs/system/ff_dpo.yaml b/stoix/configs/system/ff_dpo.yaml index 665248fd..28af5574 100644 --- a/stoix/configs/system/ff_dpo.yaml +++ b/stoix/configs/system/ff_dpo.yaml @@ -16,5 +16,6 @@ ent_coef: 0.001 # Entropy regularisation term for loss function. vf_coef: 1.0 # Critic weight in max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update. decay_learning_rates: False # Whether learning rates should be linearly decayed during training. +standardize_advantages: True # Whether to standardize the advantages. alpha : 2.0 beta : 0.6 diff --git a/stoix/configs/system/ff_ppo.yaml b/stoix/configs/system/ff_ppo.yaml index 00a5102c..a59a6580 100644 --- a/stoix/configs/system/ff_ppo.yaml +++ b/stoix/configs/system/ff_ppo.yaml @@ -16,3 +16,4 @@ ent_coef: 0.001 # Entropy regularisation term for loss function. vf_coef: 1.0 # Critic weight in max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update. decay_learning_rates: False # Whether learning rates should be linearly decayed during training. +standardize_advantages: True # Whether to standardize the advantages. diff --git a/stoix/configs/system/rec_ppo.yaml b/stoix/configs/system/rec_ppo.yaml index a5a7f063..ec125161 100644 --- a/stoix/configs/system/rec_ppo.yaml +++ b/stoix/configs/system/rec_ppo.yaml @@ -16,6 +16,7 @@ ent_coef: 0.01 # Entropy regularisation term for loss function. vf_coef: 0.5 # Critic weight in max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update. decay_learning_rates: False # Whether learning rates should be linearly decayed during training. +standardize_advantages: True # Whether to standardize the advantages. # --- Recurrent hyperparameters --- recurrent_chunk_size: ~ # The size of the chunks in which the recurrent sequences are divided during the training process. diff --git a/stoix/systems/awr/ff_awr.py b/stoix/systems/awr/ff_awr.py index 58374a88..7e44e214 100644 --- a/stoix/systems/awr/ff_awr.py +++ b/stoix/systems/awr/ff_awr.py @@ -245,7 +245,12 @@ def _actor_loss_fn( r_t = sequence.reward[:, :-1] d_t = (1 - sequence.done.astype(jnp.float32)[:, :-1]) * config.system.gamma advantages, _ = batch_truncated_generalized_advantage_estimation( - r_t, d_t, config.system.gae_lambda, v_t, time_major=False + r_t, + d_t, + config.system.gae_lambda, + v_t, + time_major=False, + standardize_advantages=config.system.standardize_advantages, ) weights = jnp.exp(advantages / config.system.beta) weights = jnp.minimum(weights, config.system.weight_clip) diff --git a/stoix/systems/awr/ff_awr_continuous.py b/stoix/systems/awr/ff_awr_continuous.py index 4e7113d3..3db7f349 100644 --- a/stoix/systems/awr/ff_awr_continuous.py +++ b/stoix/systems/awr/ff_awr_continuous.py @@ -245,7 +245,12 @@ def _actor_loss_fn( r_t = sequence.reward[:, :-1] d_t = (1 - sequence.done.astype(jnp.float32)[:, :-1]) * config.system.gamma advantages, _ = batch_truncated_generalized_advantage_estimation( - r_t, d_t, config.system.gae_lambda, v_t, time_major=False + r_t, + d_t, + config.system.gae_lambda, + v_t, + time_major=False, + standardize_advantages=config.system.standardize_advantages, ) weights = jnp.exp(advantages / config.system.beta) weights = jnp.minimum(weights, config.system.weight_clip) diff --git a/stoix/systems/ppo/ff_dpo_continuous.py b/stoix/systems/ppo/ff_dpo_continuous.py index 7580cb12..8644e2f8 100644 --- a/stoix/systems/ppo/ff_dpo_continuous.py +++ b/stoix/systems/ppo/ff_dpo_continuous.py @@ -107,7 +107,12 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra d_t = 1.0 - traj_batch.done.astype(jnp.float32) d_t = (d_t * config.system.gamma).astype(jnp.float32) advantages, targets = batch_truncated_generalized_advantage_estimation( - r_t, d_t, config.system.gae_lambda, v_t, time_major=True + r_t, + d_t, + config.system.gae_lambda, + v_t, + time_major=True, + standardize_advantages=config.system.standardize_advantages, ) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: diff --git a/stoix/systems/ppo/ff_ppo.py b/stoix/systems/ppo/ff_ppo.py index 6deb95dc..c8d779ac 100644 --- a/stoix/systems/ppo/ff_ppo.py +++ b/stoix/systems/ppo/ff_ppo.py @@ -107,7 +107,12 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra d_t = 1.0 - traj_batch.done.astype(jnp.float32) d_t = (d_t * config.system.gamma).astype(jnp.float32) advantages, targets = batch_truncated_generalized_advantage_estimation( - r_t, d_t, config.system.gae_lambda, v_t, time_major=True + r_t, + d_t, + config.system.gae_lambda, + v_t, + time_major=True, + standardize_advantages=config.system.standardize_advantages, ) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: diff --git a/stoix/systems/ppo/ff_ppo_continuous.py b/stoix/systems/ppo/ff_ppo_continuous.py index 83d4ccbc..d741e18b 100644 --- a/stoix/systems/ppo/ff_ppo_continuous.py +++ b/stoix/systems/ppo/ff_ppo_continuous.py @@ -107,7 +107,12 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra d_t = 1.0 - traj_batch.done.astype(jnp.float32) d_t = (d_t * config.system.gamma).astype(jnp.float32) advantages, targets = batch_truncated_generalized_advantage_estimation( - r_t, d_t, config.system.gae_lambda, v_t, time_major=True + r_t, + d_t, + config.system.gae_lambda, + v_t, + time_major=True, + standardize_advantages=config.system.standardize_advantages, ) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: diff --git a/stoix/systems/ppo/rec_ppo.py b/stoix/systems/ppo/rec_ppo.py index 26996646..f572d41f 100644 --- a/stoix/systems/ppo/rec_ppo.py +++ b/stoix/systems/ppo/rec_ppo.py @@ -176,7 +176,12 @@ def _env_step( d_t = 1.0 - traj_batch.done.astype(jnp.float32) d_t = (d_t * config.system.gamma).astype(jnp.float32) advantages, targets = batch_truncated_generalized_advantage_estimation( - r_t, d_t, config.system.gae_lambda, v_t, time_major=True + r_t, + d_t, + config.system.gae_lambda, + v_t, + time_major=True, + standardize_advantages=config.system.standardize_advantages, ) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: diff --git a/stoix/utils/loss.py b/stoix/utils/loss.py index 9db1fb53..6423f7f9 100644 --- a/stoix/utils/loss.py +++ b/stoix/utils/loss.py @@ -15,7 +15,6 @@ def ppo_loss( pi_log_prob_t: chex.Array, b_pi_log_prob_t: chex.Array, gae_t: chex.Array, epsilon: float ) -> chex.Array: ratio = jnp.exp(pi_log_prob_t - b_pi_log_prob_t) - gae_t = (gae_t - gae_t.mean()) / (gae_t.std() + 1e-8) loss_actor1 = ratio * gae_t loss_actor2 = ( jnp.clip( @@ -38,7 +37,6 @@ def dpo_loss( beta: float, ) -> chex.Array: log_diff = pi_log_prob_t - b_pi_log_prob_t - gae_t = (gae_t - gae_t.mean()) / (gae_t.std() + 1e-8) ratio = jnp.exp(log_diff) is_pos = (gae_t >= 0.0).astype(jnp.float32) r1 = ratio - 1.0 diff --git a/stoix/utils/multistep.py b/stoix/utils/multistep.py index f1061f48..9f0b8724 100644 --- a/stoix/utils/multistep.py +++ b/stoix/utils/multistep.py @@ -16,6 +16,7 @@ def batch_truncated_generalized_advantage_estimation( values: chex.Array, stop_target_gradients: bool = True, time_major: bool = False, + standardize_advantages: bool = False, ) -> Tuple[chex.Array, chex.Array]: """Computes truncated generalized advantage estimates for a sequence length k. @@ -39,6 +40,7 @@ def batch_truncated_generalized_advantage_estimation( to targets. time_major: If True, the first dimension of the input tensors is the time dimension. + standardize_advantages: If True, standardize the advantages. Returns: Multistep truncated generalized advantage estimation at times [0, k-1]. @@ -84,6 +86,9 @@ def _body( lambda x: jax.lax.stop_gradient(x), (advantage_t, target_values) ) + if standardize_advantages: + advantage_t = jax.nn.standardize(advantage_t, axis=(0, 1)) + return advantage_t, target_values From 9d58ba638e2595da0ff32e24711417d4b567ebf3 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Tue, 2 Apr 2024 21:38:24 +0000 Subject: [PATCH 2/2] chore: revert awr --- stoix/configs/system/ff_awr.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stoix/configs/system/ff_awr.yaml b/stoix/configs/system/ff_awr.yaml index ef05f927..f554b2e0 100644 --- a/stoix/configs/system/ff_awr.yaml +++ b/stoix/configs/system/ff_awr.yaml @@ -20,4 +20,4 @@ decay_learning_rates: False # Whether learning rates should be linearly decayed gae_lambda: 0.95 # The lambda parameter for the generalized advantage estimator. beta: 0.05 # The temperature of the exponentiated advantage weights. weight_clip: 20.0 # The maximum absolute value of the advantage weights. -standardize_advantages: True # Whether to standardize the advantages. +standardize_advantages: False # Whether to standardize the advantages.