diff --git a/deepspeed/runtime/fp16/onebit/lamb.py b/deepspeed/runtime/fp16/onebit/lamb.py index 89b6f40a308c..9e7bae816ecd 100644 --- a/deepspeed/runtime/fp16/onebit/lamb.py +++ b/deepspeed/runtime/fp16/onebit/lamb.py @@ -177,7 +177,7 @@ def step(self, closure=None, grads=None): # This is used to reduce compression error during compression stage. momentum_scales = [] for group in self.param_groups: - momentum_scales.append([(torch.linalg.norm(self.state[p]['exp_avg']) / + momentum_scales.append([(torch.linalg.vector_norm(self.state[p]['exp_avg']) / np.sqrt(torch.numel(self.state[p]['exp_avg']))).item() for p in group['params']]) united_scale = sum([sum(x) for x in momentum_scales]) / sum([len(x) for x in momentum_scales])