rasbt commited on
Commit
14c86d1
·
verified ·
1 Parent(s): 8e95284

Delete .ipynb_checkpoints/model-checkpoint.py

Browse files
.ipynb_checkpoints/model-checkpoint.py DELETED
@@ -1,334 +0,0 @@
1
- # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
2
- # Source for "Build a Large Language Model From Scratch"
3
- # https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb
4
-
5
-
6
- import torch
7
- import torch.nn as nn
8
-
9
-
10
- LLAMA32_CONFIG_1B = {
11
- "vocab_size": 128_256, # Vocabulary size
12
- "context_length": 8192, # Maximum context length to use (reduced to save memory)
13
- "orig_context_length": 131_072, # Context length that was used to train the model
14
- "emb_dim": 2048, # Embedding dimension
15
- "n_heads": 32, # Number of attention heads
16
- "n_layers": 16, # Number of layers
17
- "hidden_dim": 8192, # Size of the intermediate dimension in FeedForward
18
- "n_kv_groups": 8, # Key-Value groups for grouped-query attention
19
- "rope_base": 500_000.0, # The base in RoPE's "theta"
20
- "dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
21
- "rope_freq": { # RoPE frequency scaling
22
- "factor": 32.0,
23
- "low_freq_factor": 1.0,
24
- "high_freq_factor": 4.0,
25
- "original_context_length": 8192,
26
- }
27
- }
28
-
29
- LLAMA32_CONFIG_3B = {
30
- "vocab_size": 128_256, # Vocabulary size
31
- "context_length": 8192, # Maximum context length to use (reduced to save memory)
32
- "orig_context_length": 131_072, # Context length that was used to train the model
33
- "emb_dim": 3072, # Embedding dimension
34
- "n_heads": 24, # Number of attention heads
35
- "n_layers": 28, # Number of layers
36
- "hidden_dim": 8192, # Size of the intermediate dimension in FeedForward
37
- "n_kv_groups": 8, # Key-Value groups for grouped-query attention
38
- "rope_base": 500_000.0, # The base in RoPE's "theta"
39
- "dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
40
- "rope_freq": { # RoPE frequency scaling
41
- "factor": 32.0,
42
- "low_freq_factor": 1.0,
43
- "high_freq_factor": 4.0,
44
- "original_context_length": 8192,
45
- }
46
- }
47
-
48
-
49
- class Llama3Model(nn.Module):
50
- def __init__(self, cfg):
51
- super().__init__()
52
-
53
- # Main model parameters
54
- self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
55
-
56
- self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
57
- [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
58
- )
59
-
60
- self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
61
- self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
62
-
63
- # Reusuable utilities
64
- self.register_buffer(
65
- "mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal=1).bool(),
66
- persistent=False
67
- )
68
-
69
- if cfg["orig_context_length"] != cfg["context_length"]:
70
- cfg["rope_base"] = rescale_theta(
71
- cfg["rope_base"],
72
- cfg["orig_context_length"],
73
- cfg["context_length"]
74
- )
75
- cos, sin = compute_rope_params(
76
- head_dim=cfg["emb_dim"] // cfg["n_heads"],
77
- theta_base=cfg["rope_base"],
78
- context_length=cfg["context_length"],
79
- freq_config=cfg["rope_freq"]
80
- )
81
- self.register_buffer("cos", cos, persistent=False)
82
- self.register_buffer("sin", sin, persistent=False)
83
- self.cfg = cfg
84
-
85
- def forward(self, in_idx):
86
- # Forward pass
87
- tok_embeds = self.tok_emb(in_idx)
88
- x = tok_embeds
89
-
90
- for block in self.trf_blocks:
91
- x = block(x, self.mask, self.cos, self.sin)
92
- x = self.final_norm(x)
93
- logits = self.out_head(x.to(self.cfg["dtype"]))
94
- return logits
95
-
96
-
97
- class TransformerBlock(nn.Module):
98
- def __init__(self, cfg):
99
- super().__init__()
100
- self.att = GroupedQueryAttention(
101
- d_in=cfg["emb_dim"],
102
- d_out=cfg["emb_dim"],
103
- num_heads=cfg["n_heads"],
104
- num_kv_groups=cfg["n_kv_groups"],
105
- dtype=cfg["dtype"]
106
- )
107
- self.ff = FeedForward(cfg)
108
- self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
109
- self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
110
-
111
- def forward(self, x, mask, cos, sin):
112
- # Shortcut connection for attention block
113
- shortcut = x
114
- x = self.norm1(x)
115
- x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]
116
- x = x + shortcut # Add the original input back
117
-
118
- # Shortcut connection for feed-forward block
119
- shortcut = x
120
- x = self.norm2(x)
121
- x = self.ff(x)
122
- x = x + shortcut # Add the original input back
123
-
124
- return x
125
-
126
-
127
- class FeedForward(nn.Module):
128
- def __init__(self, cfg):
129
- super().__init__()
130
- self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
131
- self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
132
- self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
133
-
134
- def forward(self, x):
135
- x_fc1 = self.fc1(x)
136
- x_fc2 = self.fc2(x)
137
- x = nn.functional.silu(x_fc1) * x_fc2
138
- return self.fc3(x)
139
-
140
-
141
- class GroupedQueryAttention(nn.Module):
142
- def __init__(
143
- self, d_in, d_out, num_heads,
144
- num_kv_groups,
145
- dtype=None
146
- ):
147
- super().__init__()
148
- assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
149
- assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
150
-
151
- self.d_out = d_out
152
- self.num_heads = num_heads
153
- self.head_dim = d_out // num_heads
154
-
155
- self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
156
- self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
157
- self.num_kv_groups = num_kv_groups
158
- self.group_size = num_heads // num_kv_groups
159
-
160
- self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
161
- self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
162
-
163
- def forward(self, x, mask, cos, sin):
164
- b, num_tokens, d_in = x.shape
165
-
166
- queries = self.W_query(x) # Shape: (b, num_tokens, d_out)
167
- keys = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
168
- values = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
169
-
170
- # Reshape queries, keys, and values
171
- queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
172
- keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)
173
- values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)
174
-
175
- # Transpose keys, values, and queries
176
- keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
177
- values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
178
- queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)
179
-
180
- # Apply RoPE
181
- keys = apply_rope(keys, cos, sin)
182
- queries = apply_rope(queries, cos, sin)
183
-
184
- # Expand keys and values to match the number of heads
185
- # Shape: (b, num_heads, num_tokens, head_dim)
186
- keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
187
- values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
188
- # For example, before repeat_interleave along dim=1 (query groups):
189
- # [K1, K2]
190
- # After repeat_interleave (each query group is repeated group_size times):
191
- # [K1, K1, K2, K2]
192
- # If we used regular repeat instead of repeat_interleave, we'd get:
193
- # [K1, K2, K1, K2]
194
-
195
- # Compute scaled dot-product attention (aka self-attention) with a causal mask
196
- # Shape: (b, num_heads, num_tokens, num_tokens)
197
- attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
198
-
199
- # Use the mask to fill attention scores
200
- attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf)
201
-
202
- attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
203
- assert keys.shape[-1] == self.head_dim
204
-
205
- # Shape: (b, num_tokens, num_heads, head_dim)
206
- context_vec = (attn_weights @ values).transpose(1, 2)
207
-
208
- # Combine heads, where self.d_out = self.num_heads * self.head_dim
209
- context_vec = context_vec.reshape(b, num_tokens, self.d_out)
210
- context_vec = self.out_proj(context_vec) # optional projection
211
-
212
- return context_vec
213
-
214
-
215
- def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):
216
- assert head_dim % 2 == 0, "Embedding dimension must be even"
217
-
218
- # Compute the inverse frequencies
219
- inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
220
-
221
- # Frequency adjustments
222
- if freq_config is not None:
223
- low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"]
224
- high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"]
225
-
226
- wavelen = 2 * torch.pi / inv_freq
227
-
228
- inv_freq_llama = torch.where(
229
- wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq
230
- )
231
-
232
- smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / (
233
- freq_config["high_freq_factor"] - freq_config["low_freq_factor"]
234
- )
235
-
236
- smoothed_inv_freq = (
237
- (1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq
238
- )
239
-
240
- is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)
241
- inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
242
- inv_freq = inv_freq_llama
243
-
244
- # Generate position indices
245
- positions = torch.arange(context_length, dtype=dtype)
246
-
247
- # Compute the angles
248
- angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
249
-
250
- # Expand angles to match the head_dim
251
- angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)
252
-
253
- # Precompute sine and cosine
254
- cos = torch.cos(angles)
255
- sin = torch.sin(angles)
256
-
257
- return cos, sin
258
-
259
-
260
- def apply_rope(x, cos, sin):
261
- # x: (batch_size, num_heads, seq_len, head_dim)
262
- batch_size, num_heads, seq_len, head_dim = x.shape
263
- assert head_dim % 2 == 0, "Head dimension must be even"
264
-
265
- # Split x into first half and second half
266
- x1 = x[..., : head_dim // 2] # First half
267
- x2 = x[..., head_dim // 2:] # Second half
268
-
269
- # Adjust sin and cos shapes
270
- cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
271
- sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
272
-
273
- # Apply the rotary transformation
274
- rotated = torch.cat((-x2, x1), dim=-1)
275
- x_rotated = (x * cos) + (rotated * sin)
276
-
277
- # It's ok to use lower-precision after applying cos and sin rotation
278
- return x_rotated.to(dtype=x.dtype)
279
-
280
-
281
- def rescale_theta(theta_old, context_length_old, context_length_new):
282
- scaling_factor = context_length_new / context_length_old
283
- theta_new = theta_old * scaling_factor
284
- return theta_new
285
-
286
-
287
- def text_to_token_ids(text, tokenizer):
288
- encoded = tokenizer.encode(text)
289
- encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
290
- return encoded_tensor
291
-
292
-
293
- def token_ids_to_text(token_ids, tokenizer):
294
- flat = token_ids.squeeze(0) # remove batch dimension
295
- return tokenizer.decode(flat.tolist())
296
-
297
-
298
- def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
299
-
300
- # For-loop is the same as before: Get logits, and only focus on last time step
301
- for _ in range(max_new_tokens):
302
- idx_cond = idx[:, -context_size:]
303
- with torch.no_grad():
304
- logits = model(idx_cond)
305
- logits = logits[:, -1, :]
306
-
307
- # Filter logits with top_k sampling
308
- if top_k is not None:
309
- # Keep only top_k values
310
- top_logits, _ = torch.topk(logits, top_k)
311
- min_val = top_logits[:, -1]
312
- logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
313
-
314
- # Apply temperature scaling
315
- if temperature > 0.0:
316
- logits = logits / temperature
317
-
318
- # Apply softmax to get probabilities
319
- probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
320
-
321
- # Sample from the distribution
322
- idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
323
-
324
- # Otherwise same as before: get idx of the vocab entry with the highest logits value
325
- else:
326
- idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
327
-
328
- if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
329
- break
330
-
331
- # Same as before: append sampled index to the running sequence
332
- idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
333
-
334
- return idx