Exquisique commited on
Commit
f8f40b7
·
1 Parent(s): cfbb0ee
Files changed (2) hide show
  1. config.json +2 -2
  2. model.py +32 -13
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "model_type": "gpt",
3
  "architectures": ["GPT"],
4
  "vocab_size": 50257,
5
  "block_size": 128,
@@ -10,7 +10,7 @@
10
  "bias": true,
11
  "auto_map": {
12
  "AutoConfig": "model.GPTConfig",
13
- "AutoModel": "model.GPT"
14
  },
15
  "torch_dtype": "float32",
16
  "transformers_version": "4.42.0"
 
1
  {
2
+ "model_type": "babylang",
3
  "architectures": ["GPT"],
4
  "vocab_size": 50257,
5
  "block_size": 128,
 
10
  "bias": true,
11
  "auto_map": {
12
  "AutoConfig": "model.GPTConfig",
13
+ "AutoModelForCausalLM": "model.GPT"
14
  },
15
  "torch_dtype": "float32",
16
  "transformers_version": "4.42.0"
model.py CHANGED
@@ -3,9 +3,10 @@ import torch.nn as nn
3
  import torch.nn.functional as F
4
  import math
5
  from transformers import PreTrainedModel, PretrainedConfig
 
6
 
7
  class GPTConfig(PretrainedConfig):
8
- model_type = "gpt"
9
 
10
  def __init__(
11
  self,
@@ -50,13 +51,20 @@ class CausalSelfAttention(nn.Module):
50
  if not self.flash:
51
  self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
52
 
53
- def forward(self, x):
54
  B, T, C = x.size()
55
  q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
56
  k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
57
  q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
58
  v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
59
 
 
 
 
 
 
 
 
60
  if self.flash:
61
  y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
62
  else:
@@ -68,7 +76,7 @@ class CausalSelfAttention(nn.Module):
68
 
69
  y = y.transpose(1, 2).contiguous().view(B, T, C)
70
  y = self.resid_dropout(self.c_proj(y))
71
- return y
72
 
73
  class MLP(nn.Module):
74
  def __init__(self, config):
@@ -89,14 +97,14 @@ class Block(nn.Module):
89
  self.ln2 = LayerNorm(config.n_embd, config.bias)
90
  self.mlp = MLP(config)
91
 
92
- def forward(self, x):
93
- x = x + self.attn(self.ln1(x))
94
  x = x + self.mlp(self.ln2(x))
95
  return x
96
 
97
  class GPT(PreTrainedModel):
98
  config_class = GPTConfig
99
- base_model_prefix = "gpt"
100
 
101
  def __init__(self, config):
102
  super().__init__(config)
@@ -123,26 +131,37 @@ class GPT(PreTrainedModel):
123
  elif isinstance(module, nn.Embedding):
124
  nn.init.normal_(module.weight, mean=0.0, std=0.02)
125
 
126
- def forward(self, input_ids, labels=None):
127
  device = input_ids.device
128
  b, t = input_ids.size()
129
  assert t <= self.config.block_size
130
  pos = torch.arange(0, t, dtype=torch.long, device=device)
131
 
 
 
 
132
  tok_emb = self.transformer.wte(input_ids)
133
  pos_emb = self.transformer.wpe(pos)
134
  x = self.transformer.drop(tok_emb + pos_emb)
135
- for block in self.transformer.h:
136
- x = block(x)
137
- x = self.transformer.ln_f(x)
138
 
 
 
 
 
 
 
139
  logits = self.lm_head(x)
140
 
 
141
  if labels is not None:
142
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1)
143
- return {"loss": loss, "logits": logits}
144
- else:
145
- return {"logits": logits}
 
 
 
 
146
 
147
  @torch.no_grad()
148
  def generate(self, input_ids, max_new_tokens, temperature=1.0, top_k=None):
 
3
  import torch.nn.functional as F
4
  import math
5
  from transformers import PreTrainedModel, PretrainedConfig
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
 
8
  class GPTConfig(PretrainedConfig):
9
+ model_type = "babylang"
10
 
11
  def __init__(
12
  self,
 
51
  if not self.flash:
52
  self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
53
 
54
+ def forward(self, x, layer_past=None):
55
  B, T, C = x.size()
56
  q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
57
  k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
58
  q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
59
  v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
60
 
61
+ if layer_past is not None:
62
+ past_key, past_value = layer_past
63
+ k = torch.cat((past_key, k), dim=-2)
64
+ v = torch.cat((past_value, v), dim=-2)
65
+
66
+ present = (k, v)
67
+
68
  if self.flash:
69
  y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
70
  else:
 
76
 
77
  y = y.transpose(1, 2).contiguous().view(B, T, C)
78
  y = self.resid_dropout(self.c_proj(y))
79
+ return y, present
80
 
81
  class MLP(nn.Module):
82
  def __init__(self, config):
 
97
  self.ln2 = LayerNorm(config.n_embd, config.bias)
98
  self.mlp = MLP(config)
99
 
100
+ def forward(self, x, layer_past=None):
101
+ x = x + self.attn(self.ln1(x), layer_past=layer_past)
102
  x = x + self.mlp(self.ln2(x))
103
  return x
104
 
105
  class GPT(PreTrainedModel):
106
  config_class = GPTConfig
107
+ base_model_prefix = "babylang"
108
 
109
  def __init__(self, config):
110
  super().__init__(config)
 
131
  elif isinstance(module, nn.Embedding):
132
  nn.init.normal_(module.weight, mean=0.0, std=0.02)
133
 
134
+ def forward(self, input_ids, past_key_values=None, attention_mask=None, labels=None):
135
  device = input_ids.device
136
  b, t = input_ids.size()
137
  assert t <= self.config.block_size
138
  pos = torch.arange(0, t, dtype=torch.long, device=device)
139
 
140
+ if past_key_values is not None:
141
+ pos = pos[-1].unsqueeze(0)
142
+
143
  tok_emb = self.transformer.wte(input_ids)
144
  pos_emb = self.transformer.wpe(pos)
145
  x = self.transformer.drop(tok_emb + pos_emb)
 
 
 
146
 
147
+ new_past_key_values = []
148
+ for i, block in enumerate(self.transformer.h):
149
+ x, past = block(x, layer_past=past_key_values[i] if past_key_values is not None else None)
150
+ new_past_key_values.append(past)
151
+
152
+ x = self.transformer.ln_f(x)
153
  logits = self.lm_head(x)
154
 
155
+ loss = None
156
  if labels is not None:
157
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1)
158
+
159
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_past_key_values)
160
+
161
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
162
+ if past_key_values:
163
+ input_ids = input_ids[:, -1].unsqueeze(-1)
164
+ return {"input_ids": input_ids, "past_key_values": past_key_values}
165
 
166
  @torch.no_grad()
167
  def generate(self, input_ids, max_new_tokens, temperature=1.0, top_k=None):