Skip to content

Commit

Permalink
Update lamb to be more clear since it does an L2 vector norm
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams committed Jan 17, 2025
1 parent 08b1f55 commit 8db1efb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion deepspeed/runtime/fp16/onebit/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def step(self, closure=None, grads=None):
# This is used to reduce compression error during compression stage.
momentum_scales = []
for group in self.param_groups:
momentum_scales.append([(torch.linalg.norm(self.state[p]['exp_avg']) /
momentum_scales.append([(torch.linalg.vector_norm(self.state[p]['exp_avg']) /
np.sqrt(torch.numel(self.state[p]['exp_avg']))).item()
for p in group['params']])
united_scale = sum([sum(x) for x in momentum_scales]) / sum([len(x) for x in momentum_scales])
Expand Down

0 comments on commit 8db1efb

Please sign in to comment.