File size: 2,642 Bytes
6602da5 94aaa8e c14f863 94aaa8e f11bd15 94aaa8e 6602da5 1a5fabb 6602da5 eb80782 1a5fabb c14f863 6602da5 c14f863 c8fa7f3 c14f863 94aaa8e c14f863 94aaa8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import torch
from torch import nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
PretrainedConfig,
PreTrainedModel,
)
from collections import namedtuple
from .llama import CustomAttentionLLaMa
class MyLLaMaConfig(PretrainedConfig):
model_type = "LLaMa"
def __init__(
self,
embed_dim: int = 1536,
n_layers: int = 24,
n_heads: int = 24,
n_chckpnt_segments: int = 24,
**kwargs,
):
self.embed_dim = embed_dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_chckpnt_segments = n_chckpnt_segments
super().__init__(**kwargs)
class MyLLaMa(PreTrainedModel):
config_class = MyLLaMaConfig
def __init__(self, config: MyLLaMaConfig):
super().__init__(config)
self.model = CustomAttentionLLaMa(
config.embed_dim,
config.n_layers,
config.n_heads,
dropout=0,
n_chckpnt_segments=config.n_chckpnt_segments,
)
def load_state_dict(self, state_dict, **kwargs):
for key in list(state_dict.keys()):
if "rmsnorm1.weight" in key:
new_key = key.replace("rmsnorm1.weight", "rmsnorm1.gamma")
state_dict[new_key] = state_dict.pop(key)
elif "rmsnorm2.weight" in key:
new_key = key.replace("rmsnorm2.weight", "rmsnorm2.gamma")
state_dict[new_key] = state_dict.pop(key)
elif "rmsnorm.weight" in key:
new_key = key.replace("rmsnorm.weight", "rmsnorm.gamma")
state_dict[new_key] = state_dict.pop(key)
super().load_state_dict(state_dict, **kwargs)
def forward(self, tensor, labels=None):
att_mask = (
torch.where(
torch.triu(torch.ones((tensor.shape[1], tensor.shape[1]))) == 1,
0,
-torch.inf,
)
.transpose(0, 1)
.to(self.model.embed.weight.device)
)
pad_mask = torch.where(
tensor == self.model.tokenizer.pad_token_id, False, True
).to(self.model.embed.weight.device)
outs = namedtuple("output", ["logits", "loss"])
logits = self.model(tensor, att_mask, pad_mask)["logits"]
outs.logits = logits.transpose(1, 2)
if labels is not None:
loss = nn.functional.cross_entropy(logits, labels)
outs.loss = loss
return outs
AutoConfig.register("LLaMa", MyLLaMaConfig)
AutoModel.register(MyLLaMaConfig, MyLLaMa)
AutoModelForCausalLM.register(MyLLaMaConfig, MyLLaMa)
|