Exquisique commited on
Commit
d2544e1
·
1 Parent(s): 23516c0

Make GPT inherit from PreTrainedModel for HF compatibility

Browse files
Files changed (2) hide show
  1. README.md +0 -1
  2. model.py +42 -17
README.md CHANGED
@@ -13,7 +13,6 @@ library_name: transformers
13
  pipeline_tag: text-generation
14
  ---
15
 
16
- # 🍼 BabyLangModel
17
 
18
  # 🍼 BabyLangModel
19
 
 
13
  pipeline_tag: text-generation
14
  ---
15
 
 
16
 
17
  # 🍼 BabyLangModel
18
 
model.py CHANGED
@@ -2,6 +2,30 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import math
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  class LayerNorm(nn.Module):
7
  def __init__(self, ndim, bias):
@@ -70,10 +94,11 @@ class Block(nn.Module):
70
  x = x + self.mlp(self.ln2(x))
71
  return x
72
 
73
- class GPT(nn.Module):
 
 
74
  def __init__(self, config):
75
- super().__init__()
76
- self.config = config
77
  self.transformer = nn.ModuleDict(dict(
78
  wte=nn.Embedding(config.vocab_size, config.n_embd),
79
  wpe=nn.Embedding(config.block_size, config.n_embd),
@@ -97,37 +122,37 @@ class GPT(nn.Module):
97
  elif isinstance(module, nn.Embedding):
98
  nn.init.normal_(module.weight, mean=0.0, std=0.02)
99
 
100
- def forward(self, idx, targets=None):
101
- device = idx.device
102
- b, t = idx.size()
103
  assert t <= self.config.block_size
104
  pos = torch.arange(0, t, dtype=torch.long, device=device)
105
 
106
- tok_emb = self.transformer.wte(idx)
107
  pos_emb = self.transformer.wpe(pos)
108
  x = self.transformer.drop(tok_emb + pos_emb)
109
  for block in self.transformer.h:
110
  x = block(x)
111
  x = self.transformer.ln_f(x)
112
 
113
- if targets is not None:
114
  logits = self.lm_head(x)
115
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
116
- return logits, loss
117
  else:
118
  logits = self.lm_head(x[:, [-1], :])
119
- return logits, None
120
 
121
  @torch.no_grad()
122
- def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
123
  for _ in range(max_new_tokens):
124
- idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
125
- logits, _ = self(idx_cond)
126
- logits = logits[:, -1, :] / temperature
127
  if top_k is not None:
128
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
129
  logits[logits < v[:, [-1]]] = -float('Inf')
130
  probs = F.softmax(logits, dim=-1)
131
  idx_next = torch.multinomial(probs, num_samples=1)
132
- idx = torch.cat((idx, idx_next), dim=1)
133
- return idx
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import math
5
+ from transformers import PreTrainedModel, PretrainedConfig
6
+
7
+ class GPTConfig(PretrainedConfig):
8
+ model_type = "gpt"
9
+
10
+ def __init__(
11
+ self,
12
+ vocab_size=50257,
13
+ block_size=128,
14
+ n_layer=6,
15
+ n_head=6,
16
+ n_embd=384,
17
+ dropout=0.0,
18
+ bias=True,
19
+ **kwargs
20
+ ):
21
+ super().__init__(**kwargs)
22
+ self.vocab_size = vocab_size
23
+ self.block_size = block_size
24
+ self.n_layer = n_layer
25
+ self.n_head = n_head
26
+ self.n_embd = n_embd
27
+ self.dropout = dropout
28
+ self.bias = bias
29
 
30
  class LayerNorm(nn.Module):
31
  def __init__(self, ndim, bias):
 
94
  x = x + self.mlp(self.ln2(x))
95
  return x
96
 
97
+ class GPT(PreTrainedModel):
98
+ config_class = GPTConfig
99
+
100
  def __init__(self, config):
101
+ super().__init__(config)
 
102
  self.transformer = nn.ModuleDict(dict(
103
  wte=nn.Embedding(config.vocab_size, config.n_embd),
104
  wpe=nn.Embedding(config.block_size, config.n_embd),
 
122
  elif isinstance(module, nn.Embedding):
123
  nn.init.normal_(module.weight, mean=0.0, std=0.02)
124
 
125
+ def forward(self, input_ids, labels=None):
126
+ device = input_ids.device
127
+ b, t = input_ids.size()
128
  assert t <= self.config.block_size
129
  pos = torch.arange(0, t, dtype=torch.long, device=device)
130
 
131
+ tok_emb = self.transformer.wte(input_ids)
132
  pos_emb = self.transformer.wpe(pos)
133
  x = self.transformer.drop(tok_emb + pos_emb)
134
  for block in self.transformer.h:
135
  x = block(x)
136
  x = self.transformer.ln_f(x)
137
 
138
+ if labels is not None:
139
  logits = self.lm_head(x)
140
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1)
141
+ return {'logits': logits, 'loss': loss}
142
  else:
143
  logits = self.lm_head(x[:, [-1], :])
144
+ return {'logits': logits}
145
 
146
  @torch.no_grad()
147
+ def generate(self, input_ids, max_new_tokens, temperature=1.0, top_k=None):
148
  for _ in range(max_new_tokens):
149
+ idx_cond = input_ids if input_ids.size(1) <= self.config.block_size else input_ids[:, -self.config.block_size:]
150
+ out = self(idx_cond)
151
+ logits = out['logits'][:, -1, :] / temperature
152
  if top_k is not None:
153
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
154
  logits[logits < v[:, [-1]]] = -float('Inf')
155
  probs = F.softmax(logits, dim=-1)
156
  idx_next = torch.multinomial(probs, num_samples=1)
157
+ input_ids = torch.cat((input_ids, idx_next), dim=1)
158
+ return input_ids