Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform gradient clipping on global batch when using gradient accumulation #6

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
51 changes: 49 additions & 2 deletions praxis/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2713,6 +2713,8 @@ def _get_raw_grad_transformation(self, lr: optax.Schedule):

def sharded_static_accumulation(
num_sub_batches: int,
clip_gradient_norm_to_value: float,
clip_gradient_single_norm_to_value: float,
base_tx: ShardedGradientTransformation,
) -> ShardedGradientTransformation:
"""Gradient transformation for ShardedStaticAccumulator optimizer."""
Expand Down Expand Up @@ -2781,10 +2783,54 @@ def update_fn(updates: NestedJTensor,
lambda: new_count)

def _run_base_tx():

def _compute_grad_norm(grads: NestedMap) -> JTensor:
"""Computes total grad norm."""
grad_norms_squared = jax.tree_map(lambda x: jnp.sum(x * x), grads)
grad_norms_squared, _ = jax.tree_util.tree_flatten(grad_norms_squared)
return jnp.sqrt(jnp.sum(jnp.stack(grad_norms_squared)))


def scale_gradients(
raw_grads: NestedMap,
clip_grad_norm_to_value: Optional[float] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking at praxis optimizers, clip_gradient_norm_to_value and clip_gradient_single_norm_to_value default are 0.0 and not None right?

so perhaps the types here should be float and default 0.0 instead of Optional?

clip_grad_single_norm_to_value: Optional[float] = None):

def clip_grads(grads):
if clip_grad_norm_to_value:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe assert only one of them is true?

assert clip_grad_single_norm_to_value == 0.

grad_norm = _compute_grad_norm(raw_grads)

grad_scale = jnp.minimum(
jnp.array(1, grad_norm.dtype),
jnp.array(clip_grad_norm_to_value, grad_norm.dtype)
/ grad_norm)
grads = jax.tree_map(lambda g: g * grad_scale, grads)
elif clip_grad_single_norm_to_value:
assert clip_grad_norm_to_value == 0.
grad_single_norm = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)),
grads)

def scale_gradient(grad, norm):
return grad * jnp.minimum(
jnp.array(1, norm.dtype),
jnp.array(clip_grad_single_norm_to_value,
norm.dtype) / norm)
grads = jax.tree_map(scale_gradient, grads, grad_single_norm)

return grads

grads = clip_grads(raw_grads)
return grads

averaged_updated = jax.tree_map(lambda acc: acc / num_sub_batches,
new_accumulated_update)
scaled_updated = scale_gradients(averaged_updated,
clip_gradient_norm_to_value,
clip_gradient_single_norm_to_value)
emission_updates, emission_base_state = base_tx.update(
averaged_updated, state.base_state, params) # pytype: disable=attribute-error # jax-ndarray
scaled_updated, state.base_state, params) # pytype: disable=attribute-error # jax-ndarray
return (emission_updates,
jax.tree_map(lambda u: jnp.zeros_like(u, dtype=jnp.float32),
updates), emission_base_state)
Expand Down Expand Up @@ -2855,4 +2901,5 @@ def _get_raw_grad_transformation(
self, lr: optax.Schedule) -> GeneralGradientTransformation:
p = self._hparams
base_tx = self.base_optimizer._get_raw_grad_transformation(lr) # pylint: disable=protected-access
return sharded_static_accumulation(p.num_sub_batches, base_tx)
return sharded_static_accumulation(p.num_sub_batches, p.clip_gradient_norm_to_value,
p.clip_gradient_single_norm_to_value, base_tx)