Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform gradient clipping on global batch when using gradient accumulation #6

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
53 changes: 51 additions & 2 deletions praxis/optimizers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 Google LLC.
# Copyright 2022 The Pax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -2707,6 +2707,8 @@ def _get_raw_grad_transformation(self, lr: optax.Schedule):

def sharded_static_accumulation(
num_sub_batches: int,
clip_gradient_norm_to_value: float,
clip_gradient_single_norm_to_value: float,
base_tx: ShardedGradientTransformation,
) -> ShardedGradientTransformation:
"""Gradient transformation for ShardedStaticAccumulator optimizer."""
Expand Down Expand Up @@ -2775,8 +2777,54 @@ def update_fn(updates: NestedJTensor,
lambda: new_count)

def _run_base_tx():

def _compute_grad_norm(grads: NestedMap) -> JTensor:
"""Computes total grad norm."""
grad_norms_squared = jax.tree_map(lambda x: jnp.sum(x * x), grads)
grad_norms_squared, _ = jax.tree_util.tree_flatten(grad_norms_squared)
return jnp.sqrt(jnp.sum(jnp.stack(grad_norms_squared)))


def scale_gradients(
raw_grads: NestedMap,
clip_grad_norm_to_value: Optional[float] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking at praxis optimizers, clip_gradient_norm_to_value and clip_gradient_single_norm_to_value default are 0.0 and not None right?

so perhaps the types here should be float and default 0.0 instead of Optional?

clip_grad_single_norm_to_value: Optional[float] = None):

def clip_grads(grads, grad_norm):
if clip_grad_norm_to_value:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe assert only one of them is true?

assert clip_grad_single_norm_to_value == 0.
grad_scale = jnp.minimum(
jnp.array(1, grad_norm.dtype),
jnp.array(clip_grad_norm_to_value, grad_norm.dtype)
/ grad_norm)
grads = jax.tree_map(lambda g: g * grad_scale, grads)
elif clip_grad_single_norm_to_value:
assert clip_grad_norm_to_value == 0.
grad_single_norm = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)),
grads)

def scale_gradient(grad, norm):
return grad * jnp.minimum(
jnp.array(1, norm.dtype),
jnp.array(clip_grad_single_norm_to_value,
norm.dtype) / norm)
grads = jax.tree_map(scale_gradient, grads, grad_single_norm)
grad_scale = jnp.array(1.0)
else:
# no clipping is needed.
grad_scale = jnp.array(1.0)
return grads, grad_scale

raw_grad_norm = _compute_grad_norm(raw_grads)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc, if clip_grad_single_norm_to_value is True, then raw_grad_norm is not used and we have to compute grad_single_norm separately anyways?

can we move the if-elif-else statement inside out and avoid redundant computation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely. I have addressed this with my latest commit


grads, grad_scale = clip_grads(raw_grads, raw_grad_norm)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to compute and return grad_scale?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed. I no longer return grad_scale with the latest commit

return grads

averaged_updated = jax.tree_map(lambda acc: acc / num_sub_batches,
new_accumulated_update)
averaged_updated = scale_gradients(averaged_updated,
clip_gradient_norm_to_value,
clip_gradient_single_norm_to_value)
emission_updates, emission_base_state = base_tx.update(
averaged_updated, state.base_state, params)
return (emission_updates,
Expand Down Expand Up @@ -2849,4 +2897,5 @@ def _get_raw_grad_transformation(
self, lr: optax.Schedule) -> GeneralGradientTransformation:
p = self._hparams
base_tx = self.base_optimizer._get_raw_grad_transformation(lr) # pylint: disable=protected-access
return sharded_static_accumulation(p.num_sub_batches, base_tx)
return sharded_static_accumulation(p.num_sub_batches, p.clip_gradient_norm_to_value,
p.clip_gradient_single_norm_to_value, base_tx)