-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Perform gradient clipping on global batch when using gradient accumulation #6
base: main
Are you sure you want to change the base?
Changes from 2 commits
a8dfeeb
4380135
08e4292
400cb40
42932ea
54bdc12
44c67f7
d5051c1
40a6d80
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
# coding=utf-8 | ||
# Copyright 2022 Google LLC. | ||
# Copyright 2022 The Pax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
|
@@ -2707,6 +2707,8 @@ def _get_raw_grad_transformation(self, lr: optax.Schedule): | |
|
||
def sharded_static_accumulation( | ||
num_sub_batches: int, | ||
clip_gradient_norm_to_value: float, | ||
clip_gradient_single_norm_to_value: float, | ||
base_tx: ShardedGradientTransformation, | ||
) -> ShardedGradientTransformation: | ||
"""Gradient transformation for ShardedStaticAccumulator optimizer.""" | ||
|
@@ -2775,8 +2777,54 @@ def update_fn(updates: NestedJTensor, | |
lambda: new_count) | ||
|
||
def _run_base_tx(): | ||
|
||
def _compute_grad_norm(grads: NestedMap) -> JTensor: | ||
"""Computes total grad norm.""" | ||
grad_norms_squared = jax.tree_map(lambda x: jnp.sum(x * x), grads) | ||
grad_norms_squared, _ = jax.tree_util.tree_flatten(grad_norms_squared) | ||
return jnp.sqrt(jnp.sum(jnp.stack(grad_norms_squared))) | ||
|
||
|
||
def scale_gradients( | ||
raw_grads: NestedMap, | ||
clip_grad_norm_to_value: Optional[float] = None, | ||
clip_grad_single_norm_to_value: Optional[float] = None): | ||
|
||
def clip_grads(grads, grad_norm): | ||
if clip_grad_norm_to_value: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe assert only one of them is true? |
||
assert clip_grad_single_norm_to_value == 0. | ||
grad_scale = jnp.minimum( | ||
jnp.array(1, grad_norm.dtype), | ||
jnp.array(clip_grad_norm_to_value, grad_norm.dtype) | ||
/ grad_norm) | ||
grads = jax.tree_map(lambda g: g * grad_scale, grads) | ||
elif clip_grad_single_norm_to_value: | ||
assert clip_grad_norm_to_value == 0. | ||
grad_single_norm = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), | ||
grads) | ||
ashors1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def scale_gradient(grad, norm): | ||
return grad * jnp.minimum( | ||
jnp.array(1, norm.dtype), | ||
jnp.array(clip_grad_single_norm_to_value, | ||
norm.dtype) / norm) | ||
grads = jax.tree_map(scale_gradient, grads, grad_single_norm) | ||
grad_scale = jnp.array(1.0) | ||
else: | ||
# no clipping is needed. | ||
grad_scale = jnp.array(1.0) | ||
return grads, grad_scale | ||
|
||
raw_grad_norm = _compute_grad_norm(raw_grads) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. iiuc, if can we move the if-elif-else statement inside out and avoid redundant computation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definitely. I have addressed this with my latest commit |
||
|
||
grads, grad_scale = clip_grads(raw_grads, raw_grad_norm) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to compute and return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not needed. I no longer return |
||
return grads | ||
|
||
averaged_updated = jax.tree_map(lambda acc: acc / num_sub_batches, | ||
new_accumulated_update) | ||
averaged_updated = scale_gradients(averaged_updated, | ||
ashors1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
clip_gradient_norm_to_value, | ||
clip_gradient_single_norm_to_value) | ||
emission_updates, emission_base_state = base_tx.update( | ||
averaged_updated, state.base_state, params) | ||
return (emission_updates, | ||
|
@@ -2849,4 +2897,5 @@ def _get_raw_grad_transformation( | |
self, lr: optax.Schedule) -> GeneralGradientTransformation: | ||
p = self._hparams | ||
base_tx = self.base_optimizer._get_raw_grad_transformation(lr) # pylint: disable=protected-access | ||
return sharded_static_accumulation(p.num_sub_batches, base_tx) | ||
return sharded_static_accumulation(p.num_sub_batches, p.clip_gradient_norm_to_value, | ||
p.clip_gradient_single_norm_to_value, base_tx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking at praxis optimizers, clip_gradient_norm_to_value and clip_gradient_single_norm_to_value default are 0.0 and not None right?
so perhaps the types here should be float and default 0.0 instead of Optional?