from safetensors.torch import safe_open import torch import torch.nn as nn from torch.nn.utils.parametrizations import weight_norm class Embed(nn.Module): def __init__( self, hidden_size: int, num_audio_tokens: int, num_text_tokens: int, num_vq=4 ): super().__init__() self.num_vq = num_vq self.num_audio_tokens = num_audio_tokens self.model_dim = hidden_size self.emb_code = nn.ModuleList( [nn.Embedding(num_audio_tokens, self.model_dim) for _ in range(num_vq)], ) self.emb_text = nn.Embedding(num_text_tokens, self.model_dim) self.head_text = weight_norm( nn.Linear(self.model_dim, num_text_tokens, bias=False), name="weight", ) self.head_code = nn.ModuleList( [ weight_norm( nn.Linear(self.model_dim, num_audio_tokens, bias=False), name="weight", ) for _ in range(self.num_vq) ], ) @torch.inference_mode() def from_pretrained(self, filename: str, device: torch.device): state_dict_tensors = {} with safe_open(filename, framework="pt") as f: for k in f.keys(): state_dict_tensors[k] = f.get_tensor(k) self.load_state_dict(state_dict_tensors) self.to(device) def __call__( self, input_ids: torch.Tensor, text_mask: torch.Tensor ) -> torch.Tensor: """ get_emb """ return super().__call__(input_ids, text_mask) @torch.inference_mode() def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor: """ get_emb """ device = next(self.parameters()).device emb_text: torch.Tensor = self.emb_text( input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(device) ) text_mask_inv = text_mask.logical_not().to(device) masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(device) emb_code = [ self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq) ] emb_code = torch.stack(emb_code, 2).sum(2) emb = torch.zeros( (input_ids.shape[:-1]) + (emb_text.shape[-1],), device=emb_text.device, dtype=emb_text.dtype, ) emb[text_mask] = emb_text emb[text_mask_inv] = emb_code.to(emb.dtype) del emb_text, emb_code, text_mask_inv return emb