Skip to content

Commit

Permalink
use getattr for .optim and allow variable num of args (#12184)
Browse files Browse the repository at this point in the history
* use getattr for .optim and allow variable num of args

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* pylint

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa and akoumpa authored Feb 14, 2025
1 parent 7a00886 commit 0eb9e5d
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions nemo/lightning/pytorch/callbacks/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,19 @@


def collect_precision(tensor: torch.Tensor) -> Dict[str, str]:
return {"Precision": str(tensor.dtype)}
"""Returns tensor's precision"""
if isinstance(tensor, torch.Tensor):
return {"Precision": str(tensor.dtype)}
else:
return {"Precision": "not-a-tensor"}


def collect_precision_and_shape(tensor: torch.Tensor) -> Dict[str, str]:
return {"Shape": str(tensor.shape), "Precision": str(tensor.dtype)}
"""Returns tensor's shape & precision"""
if isinstance(tensor, torch.Tensor):
return {"Shape": str(tensor.shape), "Precision": str(tensor.dtype)}
else:
return {"Shape": "not-a-tensor", "Precision": "not-a-tensor"}


class ParameterDebugger(Callback):
Expand Down Expand Up @@ -106,20 +114,20 @@ def __init__(
if isinstance(log_on_hooks, str):
log_on_hooks = [log_on_hooks]
for hook_name in log_on_hooks:
assert (
hook_name in valid_hooks
), f"Hook {hook_name} supplied to log_on_hooks is not valid or can not be used. Valid hooks are {valid_hooks}"
assert hook_name in valid_hooks, (
"Hook {} supplied to log_on_hooks is not valid or " "can not be used. Valid hooks are {}"
).format(hook_name, valid_hooks)
setattr(self, hook_name, self._apply_user_funcs)

def _apply_user_funcs(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
def _apply_user_funcs(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs) -> None:
"""
Iterate over model parameters, find gradient tensor, apply and collect outputs of
param_fn and grad_fn, and log outputs in a table.
"""

def find_grad_tensor(param: torch.Tensor) -> Optional[torch.Tensor]:
"""If using MCore optimizer, search the grad buckets for param's grad tensor."""
if not isinstance(pl_module.optim, MegatronOptimizerModule):
if not isinstance(getattr(pl_module, 'optim', None), MegatronOptimizerModule):
return param.grad

for buf in pl_module.buffers:
Expand Down

0 comments on commit 0eb9e5d

Please sign in to comment.