jatingocodeo commited on
Commit
ae00973
Β·
verified Β·
1 Parent(s): 66d83b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -64
app.py CHANGED
@@ -5,6 +5,8 @@ import torch.nn as nn
5
  import torch.nn.functional as F
6
  import math
7
  import os
 
 
8
 
9
  class RMSNorm(nn.Module):
10
  def __init__(self, hidden_size, eps=1e-5):
@@ -191,98 +193,172 @@ model_id = "jatingocodeo/SmolLM2"
191
 
192
  def load_model():
193
  try:
194
- print("Loading tokenizer...")
195
- tokenizer = AutoTokenizer.from_pretrained(model_id)
196
- print("Tokenizer loaded successfully")
197
 
198
- # Ensure the tokenizer has the necessary special tokens
199
- special_tokens = {
200
- 'pad_token': '[PAD]',
201
- 'eos_token': '</s>',
202
- 'bos_token': '<s>'
203
- }
204
- print("Adding special tokens...")
205
- tokenizer.add_special_tokens(special_tokens)
 
206
 
207
- print("Loading model from Hugging Face Hub...")
208
- # Create config first
209
- config = SmolLM2Config(
210
- pad_token_id=tokenizer.pad_token_id,
211
- bos_token_id=tokenizer.bos_token_id,
212
- eos_token_id=tokenizer.eos_token_id
213
- )
 
 
 
 
 
 
214
 
215
- # Load model from Hub
216
- model = AutoModelForCausalLM.from_pretrained(
217
- model_id,
218
- config=config,
219
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
220
- trust_remote_code=True,
221
- low_cpu_mem_usage=True
222
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- # Move model to device manually
225
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
226
- print(f"Moving model to device: {device}")
227
- model = model.to(device)
 
 
 
 
 
228
 
229
- # Resize token embeddings to match new tokenizer
230
- print("Resizing token embeddings...")
231
- model.resize_token_embeddings(len(tokenizer))
 
 
 
 
 
 
232
 
233
- print("Model loaded successfully!")
234
  return model, tokenizer
 
235
  except Exception as e:
236
- print(f"Error loading model: {str(e)}")
237
- print(f"Error type: {type(e)}")
 
 
238
  import traceback
239
  traceback.print_exc()
 
 
 
 
 
 
 
240
  raise
241
 
242
  def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
243
  try:
244
- print(f"\nGenerating text for prompt: {prompt}")
245
- # Load model and tokenizer (caching them for subsequent calls)
 
 
246
  if not hasattr(generate_text, "model"):
247
- print("First call - loading model...")
248
  generate_text.model, generate_text.tokenizer = load_model()
249
 
250
- # Ensure the prompt is not empty
251
  if not prompt.strip():
 
252
  return "Please enter a prompt."
253
 
254
- # Add BOS token if needed
255
  if not prompt.startswith(generate_text.tokenizer.bos_token):
256
  prompt = generate_text.tokenizer.bos_token + prompt
 
257
 
258
- print("Encoding prompt...")
259
- # Encode the prompt
260
- input_ids = generate_text.tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=2048)
261
- input_ids = input_ids.to(generate_text.model.device)
 
 
 
 
 
262
 
263
- print("Generating text...")
264
- # Generate text
265
- with torch.no_grad():
266
- output_ids = generate_text.model.generate(
267
- input_ids,
268
- max_length=min(max_length + len(input_ids[0]), 2048),
269
- temperature=temperature,
270
- top_k=top_k,
271
- do_sample=True,
272
- pad_token_id=generate_text.tokenizer.pad_token_id,
273
- eos_token_id=generate_text.tokenizer.eos_token_id,
274
- num_return_sequences=1
275
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
- print("Decoding generated text...")
278
- # Decode and return the generated text
279
- generated_text = generate_text.tokenizer.decode(output_ids[0], skip_special_tokens=True)
280
- print("Generation completed successfully!")
281
  return generated_text.strip()
282
 
283
  except Exception as e:
284
- print(f"Error during generation: {str(e)}")
285
- print(f"Error type: {type(e)}")
 
 
286
  import traceback
287
  traceback.print_exc()
288
  return f"An error occurred: {str(e)}"
 
5
  import torch.nn.functional as F
6
  import math
7
  import os
8
+ import sys
9
+ import transformers
10
 
11
  class RMSNorm(nn.Module):
12
  def __init__(self, hidden_size, eps=1e-5):
 
193
 
194
  def load_model():
195
  try:
196
+ print("\n=== Starting model loading process ===")
197
+ print(f"Model ID: {model_id}")
 
198
 
199
+ print("\n1. Loading tokenizer...")
200
+ try:
201
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
202
+ print("βœ“ Tokenizer loaded successfully")
203
+ print(f"Tokenizer type: {type(tokenizer)}")
204
+ print(f"Vocabulary size: {len(tokenizer)}")
205
+ except Exception as e:
206
+ print(f"Γ— Error loading tokenizer: {str(e)}")
207
+ raise
208
 
209
+ print("\n2. Adding special tokens...")
210
+ try:
211
+ special_tokens = {
212
+ 'pad_token': '[PAD]',
213
+ 'eos_token': '</s>',
214
+ 'bos_token': '<s>'
215
+ }
216
+ num_added = tokenizer.add_special_tokens(special_tokens)
217
+ print(f"βœ“ Added {num_added} special tokens")
218
+ print(f"Special tokens: {tokenizer.special_tokens_map}")
219
+ except Exception as e:
220
+ print(f"Γ— Error adding special tokens: {str(e)}")
221
+ raise
222
 
223
+ print("\n3. Creating model configuration...")
224
+ try:
225
+ config = SmolLM2Config(
226
+ pad_token_id=tokenizer.pad_token_id,
227
+ bos_token_id=tokenizer.bos_token_id,
228
+ eos_token_id=tokenizer.eos_token_id
229
+ )
230
+ print("βœ“ Configuration created successfully")
231
+ print(f"Config: {config}")
232
+ except Exception as e:
233
+ print(f"Γ— Error creating configuration: {str(e)}")
234
+ raise
235
+
236
+ print("\n4. Loading model from Hub...")
237
+ try:
238
+ model = AutoModelForCausalLM.from_pretrained(
239
+ model_id,
240
+ config=config,
241
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
242
+ trust_remote_code=True,
243
+ low_cpu_mem_usage=True,
244
+ local_files_only=False # Force download from Hub
245
+ )
246
+ print("βœ“ Model loaded successfully")
247
+ print(f"Model type: {type(model)}")
248
+ except Exception as e:
249
+ print(f"Γ— Error loading model: {str(e)}")
250
+ print("Attempting to print model files in Hub repo...")
251
+ from huggingface_hub import list_repo_files
252
+ try:
253
+ files = list_repo_files(model_id)
254
+ print(f"Files in repo: {files}")
255
+ except Exception as hub_e:
256
+ print(f"Error listing repo files: {str(hub_e)}")
257
+ raise
258
 
259
+ print("\n5. Moving model to device...")
260
+ try:
261
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
262
+ print(f"Selected device: {device}")
263
+ model = model.to(device)
264
+ print("βœ“ Model moved to device successfully")
265
+ except Exception as e:
266
+ print(f"Γ— Error moving model to device: {str(e)}")
267
+ raise
268
 
269
+ print("\n6. Resizing token embeddings...")
270
+ try:
271
+ old_size = model.get_input_embeddings().weight.shape[0]
272
+ model.resize_token_embeddings(len(tokenizer))
273
+ new_size = model.get_input_embeddings().weight.shape[0]
274
+ print(f"βœ“ Token embeddings resized from {old_size} to {new_size}")
275
+ except Exception as e:
276
+ print(f"Γ— Error resizing token embeddings: {str(e)}")
277
+ raise
278
 
279
+ print("\n=== Model loading completed successfully! ===")
280
  return model, tokenizer
281
+
282
  except Exception as e:
283
+ print("\n!!! ERROR IN MODEL LOADING !!!")
284
+ print(f"Error type: {type(e).__name__}")
285
+ print(f"Error message: {str(e)}")
286
+ print("\nFull traceback:")
287
  import traceback
288
  traceback.print_exc()
289
+ print("\nAdditional debug info:")
290
+ print(f"Python version: {sys.version}")
291
+ print(f"PyTorch version: {torch.__version__}")
292
+ print(f"Transformers version: {transformers.__version__}")
293
+ print(f"CUDA available: {torch.cuda.is_available()}")
294
+ if torch.cuda.is_available():
295
+ print(f"CUDA version: {torch.version.cuda}")
296
  raise
297
 
298
  def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
299
  try:
300
+ print("\n=== Starting text generation ===")
301
+ print(f"Input prompt: {prompt}")
302
+ print(f"Parameters: max_length={max_length}, temperature={temperature}, top_k={top_k}")
303
+
304
  if not hasattr(generate_text, "model"):
305
+ print("\n1. First call - loading model...")
306
  generate_text.model, generate_text.tokenizer = load_model()
307
 
 
308
  if not prompt.strip():
309
+ print("Γ— Empty prompt received")
310
  return "Please enter a prompt."
311
 
312
+ print("\n2. Processing prompt...")
313
  if not prompt.startswith(generate_text.tokenizer.bos_token):
314
  prompt = generate_text.tokenizer.bos_token + prompt
315
+ print("Added BOS token to prompt")
316
 
317
+ print("\n3. Encoding prompt...")
318
+ try:
319
+ input_ids = generate_text.tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=2048)
320
+ print(f"Encoded shape: {input_ids.shape}")
321
+ input_ids = input_ids.to(generate_text.model.device)
322
+ print("βœ“ Encoding successful")
323
+ except Exception as e:
324
+ print(f"Γ— Error encoding prompt: {str(e)}")
325
+ raise
326
 
327
+ print("\n4. Generating text...")
328
+ try:
329
+ with torch.no_grad():
330
+ output_ids = generate_text.model.generate(
331
+ input_ids,
332
+ max_length=min(max_length + len(input_ids[0]), 2048),
333
+ temperature=temperature,
334
+ top_k=top_k,
335
+ do_sample=True,
336
+ pad_token_id=generate_text.tokenizer.pad_token_id,
337
+ eos_token_id=generate_text.tokenizer.eos_token_id,
338
+ num_return_sequences=1
339
+ )
340
+ print(f"Generation shape: {output_ids.shape}")
341
+ except Exception as e:
342
+ print(f"Γ— Error during generation: {str(e)}")
343
+ raise
344
+
345
+ print("\n5. Decoding output...")
346
+ try:
347
+ generated_text = generate_text.tokenizer.decode(output_ids[0], skip_special_tokens=True)
348
+ print("βœ“ Decoding successful")
349
+ print(f"Output length: {len(generated_text)}")
350
+ except Exception as e:
351
+ print(f"Γ— Error decoding output: {str(e)}")
352
+ raise
353
 
354
+ print("\n=== Generation completed successfully! ===")
 
 
 
355
  return generated_text.strip()
356
 
357
  except Exception as e:
358
+ print("\n!!! ERROR IN TEXT GENERATION !!!")
359
+ print(f"Error type: {type(e).__name__}")
360
+ print(f"Error message: {str(e)}")
361
+ print("\nFull traceback:")
362
  import traceback
363
  traceback.print_exc()
364
  return f"An error occurred: {str(e)}"