Skip to content

Commit

Permalink
rm LLM in fsdp (#2688)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Feb 10, 2025
1 parent 0e09274 commit 0df22fb
Showing 1 changed file with 0 additions and 3 deletions.
3 changes: 0 additions & 3 deletions wenet/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from torch.distributed.fsdp.wrap import (lambda_auto_wrap_policy,
transformer_auto_wrap_policy)
from wenet.LLM.decoder import DecoderOnly
from wenet.branchformer.encoder_layer import BranchformerEncoderLayer
from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer
from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer
Expand Down Expand Up @@ -92,8 +91,6 @@ def check_gradient_checkpoint(model):
if model.decoder.gradient_checkpointing:
model.decoder.gradient_checkpointing = False
ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values())
if isinstance(model.decoder, DecoderOnly):
ckpt_laye_types += [DecoderOnly]
return tuple(ckpt_laye_types)


Expand Down

0 comments on commit 0df22fb

Please sign in to comment.