tranquilkd commited on
Commit
bac9d3f
·
1 Parent(s): 5a42e50

Initial commit

Browse files
Files changed (3) hide show
  1. app.py +122 -0
  2. model.py +357 -0
  3. smollm2_HF.pth +3 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from json import load
2
+ import os
3
+ import torch
4
+ import gradio as gr
5
+ from typing import Optional
6
+ from dataclasses import dataclass
7
+ from transformers import AutoTokenizer
8
+ from model import Transformer
9
+
10
+
11
+ @dataclass
12
+ class ModelArgs:
13
+ # Arch params
14
+ dim: int = 576
15
+ intermediate_dim: int = 1536
16
+ n_layers: int = 30
17
+ n_heads: int = 9
18
+ n_kv_heads: Optional[int] = 3
19
+ vocab_size: int = 49152 # defined later by tokenizer
20
+ norm_eps: float = 1.0e-05
21
+ init_scale: float = 0.041666666666666664
22
+ rope_theta: int = 10000
23
+ dropout: float = 0.1
24
+
25
+ # Training params
26
+ seed: int = 42
27
+ max_batch_size: int = 2
28
+ max_seq_len: int = 2048
29
+ steps: int = 5050
30
+ breakpoint_step: int = 5000
31
+ warmup_steps_frac: float = 0.5
32
+ save_interval:int = 1000
33
+ eval_interval:int = 500
34
+ log_interval: int = 1
35
+ grad_accum_steps: int = 8
36
+ checkpoint_path = os.path.join(os.getcwd(), "checkpoints")
37
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+ # Optimizer
40
+ initial_lr: float = 5e-4
41
+ adam_beta1: float = 0.9
42
+ adam_beta2: float = 0.95
43
+ adam_eps: float = 1.0e-08
44
+ weight_decay: float = 0.01
45
+ use_fused: bool = True
46
+
47
+
48
+ # Initialize model and tokenizer
49
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
50
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
51
+ config = ModelArgs()
52
+ config.device = device
53
+ model = Transformer(config)
54
+
55
+ # Load trained weights from zip
56
+ def load_checkpoint(model, path, device):
57
+ try:
58
+ checkpoint = torch.load(path, map_location=device)
59
+ model.load_state_dict({k.replace("_orig_mod.", ""): v for k, v in checkpoint.items() if 'cached_keys' not in k and 'cached_values' not in k})
60
+ return model
61
+ except Exception as e:
62
+ print(f"Error loading checkpoint: {e}")
63
+ return None
64
+
65
+ model = load_checkpoint(model, "smollm2_HF.pth", device)
66
+ model.to(device)
67
+ model.eval()
68
+
69
+ def generate_text(prompt,
70
+ min_length: int = 28,
71
+ max_length: int = 40,
72
+ temperature: float =0.7,
73
+ top_k: int = 50,
74
+ top_p: float = 0.7
75
+ ):
76
+ """Generate text from a prompt"""
77
+ input_ids = tokenizer(prompt,
78
+ padding=True,
79
+ truncation=True,
80
+ max_length=config.max_seq_len,
81
+ return_tensors="pt")["input_ids"].to(device)
82
+
83
+ generated = model.generate(
84
+ input_ids,
85
+ max_length=max_length,
86
+ min_length=min_length,
87
+ pad_token_id=tokenizer.pad_token_id,
88
+ do_sample=True,
89
+ temperature=temperature,
90
+ top_k=top_k,
91
+ top_p=top_p
92
+ )
93
+
94
+ return tokenizer.decode(generated[0], skip_special_tokens=True)
95
+
96
+ # Gradio interface
97
+ def gradio_interface(prompt, max_length, temperature, top_k):
98
+ return generate_text(prompt, int(max_length), float(temperature), int(top_k))
99
+
100
+ iface = gr.Interface(
101
+ fn=gradio_interface,
102
+ inputs=[
103
+ gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
104
+ gr.Slider(minimum=10, maximum=500, label="Min Length"),
105
+ gr.Slider(minimum=10, maximum=500, label="Max Length"),
106
+ gr.Slider(minimum=0.1, maximum=2.0, label="Temperature"),
107
+ gr.Slider(minimum=1, maximum=100, label="Top K"),
108
+ gr.Slider(minimum=0.1, maximum=1.0, label="Top P")
109
+ ],
110
+ outputs=gr.Textbox(label="Generated Text"),
111
+ title="SmolLM2-135M Text Generation",
112
+ description="SmolLM2-135M trained onn cosmopedia-v2 with just 5000 steps",
113
+ examples=[
114
+ ["I found the love", 50, 0.7, 50],
115
+ ["When the sun comes up", 40, 0.8, 40],
116
+ ["The slow marching of ", 60, 0.9, 45]
117
+ ],
118
+ )
119
+
120
+
121
+ if __name__ == "__main__":
122
+ iface.launch()
model.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import logging
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from typing import Optional
6
+ from torch import nn
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class RMSNorm(nn.Module):
12
+ def __init__(self, dim, eps):
13
+ super().__init__()
14
+ self.eps = eps
15
+ self.weight = nn.Parameter(torch.ones(dim))
16
+
17
+ def forward(self, x):
18
+ # Root Mean Square Layer Normalization
19
+ rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
20
+ return x * rms * self.weight
21
+
22
+
23
+ class RotaryEmbedding(nn.Module):
24
+ def __init__(self, dim, max_seq_len=2048, theta=10000):
25
+ super().__init__()
26
+ self.dim = dim
27
+ self.max_seq_len = max_seq_len
28
+
29
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
30
+ self.register_buffer("freqs", freqs)
31
+
32
+ t = torch.arange(max_seq_len, dtype=self.freqs.dtype)
33
+ freqs = torch.outer(t, self.freqs)
34
+
35
+ cos = freqs.cos()
36
+ sin = freqs.sin()
37
+ self.register_buffer('cos', cos)
38
+ self.register_buffer('sin', sin)
39
+
40
+
41
+ def rotate_half(self, x):
42
+ rot_dim = x.shape[-1]
43
+ x1 = x[..., :rot_dim // 2]
44
+ x2 = x[..., rot_dim // 2:]
45
+ return torch.cat((-x2, x1), dim=-1)
46
+
47
+ def apply_rotary_emb(self, t, x):
48
+ rot_dim = self.freqs.shape[-1]
49
+ cos = self.cos[t, :rot_dim]
50
+ sin = self.sin[t, :rot_dim]
51
+
52
+ rotated_x = (x[..., :rot_dim] * cos) + (self.rotate_half(x[..., :rot_dim]) * sin)
53
+ if x.shape[-1] > rot_dim:
54
+ rotated_x = torch.cat((rotated_x, x[..., rot_dim:]), dim=-1)
55
+ return rotated_x
56
+
57
+ def forward(self, x, seq_dim=-2):
58
+ seq_len = x.shape[seq_dim]
59
+ t = torch.arange(seq_len, device=x.device)
60
+ return self.apply_rotary_emb(t, x)
61
+
62
+
63
+ class Attention(nn.Module):
64
+ def __init__(self, args):
65
+ super().__init__()
66
+ self.dim = args.dim
67
+ self.num_heads = args.n_heads
68
+ self.kv_heads = args.n_kv_heads
69
+ self.head_dim = args.dim // args.n_heads
70
+ self.kv_head_dim = args.dim // args.n_kv_heads
71
+
72
+ assert self.head_dim * args.n_heads == args.dim, "args.dim must be divisible by args.n_heads"
73
+ assert self.kv_head_dim * args.n_kv_heads == args.dim, "args.dim must be divisible by args.n_kv_heads"
74
+
75
+ self.query_proj = nn.Linear(args.dim, args.dim, bias=False)
76
+ self.key_proj = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
77
+ self.value_proj = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
78
+
79
+ self.rope = RotaryEmbedding(self.head_dim)
80
+
81
+ self.out_proj = nn.Linear(args.dim, args.dim, bias=False)
82
+ self.dropout = nn.Dropout(args.dropout)
83
+
84
+ # # Caching storage (keys and values)
85
+ cached_keys = None
86
+ cached_values = None
87
+ self.register_buffer('cached_keys', cached_keys)
88
+ self.register_buffer('cached_values', cached_values)
89
+
90
+ def forward(self, x, mask=None, use_cache=False):
91
+ # # batch_size = x.size(0)
92
+ batch_size, seq_len, C = x.size()
93
+
94
+ query = self.query_proj(x)
95
+ key = self.key_proj(x)
96
+ value = self.value_proj(x)
97
+
98
+ # Reshape for attention computation
99
+ query = query.view(batch_size, seq_len, self.num_heads, self.head_dim)
100
+ key = key.view(batch_size, seq_len, self.kv_heads, self.head_dim)
101
+ value = value.view(batch_size, seq_len, self.kv_heads, self.head_dim)
102
+
103
+ # Transpose for attention computation
104
+ query = query.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
105
+ key = key.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
106
+ value = value.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
107
+
108
+ query = self.rope(query)
109
+ key = self.rope(key)
110
+
111
+ # # If kv_heads are less than num_heads, repeat them
112
+ # if self.kv_heads < self.num_heads:
113
+ # key = key.repeat_interleave(self.num_heads // self.kv_heads, dim=1)
114
+ # value = value.repeat_interleave(self.num_heads // self.kv_heads, dim=1)
115
+
116
+ # # Compute attention
117
+ # attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
118
+ # if mask is not None:
119
+ # attn_weights = attn_weights + mask
120
+ # attn_weights = F.softmax(attn_weights, dim=-1)
121
+
122
+ # # Compute output
123
+ # output = torch.matmul(attn_weights, value)
124
+
125
+ # Flash-attn
126
+ output = F.scaled_dot_product_attention(query, key, value, is_causal=True, dropout_p=self.dropout.p, enable_gqa=True)
127
+
128
+ # Update cache only if using cache
129
+ if use_cache:
130
+ self.cached_keys = key
131
+ self.cached_values = value
132
+ else:
133
+ # Reset cached values during training (to prevent unwanted accumulation)
134
+ self.cached_keys = None
135
+ self.cached_values = None
136
+
137
+ output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) # [batch, seq_len, num_heads * head_dim]
138
+ return self.out_proj(output)
139
+
140
+ class FeedForward(nn.Module):
141
+ def __init__(self, args):
142
+ """
143
+ Initialize the FeedForward module.
144
+
145
+ Args:
146
+ dim (int): Input dimension.
147
+ hidden_dim (int): Hidden dimension of the feedforward layer. # 2304
148
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
149
+
150
+ Attributes:
151
+ w1 (nn.Linear): Linear transformation for the first layer.
152
+ w2 (nn.Linear): Linear transformation for the second layer.
153
+ w3 (nn.Linear): Linear transformation for the third layer.
154
+
155
+ """
156
+ super().__init__()
157
+ self.w1 = nn.Linear(args.dim, args.intermediate_dim, bias=False)
158
+ self.w2 = nn.Linear(args.intermediate_dim, args.dim, bias=False)
159
+ self.w3 = nn.Linear(args.dim, args.intermediate_dim, bias=False)
160
+
161
+ def forward(self, x):
162
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
163
+
164
+
165
+ class TransformerBlock(nn.Module):
166
+ def __init__(self, layer_id: int, args):
167
+ """
168
+ Initialize a TransformerBlock.
169
+
170
+ Args:
171
+ layer_id (int): Identifier for the layer.
172
+ args (ModelArgs): Model configuration parameters.
173
+
174
+ Attributes:
175
+ n_heads (int): Number of attention heads.
176
+ dim (int): Dimension size of the model.
177
+ head_dim (int): Dimension size of each attention head.
178
+ attention (Attention): Attention module.
179
+ feed_forward (FeedForward): FeedForward module.
180
+ layer_id (int): Identifier for the layer.
181
+ attention_norm (RMSNorm): Layer normalization for attention output.
182
+ ffn_norm (RMSNorm): Layer normalization for feedforward output.
183
+
184
+ """
185
+ super().__init__()
186
+ self.n_heads = args.n_heads
187
+ self.dim = args.dim
188
+ self.head_dim = args.dim // args.n_heads
189
+ self.attention = Attention(args)
190
+ self.feed_forward = FeedForward(args)
191
+ self.layer_id = layer_id
192
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
193
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
194
+
195
+ def forward(
196
+ self,
197
+ x: torch.Tensor,
198
+ mask: Optional[torch.Tensor],
199
+ use_cache: bool
200
+ ):
201
+ """
202
+ Perform a forward pass through the TransformerBlock.
203
+
204
+ Args:
205
+ x (torch.Tensor): Input tensor.
206
+ mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
207
+ use_cache (bool): whether to use kv_cache
208
+
209
+ Returns:
210
+ torch.Tensor: Output tensor after applying attention and feedforward layers.
211
+
212
+ """
213
+ h = x + self.attention(self.attention_norm(x), mask=mask, use_cache=use_cache)
214
+ out = h + self.feed_forward(self.ffn_norm(h))
215
+ return out
216
+
217
+
218
+ class Transformer(nn.Module):
219
+ def __init__(self, args):
220
+ """
221
+ Initialize a Transformer model.
222
+
223
+ Args:
224
+ args (ModelArgs): Model configuration parameters.
225
+
226
+ Attributes:
227
+ args (ModelArgs): Model configuration parameters.
228
+ vocab_size (int): Vocabulary size.
229
+ n_layers (int): Number of layers in the model.
230
+ tok_embeddings (nn.Embedding): Token embeddings.
231
+ layers (torch.nn.ModuleList): List of Transformer blocks.
232
+ norm (RMSNorm): Layer normalization for the model output.
233
+ output (nn.Linear): Linear layer for final output.
234
+
235
+ """
236
+ super().__init__()
237
+ self.args = args
238
+ self.vocab_size = args.vocab_size
239
+ self.n_layers = args.n_layers
240
+
241
+ self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
242
+
243
+ self.layers = torch.nn.ModuleList()
244
+ for layer_id in range(args.n_layers):
245
+ self.layers.append(TransformerBlock(layer_id, args))
246
+
247
+ self.norm = RMSNorm(args.dim, eps=args.norm_eps)
248
+ # self.output = nn.Linear(
249
+ # args.dim, args.vocab_size, bias=False
250
+ # )
251
+
252
+ # # weight sharing
253
+ # self.output.weight = self.tok_embeddings.weight
254
+
255
+ # weight initialization
256
+ self.apply(self._init_weights)
257
+
258
+
259
+ def _init_weights(self, module):
260
+ std = self.args.init_scale
261
+ if isinstance(module, nn.Linear):
262
+ module.weight.data.normal_(mean=0.0, std=std)
263
+ # if module.bias is not None:
264
+ # module.bias.data.zero_()
265
+ elif isinstance(module, nn.Embedding):
266
+ module.weight.data.normal_(mean=0.0, std=std)
267
+
268
+
269
+ def forward(self, tokens: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = False):
270
+ """
271
+ Perform a forward pass through the Transformer model.
272
+
273
+ Args:
274
+ tokens (torch.Tensor): Input token indices.
275
+ mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
276
+ use_cache (bool): whether to use kv_cache
277
+
278
+ Returns:
279
+ torch.Tensor: Output logits after applying the Transformer model.
280
+
281
+ """
282
+ _, seqlen = tokens.shape
283
+ h = self.tok_embeddings(tokens)
284
+
285
+ if mask is None:
286
+ mask = torch.triu(torch.ones((seqlen, seqlen),
287
+ dtype=torch.bool,
288
+ device=tokens.device),
289
+ diagonal=1)
290
+ mask = mask.unsqueeze(0).unsqueeze(0)
291
+ mask = mask * -1e4
292
+
293
+ for layer in self.layers:
294
+ h = layer(h, mask, use_cache)
295
+ h = self.norm(h)
296
+ # output = self.output(h).float()
297
+ output = F.linear(h, self.tok_embeddings.weight)
298
+ return output
299
+
300
+ def generate(self,
301
+ input_ids,
302
+ max_length,
303
+ min_length=None,
304
+ num_return_sequences=1,
305
+ pad_token_id=None,
306
+ do_sample=True,
307
+ temperature=0.8,
308
+ top_k=50,
309
+ top_p=0.95
310
+ ):
311
+ self.eval()
312
+ # batch_size = input_ids.shape[0]
313
+ min_length = min_length if min_length is not None else input_ids.shape[1]
314
+
315
+ with torch.no_grad():
316
+ for ret_seq in range(num_return_sequences):
317
+ logger.info(f"Sequence #{ret_seq + 1}:")
318
+ for _ in range(max_length - input_ids.shape[1]):
319
+ outputs = self(input_ids, use_cache=True)
320
+ next_token_logits = outputs[:, -1, :]
321
+
322
+ # Apply temperature
323
+ next_token_logits = next_token_logits / temperature
324
+
325
+ # Apply top-k filtering
326
+ if top_k > 0:
327
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
328
+ next_token_logits[indices_to_remove] = float('-inf')
329
+
330
+ # Apply top-p (nucleus) filtering
331
+ if top_p < 1.0:
332
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
333
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
334
+ sorted_indices_to_remove = cumulative_probs > top_p
335
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
336
+ sorted_indices_to_remove[..., 0] = 0
337
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
338
+ next_token_logits[indices_to_remove] = float('-inf')
339
+
340
+ # Sample from the filtered distribution
341
+ if do_sample:
342
+ probs = torch.softmax(next_token_logits, dim=-1)
343
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
344
+ else:
345
+ next_tokens = torch.argmax(next_token_logits, dim=-1)
346
+
347
+ input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
348
+
349
+ # Stop if all sequences have hit the pad token
350
+ if pad_token_id is not None and (next_tokens == pad_token_id).all():
351
+ break
352
+
353
+ # Stop if we've reached min_length
354
+ if input_ids.shape[1] < min_length:
355
+ continue
356
+
357
+ return input_ids
smollm2_HF.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5f29f2b3a2075407a1d43c74cc25ff2af4efbce0e73f0fe2a10eb4f16a69044
3
+ size 553939176