-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Perform gradient clipping on global batch when using gradient accumulation #6
Open
ashors1
wants to merge
9
commits into
google:main
Choose a base branch
from
ashors1:ga_grad_clip
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 7 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
a8dfeeb
perform gradient clipping on global batch when using ShardedStaticAcc…
ashors1 4380135
remove AUTHORS file
ashors1 08e4292
minor refactor, do not return grad_scale
ashors1 400cb40
Merge branch 'main' of github.com:ashors1/praxis into ga_grad_clip
ashors1 42932ea
fix indent
ashors1 54bdc12
fix formatting, small ga bug fix
ashors1 44c67f7
sync with upstream
ashors1 d5051c1
Merge branch 'main' of github.com:ashors1/praxis into ga_grad_clip
ashors1 40a6d80
address PR comments
ashors1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2713,6 +2713,8 @@ def _get_raw_grad_transformation(self, lr: optax.Schedule): | |
|
||
def sharded_static_accumulation( | ||
num_sub_batches: int, | ||
clip_gradient_norm_to_value: float, | ||
clip_gradient_single_norm_to_value: float, | ||
base_tx: ShardedGradientTransformation, | ||
) -> ShardedGradientTransformation: | ||
"""Gradient transformation for ShardedStaticAccumulator optimizer.""" | ||
|
@@ -2781,10 +2783,54 @@ def update_fn(updates: NestedJTensor, | |
lambda: new_count) | ||
|
||
def _run_base_tx(): | ||
|
||
def _compute_grad_norm(grads: NestedMap) -> JTensor: | ||
"""Computes total grad norm.""" | ||
grad_norms_squared = jax.tree_map(lambda x: jnp.sum(x * x), grads) | ||
grad_norms_squared, _ = jax.tree_util.tree_flatten(grad_norms_squared) | ||
return jnp.sqrt(jnp.sum(jnp.stack(grad_norms_squared))) | ||
|
||
|
||
def scale_gradients( | ||
raw_grads: NestedMap, | ||
clip_grad_norm_to_value: Optional[float] = None, | ||
clip_grad_single_norm_to_value: Optional[float] = None): | ||
|
||
def clip_grads(grads): | ||
if clip_grad_norm_to_value: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe assert only one of them is true? |
||
assert clip_grad_single_norm_to_value == 0. | ||
|
||
grad_norm = _compute_grad_norm(raw_grads) | ||
|
||
grad_scale = jnp.minimum( | ||
jnp.array(1, grad_norm.dtype), | ||
jnp.array(clip_grad_norm_to_value, grad_norm.dtype) | ||
/ grad_norm) | ||
grads = jax.tree_map(lambda g: g * grad_scale, grads) | ||
elif clip_grad_single_norm_to_value: | ||
assert clip_grad_norm_to_value == 0. | ||
grad_single_norm = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), | ||
grads) | ||
|
||
def scale_gradient(grad, norm): | ||
return grad * jnp.minimum( | ||
jnp.array(1, norm.dtype), | ||
jnp.array(clip_grad_single_norm_to_value, | ||
norm.dtype) / norm) | ||
grads = jax.tree_map(scale_gradient, grads, grad_single_norm) | ||
|
||
return grads | ||
|
||
grads = clip_grads(raw_grads) | ||
return grads | ||
|
||
averaged_updated = jax.tree_map(lambda acc: acc / num_sub_batches, | ||
new_accumulated_update) | ||
scaled_updated = scale_gradients(averaged_updated, | ||
clip_gradient_norm_to_value, | ||
clip_gradient_single_norm_to_value) | ||
emission_updates, emission_base_state = base_tx.update( | ||
averaged_updated, state.base_state, params) # pytype: disable=attribute-error # jax-ndarray | ||
scaled_updated, state.base_state, params) # pytype: disable=attribute-error # jax-ndarray | ||
return (emission_updates, | ||
jax.tree_map(lambda u: jnp.zeros_like(u, dtype=jnp.float32), | ||
updates), emission_base_state) | ||
|
@@ -2855,4 +2901,5 @@ def _get_raw_grad_transformation( | |
self, lr: optax.Schedule) -> GeneralGradientTransformation: | ||
p = self._hparams | ||
base_tx = self.base_optimizer._get_raw_grad_transformation(lr) # pylint: disable=protected-access | ||
return sharded_static_accumulation(p.num_sub_batches, base_tx) | ||
return sharded_static_accumulation(p.num_sub_batches, p.clip_gradient_norm_to_value, | ||
p.clip_gradient_single_norm_to_value, base_tx) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking at praxis optimizers, clip_gradient_norm_to_value and clip_gradient_single_norm_to_value default are 0.0 and not None right?
so perhaps the types here should be float and default 0.0 instead of Optional?