jatingocodeo commited on
Commit
d370ed4
·
verified ·
1 Parent(s): 1ab2f15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -7
app.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import math
 
7
 
8
  class RMSNorm(nn.Module):
9
  def __init__(self, hidden_size, eps=1e-5):
@@ -190,37 +191,56 @@ model_id = "jatingocodeo/SmolLM2"
190
 
191
  def load_model():
192
  try:
 
193
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
194
  # Ensure the tokenizer has the necessary special tokens
195
  special_tokens = {
196
  'pad_token': '[PAD]',
197
  'eos_token': '</s>',
198
  'bos_token': '<s>'
199
  }
 
200
  tokenizer.add_special_tokens(special_tokens)
201
 
202
- # Load model without device_map
203
- model = AutoModelForCausalLM.from_pretrained(
204
- model_id,
205
- torch_dtype=torch.float16,
206
- pad_token_id=tokenizer.pad_token_id
 
 
 
 
 
207
  )
 
208
 
209
  # Move model to device manually
210
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
211
  model = model.to(device)
212
 
213
  # Resize token embeddings to match new tokenizer
 
214
  model.resize_token_embeddings(len(tokenizer))
 
 
215
  return model, tokenizer
216
  except Exception as e:
217
  print(f"Error loading model: {str(e)}")
 
 
 
218
  raise
219
 
220
  def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
221
  try:
 
222
  # Load model and tokenizer (caching them for subsequent calls)
223
  if not hasattr(generate_text, "model"):
 
224
  generate_text.model, generate_text.tokenizer = load_model()
225
 
226
  # Ensure the prompt is not empty
@@ -231,15 +251,17 @@ def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
231
  if not prompt.startswith(generate_text.tokenizer.bos_token):
232
  prompt = generate_text.tokenizer.bos_token + prompt
233
 
 
234
  # Encode the prompt
235
  input_ids = generate_text.tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=2048)
236
  input_ids = input_ids.to(generate_text.model.device)
237
 
 
238
  # Generate text
239
  with torch.no_grad():
240
  output_ids = generate_text.model.generate(
241
  input_ids,
242
- max_length=min(max_length + len(input_ids[0]), 2048), # Respect model's max length
243
  temperature=temperature,
244
  top_k=top_k,
245
  do_sample=True,
@@ -248,12 +270,17 @@ def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
248
  num_return_sequences=1
249
  )
250
 
 
251
  # Decode and return the generated text
252
  generated_text = generate_text.tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
253
  return generated_text.strip()
254
 
255
  except Exception as e:
256
  print(f"Error during generation: {str(e)}")
 
 
 
257
  return f"An error occurred: {str(e)}"
258
 
259
  # Create Gradio interface
@@ -280,4 +307,5 @@ iface = gr.Interface(
280
  )
281
 
282
  if __name__ == "__main__":
283
- iface.launch()
 
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import math
7
+ import os
8
 
9
  class RMSNorm(nn.Module):
10
  def __init__(self, hidden_size, eps=1e-5):
 
191
 
192
  def load_model():
193
  try:
194
+ print("Loading tokenizer...")
195
  tokenizer = AutoTokenizer.from_pretrained(model_id)
196
+ print("Tokenizer loaded successfully")
197
+
198
  # Ensure the tokenizer has the necessary special tokens
199
  special_tokens = {
200
  'pad_token': '[PAD]',
201
  'eos_token': '</s>',
202
  'bos_token': '<s>'
203
  }
204
+ print("Adding special tokens...")
205
  tokenizer.add_special_tokens(special_tokens)
206
 
207
+ print("Loading model configuration...")
208
+ config = SmolLM2Config()
209
+
210
+ print("Initializing model...")
211
+ model = SmolLM2ForCausalLM(config)
212
+
213
+ print("Loading model weights...")
214
+ state_dict = torch.load(
215
+ os.path.join(model_id, "pytorch_model.bin"),
216
+ map_location="cpu"
217
  )
218
+ model.load_state_dict(state_dict)
219
 
220
  # Move model to device manually
221
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
222
+ print(f"Moving model to device: {device}")
223
  model = model.to(device)
224
 
225
  # Resize token embeddings to match new tokenizer
226
+ print("Resizing token embeddings...")
227
  model.resize_token_embeddings(len(tokenizer))
228
+
229
+ print("Model loaded successfully!")
230
  return model, tokenizer
231
  except Exception as e:
232
  print(f"Error loading model: {str(e)}")
233
+ print(f"Error type: {type(e)}")
234
+ import traceback
235
+ traceback.print_exc()
236
  raise
237
 
238
  def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
239
  try:
240
+ print(f"\nGenerating text for prompt: {prompt}")
241
  # Load model and tokenizer (caching them for subsequent calls)
242
  if not hasattr(generate_text, "model"):
243
+ print("First call - loading model...")
244
  generate_text.model, generate_text.tokenizer = load_model()
245
 
246
  # Ensure the prompt is not empty
 
251
  if not prompt.startswith(generate_text.tokenizer.bos_token):
252
  prompt = generate_text.tokenizer.bos_token + prompt
253
 
254
+ print("Encoding prompt...")
255
  # Encode the prompt
256
  input_ids = generate_text.tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=2048)
257
  input_ids = input_ids.to(generate_text.model.device)
258
 
259
+ print("Generating text...")
260
  # Generate text
261
  with torch.no_grad():
262
  output_ids = generate_text.model.generate(
263
  input_ids,
264
+ max_length=min(max_length + len(input_ids[0]), 2048),
265
  temperature=temperature,
266
  top_k=top_k,
267
  do_sample=True,
 
270
  num_return_sequences=1
271
  )
272
 
273
+ print("Decoding generated text...")
274
  # Decode and return the generated text
275
  generated_text = generate_text.tokenizer.decode(output_ids[0], skip_special_tokens=True)
276
+ print("Generation completed successfully!")
277
  return generated_text.strip()
278
 
279
  except Exception as e:
280
  print(f"Error during generation: {str(e)}")
281
+ print(f"Error type: {type(e)}")
282
+ import traceback
283
+ traceback.print_exc()
284
  return f"An error occurred: {str(e)}"
285
 
286
  # Create Gradio interface
 
307
  )
308
 
309
  if __name__ == "__main__":
310
+ print("Starting Gradio interface...")
311
+ iface.launch(debug=True)