jatingocodeo commited on
Commit
d1cb366
·
verified ·
1 Parent(s): f4e7e95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +466 -0
app.py CHANGED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PretrainedConfig
4
+ from huggingface_hub import hf_hub_download
5
+ import json
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import math
9
+
10
+ # Define the model architecture
11
+ class SmolLM2Config(PretrainedConfig):
12
+ model_type = "smollm2"
13
+
14
+ def __init__(
15
+ self,
16
+ vocab_size=49152,
17
+ hidden_size=576,
18
+ intermediate_size=1536,
19
+ num_hidden_layers=30,
20
+ num_attention_heads=9,
21
+ num_key_value_heads=3,
22
+ hidden_act="silu",
23
+ max_position_embeddings=2048,
24
+ initializer_range=0.041666666666666664,
25
+ rms_norm_eps=1e-5,
26
+ use_cache=True,
27
+ pad_token_id=None,
28
+ bos_token_id=0,
29
+ eos_token_id=0,
30
+ tie_word_embeddings=True,
31
+ rope_theta=10000.0,
32
+ **kwargs
33
+ ):
34
+ self.vocab_size = vocab_size
35
+ self.hidden_size = hidden_size
36
+ self.intermediate_size = intermediate_size
37
+ self.num_hidden_layers = num_hidden_layers
38
+ self.num_attention_heads = num_attention_heads
39
+ self.num_key_value_heads = num_key_value_heads
40
+ self.hidden_act = hidden_act
41
+ self.max_position_embeddings = max_position_embeddings
42
+ self.initializer_range = initializer_range
43
+ self.rms_norm_eps = rms_norm_eps
44
+ self.use_cache = use_cache
45
+ self.rope_theta = rope_theta
46
+ super().__init__(
47
+ pad_token_id=pad_token_id,
48
+ bos_token_id=bos_token_id,
49
+ eos_token_id=eos_token_id,
50
+ tie_word_embeddings=tie_word_embeddings,
51
+ **kwargs
52
+ )
53
+
54
+ # Register the model architecture
55
+ from transformers import AutoConfig
56
+ AutoConfig.register("smollm2", SmolLM2Config)
57
+
58
+ class RMSNorm(nn.Module):
59
+ def __init__(self, hidden_size, eps=1e-5):
60
+ super().__init__()
61
+ self.weight = nn.Parameter(torch.ones(hidden_size))
62
+ self.eps = eps
63
+
64
+ def forward(self, x):
65
+ variance = x.pow(2).mean(-1, keepdim=True)
66
+ x = x * torch.rsqrt(variance + self.eps)
67
+ return self.weight * x
68
+
69
+ def precompute_rope_frequencies(dim: int, max_position_embeddings: int, theta: float = 10000.0):
70
+ position = torch.arange(max_position_embeddings).unsqueeze(1) # [seq_len, 1]
71
+ div_term = theta ** (torch.arange(0, dim, 2).float() / dim) # [dim/2]
72
+ freqs = position / div_term # [seq_len, dim/2]
73
+ return freqs
74
+
75
+ def apply_rotary_embeddings(x: torch.Tensor, freqs: torch.Tensor):
76
+ # x shape: [batch, seq_len, heads, head_dim]
77
+ # freqs shape: [seq_len, head_dim/2]
78
+ x_rot = x.float()
79
+
80
+ # Reshape freqs to match x's dimensions
81
+ freqs = freqs.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim/2]
82
+
83
+ # Split channels for rotation
84
+ x1, x2 = x_rot[..., :x_rot.shape[-1]//2], x_rot[..., x_rot.shape[-1]//2:]
85
+
86
+ # Apply rotary embeddings
87
+ cos = torch.cos(freqs).to(x.device)
88
+ sin = torch.sin(freqs).to(x.device)
89
+
90
+ # Ensure broadcasting dimensions match
91
+ cos = cos.expand_as(x1)
92
+ sin = sin.expand_as(x1)
93
+
94
+ # Rotate x1 and x2
95
+ x1_rot = x1 * cos - x2 * sin
96
+ x2_rot = x2 * cos + x1 * sin
97
+
98
+ # Concatenate back
99
+ return torch.cat([x1_rot, x2_rot], dim=-1).to(x.dtype)
100
+
101
+ class LlamaAttention(nn.Module):
102
+ def __init__(self, config: SmolLM2Config):
103
+ super().__init__()
104
+ self.hidden_size = config.hidden_size
105
+ self.num_heads = config.num_attention_heads
106
+ self.num_kv_heads = config.num_key_value_heads
107
+ self.head_dim = config.hidden_size // config.num_attention_heads
108
+
109
+ # Adjust projections to match head dimensions
110
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
111
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
112
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
113
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
114
+
115
+ # Initialize rotary embeddings
116
+ self.register_buffer(
117
+ "rope_freqs",
118
+ precompute_rope_frequencies(
119
+ self.head_dim, # Use full head_dim for frequencies
120
+ config.max_position_embeddings,
121
+ config.rope_theta
122
+ ),
123
+ persistent=False
124
+ )
125
+
126
+ def forward(self, hidden_states, attention_mask=None):
127
+ batch_size, seq_length, _ = hidden_states.size()
128
+
129
+ # Project and reshape
130
+ q = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim)
131
+ k = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
132
+ v = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
133
+
134
+ # Apply rotary embeddings
135
+ q = apply_rotary_embeddings(q, self.rope_freqs[:seq_length])
136
+ k = apply_rotary_embeddings(k, self.rope_freqs[:seq_length])
137
+
138
+ # Repeat k/v heads if num_kv_heads < num_heads
139
+ if self.num_kv_heads < self.num_heads:
140
+ k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
141
+ v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
142
+
143
+ # Scaled dot-product attention
144
+ q = q.transpose(1, 2) # (batch, num_heads, seq_len, head_dim)
145
+ k = k.transpose(1, 2)
146
+ v = v.transpose(1, 2)
147
+
148
+ attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
149
+
150
+ if attention_mask is not None:
151
+ attention_scores = attention_scores + attention_mask
152
+
153
+ attention_probs = F.softmax(attention_scores, dim=-1)
154
+ context = torch.matmul(attention_probs, v)
155
+
156
+ context = context.transpose(1, 2).contiguous()
157
+ context = context.view(batch_size, seq_length, -1)
158
+
159
+ return self.o_proj(context)
160
+
161
+ class LlamaMLP(nn.Module):
162
+ def __init__(self, config: SmolLM2Config):
163
+ super().__init__()
164
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
165
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
166
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
167
+ self.act_fn = nn.SiLU()
168
+
169
+ def forward(self, x):
170
+ gate = self.act_fn(self.gate_proj(x))
171
+ up = self.up_proj(x)
172
+ return self.down_proj(gate * up)
173
+
174
+ class LlamaDecoderLayer(nn.Module):
175
+ def __init__(self, config: SmolLM2Config):
176
+ super().__init__()
177
+ self.self_attn = LlamaAttention(config)
178
+ self.mlp = LlamaMLP(config)
179
+ self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
180
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
181
+
182
+ def forward(self, hidden_states, attention_mask=None):
183
+ residual = hidden_states
184
+ hidden_states = self.input_layernorm(hidden_states)
185
+ hidden_states = self.self_attn(hidden_states, attention_mask)
186
+ hidden_states = residual + hidden_states
187
+
188
+ residual = hidden_states
189
+ hidden_states = self.post_attention_layernorm(hidden_states)
190
+ hidden_states = self.mlp(hidden_states)
191
+ hidden_states = residual + hidden_states
192
+
193
+ return hidden_states
194
+
195
+ class SmolLM2ForCausalLM(PreTrainedModel):
196
+ config_class = SmolLM2Config
197
+
198
+ def __init__(self, config):
199
+ super().__init__(config)
200
+ self.config = config
201
+
202
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
203
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
204
+ self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
205
+
206
+ # Add lm_head before weight tying
207
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
208
+
209
+ # Initialize weights
210
+ self.apply(self._init_weights)
211
+
212
+ # Tie weights if configured
213
+ if config.tie_word_embeddings:
214
+ self.lm_head.weight = self.embed_tokens.weight
215
+
216
+ def _init_weights(self, module):
217
+ if isinstance(module, nn.Linear):
218
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
219
+ if module.bias is not None:
220
+ torch.nn.init.zeros_(module.bias)
221
+ elif isinstance(module, nn.Embedding):
222
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
223
+
224
+ def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
225
+ hidden_states = self.embed_tokens(input_ids)
226
+
227
+ # Create causal attention mask if none provided
228
+ if attention_mask is None:
229
+ # Create causal mask
230
+ seq_length = input_ids.size(1)
231
+ # [batch_size, 1, seq_length, seq_length]
232
+ causal_mask = torch.triu(
233
+ torch.ones((seq_length, seq_length), dtype=torch.bool, device=input_ids.device),
234
+ diagonal=1
235
+ ).unsqueeze(0).unsqueeze(0)
236
+ attention_mask = torch.zeros(
237
+ (1, 1, seq_length, seq_length),
238
+ dtype=hidden_states.dtype,
239
+ device=hidden_states.device
240
+ )
241
+ attention_mask.masked_fill_(causal_mask, float("-inf"))
242
+
243
+ for layer in self.layers:
244
+ hidden_states = layer(hidden_states, attention_mask)
245
+
246
+ hidden_states = self.norm(hidden_states)
247
+ logits = self.lm_head(hidden_states)
248
+
249
+ loss = None
250
+ if labels is not None:
251
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
252
+
253
+ return logits if loss is None else (loss, logits)
254
+
255
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
256
+ return {
257
+ "input_ids": input_ids,
258
+ "attention_mask": kwargs.get("attention_mask", None)
259
+ }
260
+
261
+ def generate(
262
+ self,
263
+ input_ids,
264
+ max_length=100,
265
+ temperature=0.7,
266
+ top_k=50,
267
+ do_sample=True,
268
+ num_return_sequences=1,
269
+ pad_token_id=None,
270
+ eos_token_id=None,
271
+ **kwargs
272
+ ):
273
+ cur_len = input_ids.shape[1]
274
+ batch_size = input_ids.shape[0]
275
+
276
+ if max_length < cur_len:
277
+ max_length = cur_len
278
+
279
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
280
+
281
+ while cur_len < max_length:
282
+ # Prepare model inputs
283
+ model_inputs = self.prepare_inputs_for_generation(input_ids)
284
+
285
+ # Forward pass
286
+ with torch.no_grad():
287
+ outputs = self(**model_inputs)
288
+ next_token_logits = outputs[:, -1, :]
289
+
290
+ # Temperature scaling
291
+ if temperature != 1.0 and temperature > 0:
292
+ next_token_logits = next_token_logits / temperature
293
+
294
+ # Top-k filtering
295
+ if top_k > 0:
296
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
297
+ next_token_logits[indices_to_remove] = float('-inf')
298
+
299
+ # Sample or greedy
300
+ if do_sample:
301
+ probs = F.softmax(next_token_logits, dim=-1)
302
+ next_tokens = torch.multinomial(probs, num_samples=1)
303
+ else:
304
+ next_tokens = torch.argmax(next_token_logits, dim=-1)
305
+ next_tokens = next_tokens.unsqueeze(-1)
306
+
307
+ # Append next tokens
308
+ input_ids = torch.cat([input_ids, next_tokens], dim=-1)
309
+ cur_len = input_ids.shape[1]
310
+
311
+ # Early stopping if all sequences have reached the EOS token
312
+ if eos_token_id is not None:
313
+ unfinished_sequences = unfinished_sequences.mul(
314
+ next_tokens.squeeze(-1).ne(eos_token_id).long()
315
+ )
316
+ if unfinished_sequences.max() == 0:
317
+ break
318
+
319
+ return input_ids
320
+
321
+ # Register the model
322
+ AutoModelForCausalLM.register(SmolLM2Config, SmolLM2ForCausalLM)
323
+
324
+ # Cache for model and tokenizer
325
+ MODEL = None
326
+ TOKENIZER = None
327
+ CONFIG = None
328
+
329
+ def initialize():
330
+ global MODEL, TOKENIZER, CONFIG
331
+
332
+ if MODEL is None:
333
+ print("Loading model and tokenizer...")
334
+ model_id = "jatingocodeo/SmolLM2"
335
+
336
+ try:
337
+ # Download and load config
338
+ print("Loading config...")
339
+ config_path = hf_hub_download(repo_id=model_id, filename="config.json")
340
+ with open(config_path, 'r') as f:
341
+ config_dict = json.load(f)
342
+ CONFIG = SmolLM2Config(**config_dict)
343
+
344
+ # Load tokenizer
345
+ print("Loading tokenizer...")
346
+ TOKENIZER = AutoTokenizer.from_pretrained(
347
+ model_id,
348
+ model_max_length=CONFIG.max_position_embeddings,
349
+ padding_side="left",
350
+ truncation_side="left",
351
+ trust_remote_code=True
352
+ )
353
+
354
+ # Make sure we're using the correct special tokens
355
+ special_tokens = {
356
+ 'bos_token': '<|endoftext|>',
357
+ 'eos_token': '<|endoftext|>',
358
+ 'unk_token': '<|endoftext|>',
359
+ 'pad_token': '<|endoftext|>' # Using endoftext as pad token since it's not specified
360
+ }
361
+ TOKENIZER.add_special_tokens(special_tokens)
362
+
363
+ # Load model weights
364
+ print("Loading model...")
365
+ weights_path = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
366
+
367
+ # Initialize model
368
+ MODEL = SmolLM2ForCausalLM(CONFIG)
369
+
370
+ # Resize token embeddings to match tokenizer
371
+ MODEL.resize_token_embeddings(len(TOKENIZER))
372
+
373
+ # Load state dict
374
+ state_dict = torch.load(weights_path, map_location="cpu")
375
+ MODEL.load_state_dict(state_dict)
376
+
377
+ # Move model to device
378
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
379
+ MODEL = MODEL.to(device)
380
+
381
+ print(f"Model loaded successfully on {device}")
382
+
383
+ except Exception as e:
384
+ print(f"Error initializing: {str(e)}")
385
+ raise
386
+
387
+ def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
388
+ # Initialize if not already done
389
+ if MODEL is None:
390
+ try:
391
+ initialize()
392
+ except Exception as e:
393
+ return f"Failed to initialize model: {str(e)}"
394
+
395
+ try:
396
+ # Process prompt
397
+ if not prompt.strip():
398
+ return "Please enter a prompt."
399
+
400
+ # Add BOS token if needed
401
+ if not prompt.startswith(TOKENIZER.bos_token):
402
+ prompt = TOKENIZER.bos_token + prompt
403
+
404
+ # Encode prompt
405
+ encoded = TOKENIZER.encode_plus(
406
+ prompt,
407
+ add_special_tokens=True,
408
+ return_tensors="pt",
409
+ padding=True,
410
+ truncation=True,
411
+ max_length=CONFIG.max_position_embeddings
412
+ )
413
+ input_ids = encoded["input_ids"].to(MODEL.device)
414
+ attention_mask = encoded["attention_mask"].to(MODEL.device)
415
+
416
+ # Generate
417
+ with torch.no_grad():
418
+ outputs = MODEL.generate(
419
+ input_ids,
420
+ attention_mask=attention_mask,
421
+ max_length=min(max_length + len(input_ids[0]), CONFIG.max_position_embeddings),
422
+ temperature=max(0.1, min(temperature, 1.0)), # Clamp temperature
423
+ top_k=max(1, min(top_k, 100)), # Clamp top_k
424
+ do_sample=True if temperature > 0 else False,
425
+ num_return_sequences=1,
426
+ pad_token_id=TOKENIZER.pad_token_id,
427
+ eos_token_id=TOKENIZER.eos_token_id,
428
+ )
429
+
430
+ # Decode and return
431
+ generated_text = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
432
+ return generated_text.strip()
433
+
434
+ except Exception as e:
435
+ import traceback
436
+ traceback.print_exc()
437
+ return f"Error during text generation: {str(e)}"
438
+
439
+ # Create Gradio interface
440
+ iface = gr.Interface(
441
+ fn=generate_text,
442
+ inputs=[
443
+ gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=2),
444
+ gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
445
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
446
+ gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top K"),
447
+ ],
448
+ outputs=gr.Textbox(label="Generated Text", lines=5),
449
+ title="SmolLM2 Text Generator",
450
+ description="Generate text using the fine-tuned SmolLM2 model. Adjust parameters to control the generation.",
451
+ examples=[
452
+ ["Once upon a time", 100, 0.7, 50],
453
+ ["The quick brown fox", 150, 0.8, 40],
454
+ ],
455
+ allow_flagging="never"
456
+ )
457
+
458
+ # Initialize on startup
459
+ try:
460
+ initialize()
461
+ except Exception as e:
462
+ print(f"Warning: Model initialization failed: {str(e)}")
463
+ print("Model will be initialized on first request")
464
+
465
+ if __name__ == "__main__":
466
+ iface.launch()