Commit
·
f8f40b7
1
Parent(s):
cfbb0ee
"test"
Browse files- config.json +2 -2
- model.py +32 -13
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"model_type": "
|
3 |
"architectures": ["GPT"],
|
4 |
"vocab_size": 50257,
|
5 |
"block_size": 128,
|
@@ -10,7 +10,7 @@
|
|
10 |
"bias": true,
|
11 |
"auto_map": {
|
12 |
"AutoConfig": "model.GPTConfig",
|
13 |
-
"
|
14 |
},
|
15 |
"torch_dtype": "float32",
|
16 |
"transformers_version": "4.42.0"
|
|
|
1 |
{
|
2 |
+
"model_type": "babylang",
|
3 |
"architectures": ["GPT"],
|
4 |
"vocab_size": 50257,
|
5 |
"block_size": 128,
|
|
|
10 |
"bias": true,
|
11 |
"auto_map": {
|
12 |
"AutoConfig": "model.GPTConfig",
|
13 |
+
"AutoModelForCausalLM": "model.GPT"
|
14 |
},
|
15 |
"torch_dtype": "float32",
|
16 |
"transformers_version": "4.42.0"
|
model.py
CHANGED
@@ -3,9 +3,10 @@ import torch.nn as nn
|
|
3 |
import torch.nn.functional as F
|
4 |
import math
|
5 |
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
6 |
|
7 |
class GPTConfig(PretrainedConfig):
|
8 |
-
model_type = "
|
9 |
|
10 |
def __init__(
|
11 |
self,
|
@@ -50,13 +51,20 @@ class CausalSelfAttention(nn.Module):
|
|
50 |
if not self.flash:
|
51 |
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
|
52 |
|
53 |
-
def forward(self, x):
|
54 |
B, T, C = x.size()
|
55 |
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
56 |
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
57 |
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
58 |
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
if self.flash:
|
61 |
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
|
62 |
else:
|
@@ -68,7 +76,7 @@ class CausalSelfAttention(nn.Module):
|
|
68 |
|
69 |
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
70 |
y = self.resid_dropout(self.c_proj(y))
|
71 |
-
return y
|
72 |
|
73 |
class MLP(nn.Module):
|
74 |
def __init__(self, config):
|
@@ -89,14 +97,14 @@ class Block(nn.Module):
|
|
89 |
self.ln2 = LayerNorm(config.n_embd, config.bias)
|
90 |
self.mlp = MLP(config)
|
91 |
|
92 |
-
def forward(self, x):
|
93 |
-
x = x + self.attn(self.ln1(x))
|
94 |
x = x + self.mlp(self.ln2(x))
|
95 |
return x
|
96 |
|
97 |
class GPT(PreTrainedModel):
|
98 |
config_class = GPTConfig
|
99 |
-
base_model_prefix = "
|
100 |
|
101 |
def __init__(self, config):
|
102 |
super().__init__(config)
|
@@ -123,26 +131,37 @@ class GPT(PreTrainedModel):
|
|
123 |
elif isinstance(module, nn.Embedding):
|
124 |
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
125 |
|
126 |
-
def forward(self, input_ids, labels=None):
|
127 |
device = input_ids.device
|
128 |
b, t = input_ids.size()
|
129 |
assert t <= self.config.block_size
|
130 |
pos = torch.arange(0, t, dtype=torch.long, device=device)
|
131 |
|
|
|
|
|
|
|
132 |
tok_emb = self.transformer.wte(input_ids)
|
133 |
pos_emb = self.transformer.wpe(pos)
|
134 |
x = self.transformer.drop(tok_emb + pos_emb)
|
135 |
-
for block in self.transformer.h:
|
136 |
-
x = block(x)
|
137 |
-
x = self.transformer.ln_f(x)
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
logits = self.lm_head(x)
|
140 |
|
|
|
141 |
if labels is not None:
|
142 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1)
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
146 |
|
147 |
@torch.no_grad()
|
148 |
def generate(self, input_ids, max_new_tokens, temperature=1.0, top_k=None):
|
|
|
3 |
import torch.nn.functional as F
|
4 |
import math
|
5 |
from transformers import PreTrainedModel, PretrainedConfig
|
6 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
7 |
|
8 |
class GPTConfig(PretrainedConfig):
|
9 |
+
model_type = "babylang"
|
10 |
|
11 |
def __init__(
|
12 |
self,
|
|
|
51 |
if not self.flash:
|
52 |
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
|
53 |
|
54 |
+
def forward(self, x, layer_past=None):
|
55 |
B, T, C = x.size()
|
56 |
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
57 |
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
58 |
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
59 |
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
60 |
|
61 |
+
if layer_past is not None:
|
62 |
+
past_key, past_value = layer_past
|
63 |
+
k = torch.cat((past_key, k), dim=-2)
|
64 |
+
v = torch.cat((past_value, v), dim=-2)
|
65 |
+
|
66 |
+
present = (k, v)
|
67 |
+
|
68 |
if self.flash:
|
69 |
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
|
70 |
else:
|
|
|
76 |
|
77 |
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
78 |
y = self.resid_dropout(self.c_proj(y))
|
79 |
+
return y, present
|
80 |
|
81 |
class MLP(nn.Module):
|
82 |
def __init__(self, config):
|
|
|
97 |
self.ln2 = LayerNorm(config.n_embd, config.bias)
|
98 |
self.mlp = MLP(config)
|
99 |
|
100 |
+
def forward(self, x, layer_past=None):
|
101 |
+
x = x + self.attn(self.ln1(x), layer_past=layer_past)
|
102 |
x = x + self.mlp(self.ln2(x))
|
103 |
return x
|
104 |
|
105 |
class GPT(PreTrainedModel):
|
106 |
config_class = GPTConfig
|
107 |
+
base_model_prefix = "babylang"
|
108 |
|
109 |
def __init__(self, config):
|
110 |
super().__init__(config)
|
|
|
131 |
elif isinstance(module, nn.Embedding):
|
132 |
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
133 |
|
134 |
+
def forward(self, input_ids, past_key_values=None, attention_mask=None, labels=None):
|
135 |
device = input_ids.device
|
136 |
b, t = input_ids.size()
|
137 |
assert t <= self.config.block_size
|
138 |
pos = torch.arange(0, t, dtype=torch.long, device=device)
|
139 |
|
140 |
+
if past_key_values is not None:
|
141 |
+
pos = pos[-1].unsqueeze(0)
|
142 |
+
|
143 |
tok_emb = self.transformer.wte(input_ids)
|
144 |
pos_emb = self.transformer.wpe(pos)
|
145 |
x = self.transformer.drop(tok_emb + pos_emb)
|
|
|
|
|
|
|
146 |
|
147 |
+
new_past_key_values = []
|
148 |
+
for i, block in enumerate(self.transformer.h):
|
149 |
+
x, past = block(x, layer_past=past_key_values[i] if past_key_values is not None else None)
|
150 |
+
new_past_key_values.append(past)
|
151 |
+
|
152 |
+
x = self.transformer.ln_f(x)
|
153 |
logits = self.lm_head(x)
|
154 |
|
155 |
+
loss = None
|
156 |
if labels is not None:
|
157 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1)
|
158 |
+
|
159 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_past_key_values)
|
160 |
+
|
161 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
162 |
+
if past_key_values:
|
163 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
164 |
+
return {"input_ids": input_ids, "past_key_values": past_key_values}
|
165 |
|
166 |
@torch.no_grad()
|
167 |
def generate(self, input_ids, max_new_tokens, temperature=1.0, top_k=None):
|