File size: 2,642 Bytes
6602da5
94aaa8e
 
 
 
 
 
 
 
c14f863
94aaa8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f11bd15
 
 
 
 
 
 
 
 
 
 
 
 
 
94aaa8e
6602da5
 
 
 
 
 
 
1a5fabb
6602da5
eb80782
 
 
1a5fabb
c14f863
6602da5
c14f863
c8fa7f3
c14f863
94aaa8e
 
c14f863
 
 
94aaa8e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from torch import nn
from transformers import (
    AutoConfig,
    AutoModel,
    AutoModelForCausalLM,
    PretrainedConfig,
    PreTrainedModel,
)
from collections import namedtuple

from .llama import CustomAttentionLLaMa


class MyLLaMaConfig(PretrainedConfig):
    model_type = "LLaMa"

    def __init__(
        self,
        embed_dim: int = 1536,
        n_layers: int = 24,
        n_heads: int = 24,
        n_chckpnt_segments: int = 24,
        **kwargs,
    ):
        self.embed_dim = embed_dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.n_chckpnt_segments = n_chckpnt_segments
        super().__init__(**kwargs)


class MyLLaMa(PreTrainedModel):
    config_class = MyLLaMaConfig

    def __init__(self, config: MyLLaMaConfig):
        super().__init__(config)
        self.model = CustomAttentionLLaMa(
            config.embed_dim,
            config.n_layers,
            config.n_heads,
            dropout=0,
            n_chckpnt_segments=config.n_chckpnt_segments,
        )

    def load_state_dict(self, state_dict, **kwargs):
        for key in list(state_dict.keys()):
            if "rmsnorm1.weight" in key:
                new_key = key.replace("rmsnorm1.weight", "rmsnorm1.gamma")
                state_dict[new_key] = state_dict.pop(key)
            elif "rmsnorm2.weight" in key:
                new_key = key.replace("rmsnorm2.weight", "rmsnorm2.gamma")
                state_dict[new_key] = state_dict.pop(key)
            elif "rmsnorm.weight" in key:
                new_key = key.replace("rmsnorm.weight", "rmsnorm.gamma")
                state_dict[new_key] = state_dict.pop(key)

        super().load_state_dict(state_dict, **kwargs)

    def forward(self, tensor, labels=None):
        att_mask = (
            torch.where(
                torch.triu(torch.ones((tensor.shape[1], tensor.shape[1]))) == 1,
                0,
                -torch.inf,
            )
            .transpose(0, 1)
            .to(self.model.embed.weight.device)
        )

        pad_mask = torch.where(
            tensor == self.model.tokenizer.pad_token_id, False, True
        ).to(self.model.embed.weight.device)
        outs = namedtuple("output", ["logits", "loss"])
        logits = self.model(tensor, att_mask, pad_mask)["logits"]

        outs.logits = logits.transpose(1, 2)

        if labels is not None:
            loss = nn.functional.cross_entropy(logits, labels)
            outs.loss = loss

        return outs


AutoConfig.register("LLaMa", MyLLaMaConfig)
AutoModel.register(MyLLaMaConfig, MyLLaMa)
AutoModelForCausalLM.register(MyLLaMaConfig, MyLLaMa)