import logging from typing import Tuple import torch import torch.nn as nn from torch.nn.utils.parametrizations import weight_norm from modeling_llama import LlamaModel, LlamaConfig class GPT(nn.Module): def __init__( self, gpt_config: dict, num_audio_tokens: int = 626, num_text_tokens: int = 21178, num_vq=4, use_flash_attn=False, device=torch.device("cpu"), logger=logging.getLogger(__name__), ): super().__init__() self.logger = logger self.device = device self.device_gpt = device if "mps" not in str(device) else torch.device("cpu") self.num_vq = num_vq self.num_audio_tokens = num_audio_tokens self.use_flash_attn = use_flash_attn self.gpt, self.llama_config = self._build_llama(gpt_config, self.device_gpt) self.is_te_llama = False self.model_dim = int(self.gpt.config.hidden_size) self.emb_code = nn.ModuleList( [ nn.Embedding( num_audio_tokens, self.model_dim, device=self.device_gpt, ) for _ in range(num_vq) ], ) self.emb_text = nn.Embedding( num_text_tokens, self.model_dim, device=self.device_gpt ) self.head_text = weight_norm( nn.Linear( self.model_dim, num_text_tokens, bias=False, device=device, ), name="weight", ) self.head_code = nn.ModuleList( [ weight_norm( nn.Linear( self.model_dim, num_audio_tokens, bias=False, device=device, ), name="weight", ) for _ in range(self.num_vq) ], ) def from_pretrained(self, file_path: str): self.load_state_dict( torch.load(file_path, weights_only=True, mmap=True), strict=False ) def _build_llama( self, config: dict, device: torch.device, ) -> Tuple[LlamaModel, LlamaConfig]: llama_config = LlamaConfig(**config) model = LlamaModel(llama_config) del model.embed_tokens return model.to(device), llama_config