venkyyuvy commited on
Commit
b1d3adc
·
1 Parent(s): ab09bcf

init commit

Browse files
Files changed (4) hide show
  1. app.py +39 -0
  2. config.py +30 -0
  3. gpt.py +150 -0
  4. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torch
3
+ import gradio as gr
4
+ from gpt import GPTLanguageModel, encode, decode
5
+
6
+
7
+ with open('stoi_itos.pkl', 'rb') as file:
8
+ stoi, itos = pickle.load(file)
9
+
10
+ lm = GPTLanguageModel()
11
+ lm.load_state_dict(torch.load('shakespeare_lm.pt', map_location='cpu'))
12
+ lm.eval()
13
+
14
+
15
+ def inference(prompt: str):
16
+ encoded_prompt = torch.tensor(encode(prompt), dtype=torch.long)
17
+ output = decode(lm.generate(encoded_prompt, max_new_tokens=500)[0].tolist())
18
+ return output
19
+
20
+ gr_interface = gr.Interface(
21
+ inference,
22
+ inputs=[
23
+ gr.Textbox("man walking on the streets", label="Prompt"),
24
+ ],
25
+ outputs=[
26
+ gr.Textbox(
27
+ label="Generated story",
28
+ height="auto",
29
+ )
30
+ ],
31
+ title="Stories generated by a language model trained on Shakespeare's work",
32
+ examples=[
33
+ ["Sunrise rising"],
34
+ ["A big blast sound"]
35
+ ]
36
+ )
37
+ gr_interface.launch(debug=True)
38
+
39
+
config.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class BigramConfig:
6
+ batch_size = 32 # how many independent sequences will we process in parallel?
7
+ block_size = 8 # what is the maximum context length for predictions?
8
+ max_iters = 3000
9
+ eval_interval = 300
10
+ learning_rate = 1e-2
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ eval_iters = 200
13
+
14
+ @dataclass
15
+ class GPTConfig:
16
+ batch_size = 64 # how many independent sequences will we process in parallel?
17
+ block_size = 256 # what is the maximum context length for predictions?
18
+ max_iters = 5000
19
+ eval_interval = 500
20
+ learning_rate = 3e-4
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+ eval_iters = 200
23
+ n_embd = 384
24
+ n_head = 6
25
+ n_layer = 6
26
+ dropout = 0.2
27
+ save_path = "."
28
+
29
+ bigram_config = BigramConfig()
30
+ gpt_config = GPTConfig()
gpt.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ from config import gpt_config as config
8
+ # ------------
9
+
10
+ torch.manual_seed(1337)
11
+
12
+
13
+ encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
14
+ decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
15
+
16
+
17
+ class Head(nn.Module):
18
+ """ one head of self-attention """
19
+
20
+ def __init__(self, head_size):
21
+ super().__init__()
22
+ self.key = nn.Linear(config.n_embd, head_size, bias=False)
23
+ self.query = nn.Linear(config.n_embd, head_size, bias=False)
24
+ self.value = nn.Linear(config.n_embd, head_size, bias=False)
25
+ self.register_buffer('tril', torch.tril(torch.ones(config.block_size, config.block_size)))
26
+
27
+ self.dropout = nn.Dropout(config.dropout)
28
+
29
+ def forward(self, x):
30
+ # input of size (batch, time-step, channels)
31
+ # output of size (batch, time-step, head size)
32
+ B,T,C = x.shape
33
+ k = self.key(x) # (B,T,hs)
34
+ q = self.query(x) # (B,T,hs)
35
+ # compute attention scores ("affinities")
36
+ wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
37
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
38
+ wei = F.softmax(wei, dim=-1) # (B, T, T)
39
+ wei = self.dropout(wei)
40
+ # perform the weighted aggregation of the values
41
+ v = self.value(x) # (B,T,hs)
42
+ out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
43
+ return out
44
+
45
+ class MultiHeadAttention(nn.Module):
46
+ """ multiple heads of self-attention in parallel """
47
+
48
+ def __init__(self, num_heads, head_size):
49
+ super().__init__()
50
+ self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
51
+ self.proj = nn.Linear(head_size * num_heads, config.n_embd)
52
+ self.dropout = nn.Dropout(config.dropout)
53
+
54
+ def forward(self, x):
55
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
56
+ out = self.dropout(self.proj(out))
57
+ return out
58
+
59
+ class FeedFoward(nn.Module):
60
+ """ a simple linear layer followed by a non-linearity """
61
+
62
+ def __init__(self, n_embd):
63
+ super().__init__()
64
+ self.net = nn.Sequential(
65
+ nn.Linear(n_embd, 4 * n_embd),
66
+ nn.ReLU(),
67
+ nn.Linear(4 * n_embd, n_embd),
68
+ nn.Dropout(config.dropout),
69
+ )
70
+
71
+ def forward(self, x):
72
+ return self.net(x)
73
+
74
+ class Block(nn.Module):
75
+ """ Transformer block: communication followed by computation """
76
+
77
+ def __init__(self, n_embd, n_head):
78
+ # n_embd: embedding dimension, n_head: the number of heads we'd like
79
+ super().__init__()
80
+ head_size = n_embd // n_head
81
+ self.sa = MultiHeadAttention(n_head, head_size)
82
+ self.ffwd = FeedFoward(n_embd)
83
+ self.ln1 = nn.LayerNorm(n_embd)
84
+ self.ln2 = nn.LayerNorm(n_embd)
85
+
86
+ def forward(self, x):
87
+ x = x + self.sa(self.ln1(x))
88
+ x = x + self.ffwd(self.ln2(x))
89
+ return x
90
+
91
+ class GPTLanguageModel(nn.Module):
92
+
93
+ def __init__(self):
94
+ super().__init__()
95
+ # each token directly reads off the logits for the next token from a lookup table
96
+ self.token_embedding_table = nn.Embedding(vocab_size, config.n_embd)
97
+ self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd)
98
+ self.blocks = nn.Sequential(*[Block(config.n_embd, n_head=config.n_head) for _ in range(config.n_layer)])
99
+ self.ln_f = nn.LayerNorm(config.n_embd) # final layer norm
100
+ self.lm_head = nn.Linear(config.n_embd, vocab_size)
101
+
102
+ # better init, not covered in the original GPT video, but important, will cover in followup video
103
+ self.apply(self._init_weights)
104
+
105
+ def _init_weights(self, module):
106
+ if isinstance(module, nn.Linear):
107
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
108
+ if module.bias is not None:
109
+ torch.nn.init.zeros_(module.bias)
110
+ elif isinstance(module, nn.Embedding):
111
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
112
+
113
+ def forward(self, idx, targets=None):
114
+ B, T = idx.shape
115
+
116
+ # idx and targets are both (B,T) tensor of integers
117
+ tok_emb = self.token_embedding_table(idx) # (B,T,C)
118
+ pos_emb = self.position_embedding_table(torch.arange(T, device=config.device)) # (T,C)
119
+ x = tok_emb + pos_emb # (B,T,C)
120
+ x = self.blocks(x) # (B,T,C)
121
+ x = self.ln_f(x) # (B,T,C)
122
+ logits = self.lm_head(x) # (B,T,vocab_size)
123
+
124
+ if targets is None:
125
+ loss = None
126
+ else:
127
+ B, T, C = logits.shape
128
+ logits = logits.view(B*T, C)
129
+ targets = targets.view(B*T)
130
+ loss = F.cross_entropy(logits, targets)
131
+
132
+ return logits, loss
133
+
134
+ def generate(self, idx, max_new_tokens):
135
+ # idx is (B, T) array of indices in the current context
136
+ for _ in range(max_new_tokens):
137
+ # crop idx to the last block_size tokens
138
+ idx_cond = idx[:, -config.block_size:]
139
+ # get the predictions
140
+ logits, loss = self(idx_cond)
141
+ # focus only on the last time step
142
+ logits = logits[:, -1, :] # becomes (B, C)
143
+ # apply softmax to get probabilities
144
+ probs = F.softmax(logits, dim=-1) # (B, C)
145
+ # sample from the distribution
146
+ idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
147
+ # append sampled index to the running sequence
148
+ idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
149
+ return idx
150
+
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pytorch