jatingocodeo commited on
Commit
6fa5efe
·
verified ·
1 Parent(s): eccb044

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -13
app.py CHANGED
@@ -169,17 +169,31 @@ class SmolLM2ForCausalLM(PreTrainedModel):
169
 
170
  hidden_states = self.embed_tokens(input_ids)
171
 
172
- # Create causal attention mask if none provided
173
- if attention_mask is None:
174
- attention_mask = torch.triu(
175
- torch.ones((input_ids.size(1), input_ids.size(1)), dtype=torch.bool, device=input_ids.device),
176
- diagonal=1
177
- )
178
- attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
179
- attention_mask = attention_mask * -1e4
 
 
 
 
180
 
 
 
 
 
 
 
 
 
 
 
181
  for layer in self.layers:
182
- hidden_states = layer(hidden_states, attention_mask)
183
 
184
  hidden_states = self.norm(hidden_states)
185
  logits = self.lm_head(hidden_states)
@@ -200,11 +214,13 @@ class SmolLM2ForCausalLM(PreTrainedModel):
200
  )
201
  return (loss, logits) if loss is not None else logits
202
 
203
- def prepare_inputs_for_generation(self, input_ids, **kwargs):
204
- return {
 
205
  "input_ids": input_ids,
206
- "attention_mask": kwargs.get("attention_mask", None)
207
  }
 
208
 
209
  # Register the model architecture
210
  from transformers import AutoConfig
@@ -272,17 +288,22 @@ def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
272
  input_ids = TOKENIZER.encode(prompt, return_tensors="pt", truncation=True, max_length=2048)
273
  input_ids = input_ids.to(MODEL.device)
274
 
 
 
 
275
  # Generate
276
  with torch.no_grad():
277
  output_ids = MODEL.generate(
278
  input_ids,
 
279
  max_length=min(max_length + len(input_ids[0]), 2048),
280
  temperature=temperature,
281
  top_k=top_k,
282
  do_sample=True,
283
  pad_token_id=TOKENIZER.pad_token_id,
284
  eos_token_id=TOKENIZER.eos_token_id,
285
- num_return_sequences=1
 
286
  )
287
 
288
  # Decode and return
@@ -290,6 +311,9 @@ def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
290
  return generated_text.strip()
291
 
292
  except Exception as e:
 
 
 
293
  return f"Error generating text: {str(e)}"
294
 
295
  # Initialize on startup
 
169
 
170
  hidden_states = self.embed_tokens(input_ids)
171
 
172
+ # Create causal attention mask
173
+ batch_size, seq_length = input_ids.size()
174
+ device = input_ids.device
175
+
176
+ # Create causal mask
177
+ causal_mask = torch.triu(
178
+ torch.ones((seq_length, seq_length), dtype=torch.bool, device=device),
179
+ diagonal=1
180
+ )
181
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
182
+ causal_mask = causal_mask.expand(batch_size, 1, seq_length, seq_length)
183
+ causal_mask = causal_mask.to(dtype=torch.float32) * -1e4
184
 
185
+ # Combine with attention mask if provided
186
+ if attention_mask is not None:
187
+ # Convert attention mask to float and unsqueeze
188
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len]
189
+ attention_mask = attention_mask.expand(batch_size, 1, seq_length, seq_length)
190
+ attention_mask = (1.0 - attention_mask) * -1e4
191
+ # Combine masks
192
+ causal_mask = causal_mask + attention_mask
193
+
194
+ # Process through layers
195
  for layer in self.layers:
196
+ hidden_states = layer(hidden_states, causal_mask)
197
 
198
  hidden_states = self.norm(hidden_states)
199
  logits = self.lm_head(hidden_states)
 
214
  )
215
  return (loss, logits) if loss is not None else logits
216
 
217
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
218
+ # Only return what we need
219
+ inputs = {
220
  "input_ids": input_ids,
221
+ "attention_mask": attention_mask
222
  }
223
+ return inputs
224
 
225
  # Register the model architecture
226
  from transformers import AutoConfig
 
288
  input_ids = TOKENIZER.encode(prompt, return_tensors="pt", truncation=True, max_length=2048)
289
  input_ids = input_ids.to(MODEL.device)
290
 
291
+ # Create attention mask
292
+ attention_mask = torch.ones_like(input_ids)
293
+
294
  # Generate
295
  with torch.no_grad():
296
  output_ids = MODEL.generate(
297
  input_ids,
298
+ attention_mask=attention_mask,
299
  max_length=min(max_length + len(input_ids[0]), 2048),
300
  temperature=temperature,
301
  top_k=top_k,
302
  do_sample=True,
303
  pad_token_id=TOKENIZER.pad_token_id,
304
  eos_token_id=TOKENIZER.eos_token_id,
305
+ num_return_sequences=1,
306
+ use_cache=True
307
  )
308
 
309
  # Decode and return
 
311
  return generated_text.strip()
312
 
313
  except Exception as e:
314
+ print(f"Generation error details: {str(e)}")
315
+ import traceback
316
+ traceback.print_exc()
317
  return f"Error generating text: {str(e)}"
318
 
319
  # Initialize on startup