From 8db1efbfa968a8d596a1ebfe2bf6a78e546c0ea6 Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Fri, 17 Jan 2025 10:21:40 -0800 Subject: [PATCH] Update lamb to be more clear since it does an L2 vector norm --- deepspeed/runtime/fp16/onebit/lamb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/fp16/onebit/lamb.py b/deepspeed/runtime/fp16/onebit/lamb.py index 89b6f40a308c..9e7bae816ecd 100644 --- a/deepspeed/runtime/fp16/onebit/lamb.py +++ b/deepspeed/runtime/fp16/onebit/lamb.py @@ -177,7 +177,7 @@ def step(self, closure=None, grads=None): # This is used to reduce compression error during compression stage. momentum_scales = [] for group in self.param_groups: - momentum_scales.append([(torch.linalg.norm(self.state[p]['exp_avg']) / + momentum_scales.append([(torch.linalg.vector_norm(self.state[p]['exp_avg']) / np.sqrt(torch.numel(self.state[p]['exp_avg']))).item() for p in group['params']]) united_scale = sum([sum(x) for x in momentum_scales]) / sum([len(x) for x in momentum_scales])