From 0df22fbc211550b11f9f65ce0d8ea10529350125 Mon Sep 17 00:00:00 2001 From: Dinghao Zhou Date: Mon, 10 Feb 2025 21:41:14 +0800 Subject: [PATCH] rm LLM in fsdp (#2688) --- wenet/utils/fsdp_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/wenet/utils/fsdp_utils.py b/wenet/utils/fsdp_utils.py index 77ca195953..33871f6f0c 100644 --- a/wenet/utils/fsdp_utils.py +++ b/wenet/utils/fsdp_utils.py @@ -5,7 +5,6 @@ from torch.distributed.fsdp.wrap import (lambda_auto_wrap_policy, transformer_auto_wrap_policy) -from wenet.LLM.decoder import DecoderOnly from wenet.branchformer.encoder_layer import BranchformerEncoderLayer from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer @@ -92,8 +91,6 @@ def check_gradient_checkpoint(model): if model.decoder.gradient_checkpointing: model.decoder.gradient_checkpointing = False ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values()) - if isinstance(model.decoder, DecoderOnly): - ckpt_laye_types += [DecoderOnly] return tuple(ckpt_laye_types)