diff --git a/wenet/utils/fsdp_utils.py b/wenet/utils/fsdp_utils.py index 77ca19595..33871f6f0 100644 --- a/wenet/utils/fsdp_utils.py +++ b/wenet/utils/fsdp_utils.py @@ -5,7 +5,6 @@ from torch.distributed.fsdp.wrap import (lambda_auto_wrap_policy, transformer_auto_wrap_policy) -from wenet.LLM.decoder import DecoderOnly from wenet.branchformer.encoder_layer import BranchformerEncoderLayer from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer @@ -92,8 +91,6 @@ def check_gradient_checkpoint(model): if model.decoder.gradient_checkpointing: model.decoder.gradient_checkpointing = False ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values()) - if isinstance(model.decoder, DecoderOnly): - ckpt_laye_types += [DecoderOnly] return tuple(ckpt_laye_types)