jatingocodeo commited on
Commit
6c3a55b
·
verified ·
1 Parent(s): 1a2e215

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -44
app.py CHANGED
@@ -21,13 +21,14 @@ class SmolLM2Config(PretrainedConfig):
21
  num_key_value_heads=3,
22
  hidden_act="silu",
23
  max_position_embeddings=2048,
24
- initializer_range=0.02,
25
  rms_norm_eps=1e-5,
26
  use_cache=True,
27
  pad_token_id=None,
28
  bos_token_id=0,
29
  eos_token_id=0,
30
  tie_word_embeddings=True,
 
31
  **kwargs
32
  ):
33
  self.vocab_size = vocab_size
@@ -41,6 +42,7 @@ class SmolLM2Config(PretrainedConfig):
41
  self.initializer_range = initializer_range
42
  self.rms_norm_eps = rms_norm_eps
43
  self.use_cache = use_cache
 
44
  super().__init__(
45
  pad_token_id=pad_token_id,
46
  bos_token_id=bos_token_id,
@@ -64,54 +66,125 @@ class RMSNorm(nn.Module):
64
  x = x * torch.rsqrt(variance + self.eps)
65
  return self.weight * x
66
 
67
- class LlamaDecoderLayer(nn.Module):
68
- def __init__(self, config):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  super().__init__()
70
  self.hidden_size = config.hidden_size
71
  self.num_heads = config.num_attention_heads
 
72
  self.head_dim = config.hidden_size // config.num_attention_heads
73
 
74
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
75
- self.k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
76
- self.v_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
77
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
 
78
 
79
- self.mlp = nn.Sequential(
80
- nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
81
- nn.SiLU(),
82
- nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
 
 
 
 
 
83
  )
 
 
 
84
 
85
- self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
86
- self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
 
 
87
 
88
- def forward(self, hidden_states, attention_mask=None):
89
- # Self Attention
90
- residual = hidden_states
91
- hidden_states = self.input_layernorm(hidden_states)
92
 
93
- # Reshape for attention
94
- batch_size, seq_length, _ = hidden_states.size()
95
- q = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
96
- k = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
97
- v = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
98
 
99
- # Compute attention scores
100
- scale = 1.0 / math.sqrt(self.head_dim)
101
- scores = torch.matmul(q, k.transpose(-2, -1)) * scale
 
 
 
102
 
103
  if attention_mask is not None:
104
- scores = scores + attention_mask
105
 
106
- attn_weights = F.softmax(scores, dim=-1)
107
- hidden_states = torch.matmul(attn_weights, v)
 
 
 
108
 
109
- # Reshape back
110
- hidden_states = hidden_states.transpose(1, 2).contiguous().view(batch_size, seq_length, -1)
111
- hidden_states = self.o_proj(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  hidden_states = residual + hidden_states
113
 
114
- # MLP
115
  residual = hidden_states
116
  hidden_states = self.post_attention_layernorm(hidden_states)
117
  hidden_states = self.mlp(hidden_states)
@@ -125,18 +198,48 @@ class SmolLM2ForCausalLM(PreTrainedModel):
125
  def __init__(self, config):
126
  super().__init__(config)
127
  self.config = config
 
128
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
129
  self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
130
  self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
 
 
131
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
132
 
 
 
 
 
133
  if config.tie_word_embeddings:
134
  self.lm_head.weight = self.embed_tokens.weight
135
 
 
 
 
 
 
 
 
 
136
  def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
137
  hidden_states = self.embed_tokens(input_ids)
138
 
139
- # Process through layers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  for layer in self.layers:
141
  hidden_states = layer(hidden_states, attention_mask)
142
 
@@ -155,15 +258,76 @@ class SmolLM2ForCausalLM(PreTrainedModel):
155
  "attention_mask": kwargs.get("attention_mask", None)
156
  }
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  # Register the model
159
  AutoModelForCausalLM.register(SmolLM2Config, SmolLM2ForCausalLM)
160
 
161
  # Cache for model and tokenizer
162
  MODEL = None
163
  TOKENIZER = None
 
164
 
165
  def initialize():
166
- global MODEL, TOKENIZER
167
 
168
  if MODEL is None:
169
  print("Loading model and tokenizer...")
@@ -175,17 +339,24 @@ def initialize():
175
  config_path = hf_hub_download(repo_id=model_id, filename="config.json")
176
  with open(config_path, 'r') as f:
177
  config_dict = json.load(f)
178
- config = SmolLM2Config(**config_dict)
179
 
180
  # Load tokenizer
181
  print("Loading tokenizer...")
182
- TOKENIZER = AutoTokenizer.from_pretrained(model_id)
 
 
 
 
 
 
183
 
184
- # Add special tokens if needed
185
  special_tokens = {
186
- 'pad_token': '[PAD]',
187
- 'eos_token': '</s>',
188
- 'bos_token': '<s>'
 
189
  }
190
  TOKENIZER.add_special_tokens(special_tokens)
191
 
@@ -194,7 +365,10 @@ def initialize():
194
  weights_path = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
195
 
196
  # Initialize model
197
- MODEL = SmolLM2ForCausalLM(config)
 
 
 
198
 
199
  # Load state dict
200
  state_dict = torch.load(weights_path, map_location="cpu")
@@ -228,14 +402,23 @@ def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
228
  prompt = TOKENIZER.bos_token + prompt
229
 
230
  # Encode prompt
231
- input_ids = TOKENIZER.encode(prompt, return_tensors="pt", truncation=True, max_length=2048)
232
- input_ids = input_ids.to(MODEL.device)
 
 
 
 
 
 
 
 
233
 
234
  # Generate
235
  with torch.no_grad():
236
  outputs = MODEL.generate(
237
  input_ids,
238
- max_length=min(max_length + len(input_ids[0]), 2048),
 
239
  temperature=max(0.1, min(temperature, 1.0)), # Clamp temperature
240
  top_k=max(1, min(top_k, 100)), # Clamp top_k
241
  do_sample=True if temperature > 0 else False,
 
21
  num_key_value_heads=3,
22
  hidden_act="silu",
23
  max_position_embeddings=2048,
24
+ initializer_range=0.041666666666666664,
25
  rms_norm_eps=1e-5,
26
  use_cache=True,
27
  pad_token_id=None,
28
  bos_token_id=0,
29
  eos_token_id=0,
30
  tie_word_embeddings=True,
31
+ rope_theta=10000.0,
32
  **kwargs
33
  ):
34
  self.vocab_size = vocab_size
 
42
  self.initializer_range = initializer_range
43
  self.rms_norm_eps = rms_norm_eps
44
  self.use_cache = use_cache
45
+ self.rope_theta = rope_theta
46
  super().__init__(
47
  pad_token_id=pad_token_id,
48
  bos_token_id=bos_token_id,
 
66
  x = x * torch.rsqrt(variance + self.eps)
67
  return self.weight * x
68
 
69
+ def precompute_rope_frequencies(dim: int, max_position_embeddings: int, theta: float = 10000.0):
70
+ position = torch.arange(max_position_embeddings).unsqueeze(1) # [seq_len, 1]
71
+ div_term = theta ** (torch.arange(0, dim, 2).float() / dim) # [dim/2]
72
+ freqs = position / div_term # [seq_len, dim/2]
73
+ return freqs
74
+
75
+ def apply_rotary_embeddings(x: torch.Tensor, freqs: torch.Tensor):
76
+ # x shape: [batch, seq_len, heads, head_dim]
77
+ # freqs shape: [seq_len, head_dim/2]
78
+ x_rot = x.float()
79
+
80
+ # Reshape freqs to match x's dimensions
81
+ freqs = freqs.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim/2]
82
+
83
+ # Split channels for rotation
84
+ x1, x2 = x_rot[..., :x_rot.shape[-1]//2], x_rot[..., x_rot.shape[-1]//2:]
85
+
86
+ # Apply rotary embeddings
87
+ cos = torch.cos(freqs).to(x.device)
88
+ sin = torch.sin(freqs).to(x.device)
89
+
90
+ # Ensure broadcasting dimensions match
91
+ cos = cos.expand_as(x1)
92
+ sin = sin.expand_as(x1)
93
+
94
+ # Rotate x1 and x2
95
+ x1_rot = x1 * cos - x2 * sin
96
+ x2_rot = x2 * cos + x1 * sin
97
+
98
+ # Concatenate back
99
+ return torch.cat([x1_rot, x2_rot], dim=-1).to(x.dtype)
100
+
101
+ class LlamaAttention(nn.Module):
102
+ def __init__(self, config: SmolLM2Config):
103
  super().__init__()
104
  self.hidden_size = config.hidden_size
105
  self.num_heads = config.num_attention_heads
106
+ self.num_kv_heads = config.num_key_value_heads
107
  self.head_dim = config.hidden_size // config.num_attention_heads
108
 
109
+ # Adjust projections to match head dimensions
110
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
111
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
112
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
113
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
114
 
115
+ # Initialize rotary embeddings
116
+ self.register_buffer(
117
+ "rope_freqs",
118
+ precompute_rope_frequencies(
119
+ self.head_dim, # Use full head_dim for frequencies
120
+ config.max_position_embeddings,
121
+ config.rope_theta
122
+ ),
123
+ persistent=False
124
  )
125
+
126
+ def forward(self, hidden_states, attention_mask=None):
127
+ batch_size, seq_length, _ = hidden_states.size()
128
 
129
+ # Project and reshape
130
+ q = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim)
131
+ k = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
132
+ v = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
133
 
134
+ # Apply rotary embeddings
135
+ q = apply_rotary_embeddings(q, self.rope_freqs[:seq_length])
136
+ k = apply_rotary_embeddings(k, self.rope_freqs[:seq_length])
 
137
 
138
+ # Repeat k/v heads if num_kv_heads < num_heads
139
+ if self.num_kv_heads < self.num_heads:
140
+ k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
141
+ v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
 
142
 
143
+ # Scaled dot-product attention
144
+ q = q.transpose(1, 2) # (batch, num_heads, seq_len, head_dim)
145
+ k = k.transpose(1, 2)
146
+ v = v.transpose(1, 2)
147
+
148
+ attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
149
 
150
  if attention_mask is not None:
151
+ attention_scores = attention_scores + attention_mask
152
 
153
+ attention_probs = F.softmax(attention_scores, dim=-1)
154
+ context = torch.matmul(attention_probs, v)
155
+
156
+ context = context.transpose(1, 2).contiguous()
157
+ context = context.view(batch_size, seq_length, -1)
158
 
159
+ return self.o_proj(context)
160
+
161
+ class LlamaMLP(nn.Module):
162
+ def __init__(self, config: SmolLM2Config):
163
+ super().__init__()
164
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
165
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
166
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
167
+ self.act_fn = nn.SiLU()
168
+
169
+ def forward(self, x):
170
+ gate = self.act_fn(self.gate_proj(x))
171
+ up = self.up_proj(x)
172
+ return self.down_proj(gate * up)
173
+
174
+ class LlamaDecoderLayer(nn.Module):
175
+ def __init__(self, config: SmolLM2Config):
176
+ super().__init__()
177
+ self.self_attn = LlamaAttention(config)
178
+ self.mlp = LlamaMLP(config)
179
+ self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
180
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
181
+
182
+ def forward(self, hidden_states, attention_mask=None):
183
+ residual = hidden_states
184
+ hidden_states = self.input_layernorm(hidden_states)
185
+ hidden_states = self.self_attn(hidden_states, attention_mask)
186
  hidden_states = residual + hidden_states
187
 
 
188
  residual = hidden_states
189
  hidden_states = self.post_attention_layernorm(hidden_states)
190
  hidden_states = self.mlp(hidden_states)
 
198
  def __init__(self, config):
199
  super().__init__(config)
200
  self.config = config
201
+
202
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
203
  self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
204
  self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
205
+
206
+ # Add lm_head before weight tying
207
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
208
 
209
+ # Initialize weights
210
+ self.apply(self._init_weights)
211
+
212
+ # Tie weights if configured
213
  if config.tie_word_embeddings:
214
  self.lm_head.weight = self.embed_tokens.weight
215
 
216
+ def _init_weights(self, module):
217
+ if isinstance(module, nn.Linear):
218
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
219
+ if module.bias is not None:
220
+ torch.nn.init.zeros_(module.bias)
221
+ elif isinstance(module, nn.Embedding):
222
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
223
+
224
  def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
225
  hidden_states = self.embed_tokens(input_ids)
226
 
227
+ # Create causal attention mask if none provided
228
+ if attention_mask is None:
229
+ # Create causal mask
230
+ seq_length = input_ids.size(1)
231
+ # [batch_size, 1, seq_length, seq_length]
232
+ causal_mask = torch.triu(
233
+ torch.ones((seq_length, seq_length), dtype=torch.bool, device=input_ids.device),
234
+ diagonal=1
235
+ ).unsqueeze(0).unsqueeze(0)
236
+ attention_mask = torch.zeros(
237
+ (1, 1, seq_length, seq_length),
238
+ dtype=hidden_states.dtype,
239
+ device=hidden_states.device
240
+ )
241
+ attention_mask.masked_fill_(causal_mask, float("-inf"))
242
+
243
  for layer in self.layers:
244
  hidden_states = layer(hidden_states, attention_mask)
245
 
 
258
  "attention_mask": kwargs.get("attention_mask", None)
259
  }
260
 
261
+ def generate(
262
+ self,
263
+ input_ids,
264
+ max_length=100,
265
+ temperature=0.7,
266
+ top_k=50,
267
+ do_sample=True,
268
+ num_return_sequences=1,
269
+ pad_token_id=None,
270
+ eos_token_id=None,
271
+ **kwargs
272
+ ):
273
+ cur_len = input_ids.shape[1]
274
+ batch_size = input_ids.shape[0]
275
+
276
+ if max_length < cur_len:
277
+ max_length = cur_len
278
+
279
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
280
+
281
+ while cur_len < max_length:
282
+ # Prepare model inputs
283
+ model_inputs = self.prepare_inputs_for_generation(input_ids)
284
+
285
+ # Forward pass
286
+ with torch.no_grad():
287
+ outputs = self(**model_inputs)
288
+ next_token_logits = outputs[:, -1, :]
289
+
290
+ # Temperature scaling
291
+ if temperature != 1.0 and temperature > 0:
292
+ next_token_logits = next_token_logits / temperature
293
+
294
+ # Top-k filtering
295
+ if top_k > 0:
296
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
297
+ next_token_logits[indices_to_remove] = float('-inf')
298
+
299
+ # Sample or greedy
300
+ if do_sample:
301
+ probs = F.softmax(next_token_logits, dim=-1)
302
+ next_tokens = torch.multinomial(probs, num_samples=1)
303
+ else:
304
+ next_tokens = torch.argmax(next_token_logits, dim=-1)
305
+ next_tokens = next_tokens.unsqueeze(-1)
306
+
307
+ # Append next tokens
308
+ input_ids = torch.cat([input_ids, next_tokens], dim=-1)
309
+ cur_len = input_ids.shape[1]
310
+
311
+ # Early stopping if all sequences have reached the EOS token
312
+ if eos_token_id is not None:
313
+ unfinished_sequences = unfinished_sequences.mul(
314
+ next_tokens.squeeze(-1).ne(eos_token_id).long()
315
+ )
316
+ if unfinished_sequences.max() == 0:
317
+ break
318
+
319
+ return input_ids
320
+
321
  # Register the model
322
  AutoModelForCausalLM.register(SmolLM2Config, SmolLM2ForCausalLM)
323
 
324
  # Cache for model and tokenizer
325
  MODEL = None
326
  TOKENIZER = None
327
+ CONFIG = None
328
 
329
  def initialize():
330
+ global MODEL, TOKENIZER, CONFIG
331
 
332
  if MODEL is None:
333
  print("Loading model and tokenizer...")
 
339
  config_path = hf_hub_download(repo_id=model_id, filename="config.json")
340
  with open(config_path, 'r') as f:
341
  config_dict = json.load(f)
342
+ CONFIG = SmolLM2Config(**config_dict)
343
 
344
  # Load tokenizer
345
  print("Loading tokenizer...")
346
+ TOKENIZER = AutoTokenizer.from_pretrained(
347
+ model_id,
348
+ model_max_length=CONFIG.max_position_embeddings,
349
+ padding_side="left",
350
+ truncation_side="left",
351
+ trust_remote_code=True
352
+ )
353
 
354
+ # Make sure we're using the correct special tokens
355
  special_tokens = {
356
+ 'bos_token': '<|endoftext|>',
357
+ 'eos_token': '<|endoftext|>',
358
+ 'unk_token': '<|endoftext|>',
359
+ 'pad_token': '<|endoftext|>' # Using endoftext as pad token since it's not specified
360
  }
361
  TOKENIZER.add_special_tokens(special_tokens)
362
 
 
365
  weights_path = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
366
 
367
  # Initialize model
368
+ MODEL = SmolLM2ForCausalLM(CONFIG)
369
+
370
+ # Resize token embeddings to match tokenizer
371
+ MODEL.resize_token_embeddings(len(TOKENIZER))
372
 
373
  # Load state dict
374
  state_dict = torch.load(weights_path, map_location="cpu")
 
402
  prompt = TOKENIZER.bos_token + prompt
403
 
404
  # Encode prompt
405
+ encoded = TOKENIZER.encode_plus(
406
+ prompt,
407
+ add_special_tokens=True,
408
+ return_tensors="pt",
409
+ padding=True,
410
+ truncation=True,
411
+ max_length=CONFIG.max_position_embeddings
412
+ )
413
+ input_ids = encoded["input_ids"].to(MODEL.device)
414
+ attention_mask = encoded["attention_mask"].to(MODEL.device)
415
 
416
  # Generate
417
  with torch.no_grad():
418
  outputs = MODEL.generate(
419
  input_ids,
420
+ attention_mask=attention_mask,
421
+ max_length=min(max_length + len(input_ids[0]), CONFIG.max_position_embeddings),
422
  temperature=max(0.1, min(temperature, 1.0)), # Clamp temperature
423
  top_k=max(1, min(top_k, 100)), # Clamp top_k
424
  do_sample=True if temperature > 0 else False,