File size: 2,543 Bytes
c02bdcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from safetensors.torch import safe_open
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm


class Embed(nn.Module):
    def __init__(
        self, hidden_size: int, num_audio_tokens: int, num_text_tokens: int, num_vq=4
    ):
        super().__init__()

        self.num_vq = num_vq
        self.num_audio_tokens = num_audio_tokens

        self.model_dim = hidden_size
        self.emb_code = nn.ModuleList(
            [nn.Embedding(num_audio_tokens, self.model_dim) for _ in range(num_vq)],
        )
        self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)

        self.head_text = weight_norm(
            nn.Linear(self.model_dim, num_text_tokens, bias=False),
            name="weight",
        )
        self.head_code = nn.ModuleList(
            [
                weight_norm(
                    nn.Linear(self.model_dim, num_audio_tokens, bias=False),
                    name="weight",
                )
                for _ in range(self.num_vq)
            ],
        )

    @torch.inference_mode()
    def from_pretrained(self, filename: str, device: torch.device):
        state_dict_tensors = {}
        with safe_open(filename, framework="pt") as f:
            for k in f.keys():
                state_dict_tensors[k] = f.get_tensor(k)
        self.load_state_dict(state_dict_tensors)
        self.to(device)

    def __call__(
        self, input_ids: torch.Tensor, text_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        get_emb
        """
        return super().__call__(input_ids, text_mask)

    @torch.inference_mode()
    def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor:
        """
        get_emb
        """
        device = next(self.parameters()).device
        emb_text: torch.Tensor = self.emb_text(
            input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(device)
        )

        text_mask_inv = text_mask.logical_not().to(device)
        masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(device)

        emb_code = [
            self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq)
        ]
        emb_code = torch.stack(emb_code, 2).sum(2)

        emb = torch.zeros(
            (input_ids.shape[:-1]) + (emb_text.shape[-1],),
            device=emb_text.device,
            dtype=emb_text.dtype,
        )
        emb[text_mask] = emb_text
        emb[text_mask_inv] = emb_code.to(emb.dtype)

        del emb_text, emb_code, text_mask_inv

        return emb