Skip to content

Commit

Permalink
optimize(gpt): move to device_gpt
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Aug 25, 2024
1 parent a79d297 commit 8bc721c
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(

self.generator = torch.Generator(device=device)

self.config = gpt_config
self.num_vq = int(gpt_config["num_vq"])
self.num_audio_tokens = int(gpt_config["num_audio_tokens"])
self.num_text_tokens = int(gpt_config["num_text_tokens"])
Expand All @@ -51,7 +50,7 @@ def __init__(
if self.is_vllm:
return

self.gpt, self.llama_config = self._build_llama(gpt_config, self.device_gpt)
self.llama_config = self._build_llama_config(gpt_config)

self.emb_code = [ec.__call__ for ec in embed.emb_code]
self.emb_text = embed.emb_text.__call__
Expand All @@ -60,9 +59,8 @@ def __init__(

def from_pretrained(self, gpt_folder: str, embed_file_path: str, experimental=False):
if self.is_vllm and platform.system().lower() == "linux":
from safetensors.torch import save_file

from .velocity import LLM, PostModel
from .velocity import LLM

self.llm = LLM(
model=gpt_folder,
Expand All @@ -73,7 +71,8 @@ def from_pretrained(self, gpt_folder: str, embed_file_path: str, experimental=Fa
self.logger.info("vLLM model loaded")
return

self.gpt: LlamaModel = LlamaModel.from_pretrained(gpt_folder)
self.gpt: LlamaModel = LlamaModel.from_pretrained(gpt_folder).to(self.device_gpt)
del self.gpt.embed_tokens

if (
experimental
Expand Down Expand Up @@ -108,10 +107,9 @@ def set(self, v: bool):
def get(self) -> bool:
return self._interrupt

def _build_llama(
def _build_llama_config(
self,
config: dict,
device: torch.device,
) -> Tuple[LlamaModel, LlamaConfig]:

if self.use_flash_attn and is_flash_attn_2_available():
Expand All @@ -125,10 +123,7 @@ def _build_llama(
else:
llama_config = LlamaConfig(**config)

model = LlamaModel(llama_config)
del model.embed_tokens

return model.to(device), llama_config
return llama_config

def prepare(self, compile=False):
if self.use_flash_attn and is_flash_attn_2_available():
Expand Down

0 comments on commit 8bc721c

Please sign in to comment.