File size: 2,464 Bytes
c02bdcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import logging
from typing import Tuple

import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm

from modeling_llama import LlamaModel, LlamaConfig


class GPT(nn.Module):
    def __init__(
        self,
        gpt_config: dict,
        num_audio_tokens: int = 626,
        num_text_tokens: int = 21178,
        num_vq=4,
        use_flash_attn=False,
        device=torch.device("cpu"),
        logger=logging.getLogger(__name__),
    ):
        super().__init__()

        self.logger = logger

        self.device = device
        self.device_gpt = device if "mps" not in str(device) else torch.device("cpu")

        self.num_vq = num_vq
        self.num_audio_tokens = num_audio_tokens

        self.use_flash_attn = use_flash_attn

        self.gpt, self.llama_config = self._build_llama(gpt_config, self.device_gpt)
        self.is_te_llama = False
        self.model_dim = int(self.gpt.config.hidden_size)
        self.emb_code = nn.ModuleList(
            [
                nn.Embedding(
                    num_audio_tokens,
                    self.model_dim,
                    device=self.device_gpt,
                )
                for _ in range(num_vq)
            ],
        )
        self.emb_text = nn.Embedding(
            num_text_tokens, self.model_dim, device=self.device_gpt
        )

        self.head_text = weight_norm(
            nn.Linear(
                self.model_dim,
                num_text_tokens,
                bias=False,
                device=device,
            ),
            name="weight",
        )
        self.head_code = nn.ModuleList(
            [
                weight_norm(
                    nn.Linear(
                        self.model_dim,
                        num_audio_tokens,
                        bias=False,
                        device=device,
                    ),
                    name="weight",
                )
                for _ in range(self.num_vq)
            ],
        )

    def from_pretrained(self, file_path: str):
        self.load_state_dict(
            torch.load(file_path, weights_only=True, mmap=True), strict=False
        )

    def _build_llama(
        self,
        config: dict,
        device: torch.device,
    ) -> Tuple[LlamaModel, LlamaConfig]:

        llama_config = LlamaConfig(**config)
        model = LlamaModel(llama_config)
        del model.embed_tokens
        return model.to(device), llama_config