jatingocodeo commited on
Commit
fee88b4
·
verified ·
1 Parent(s): 11705f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -33
app.py CHANGED
@@ -6,56 +6,89 @@ import torch
6
  model_id = "jatingocodeo/SmolLM2"
7
 
8
  def load_model():
9
- tokenizer = AutoTokenizer.from_pretrained(model_id)
10
- model = AutoModelForCausalLM.from_pretrained(
11
- model_id,
12
- torch_dtype=torch.float16,
13
- device_map="auto"
14
- )
15
- return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
18
- # Load model and tokenizer (caching them for subsequent calls)
19
- if not hasattr(generate_text, "model"):
20
- generate_text.model, generate_text.tokenizer = load_model()
21
-
22
- # Encode the prompt
23
- input_ids = generate_text.tokenizer.encode(prompt, return_tensors="pt")
24
- input_ids = input_ids.to(generate_text.model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # Generate text
27
- with torch.no_grad():
28
- output_ids = generate_text.model.generate(
29
- input_ids,
30
- max_length=max_length,
31
- temperature=temperature,
32
- top_k=top_k,
33
- pad_token_id=generate_text.tokenizer.pad_token_id,
34
- eos_token_id=generate_text.tokenizer.eos_token_id,
35
- do_sample=True
36
- )
37
-
38
- # Decode and return the generated text
39
- generated_text = generate_text.tokenizer.decode(output_ids[0], skip_special_tokens=True)
40
- return generated_text
41
 
42
  # Create Gradio interface
43
  iface = gr.Interface(
44
  fn=generate_text,
45
  inputs=[
46
- gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
47
  gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
48
  gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
49
  gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top K"),
50
  ],
51
- outputs=gr.Textbox(label="Generated Text"),
52
  title="SmolLM2 Text Generator",
53
- description="Generate text using the fine-tuned SmolLM2 model",
 
 
 
54
  examples=[
55
  ["Once upon a time", 100, 0.7, 50],
56
  ["The quick brown fox", 150, 0.8, 40],
57
  ["In a galaxy far far away", 200, 0.9, 30],
58
- ]
 
59
  )
60
 
61
  if __name__ == "__main__":
 
6
  model_id = "jatingocodeo/SmolLM2"
7
 
8
  def load_model():
9
+ try:
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
11
+ # Ensure the tokenizer has the necessary special tokens
12
+ special_tokens = {
13
+ 'pad_token': '[PAD]',
14
+ 'eos_token': '</s>',
15
+ 'bos_token': '<s>'
16
+ }
17
+ tokenizer.add_special_tokens(special_tokens)
18
+
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ model_id,
21
+ torch_dtype=torch.float16,
22
+ device_map="auto",
23
+ pad_token_id=tokenizer.pad_token_id
24
+ )
25
+ # Resize token embeddings to match new tokenizer
26
+ model.resize_token_embeddings(len(tokenizer))
27
+ return model, tokenizer
28
+ except Exception as e:
29
+ print(f"Error loading model: {str(e)}")
30
+ raise
31
 
32
  def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
33
+ try:
34
+ # Load model and tokenizer (caching them for subsequent calls)
35
+ if not hasattr(generate_text, "model"):
36
+ generate_text.model, generate_text.tokenizer = load_model()
37
+
38
+ # Ensure the prompt is not empty
39
+ if not prompt.strip():
40
+ return "Please enter a prompt."
41
+
42
+ # Add BOS token if needed
43
+ if not prompt.startswith(generate_text.tokenizer.bos_token):
44
+ prompt = generate_text.tokenizer.bos_token + prompt
45
+
46
+ # Encode the prompt
47
+ input_ids = generate_text.tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=2048)
48
+ input_ids = input_ids.to(generate_text.model.device)
49
+
50
+ # Generate text
51
+ with torch.no_grad():
52
+ output_ids = generate_text.model.generate(
53
+ input_ids,
54
+ max_length=min(max_length + len(input_ids[0]), 2048), # Respect model's max length
55
+ temperature=temperature,
56
+ top_k=top_k,
57
+ do_sample=True,
58
+ pad_token_id=generate_text.tokenizer.pad_token_id,
59
+ eos_token_id=generate_text.tokenizer.eos_token_id,
60
+ num_return_sequences=1
61
+ )
62
+
63
+ # Decode and return the generated text
64
+ generated_text = generate_text.tokenizer.decode(output_ids[0], skip_special_tokens=True)
65
+ return generated_text.strip()
66
 
67
+ except Exception as e:
68
+ print(f"Error during generation: {str(e)}")
69
+ return f"An error occurred: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  # Create Gradio interface
72
  iface = gr.Interface(
73
  fn=generate_text,
74
  inputs=[
75
+ gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=2),
76
  gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
77
  gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
78
  gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top K"),
79
  ],
80
+ outputs=gr.Textbox(label="Generated Text", lines=5),
81
  title="SmolLM2 Text Generator",
82
+ description="""Generate text using the fine-tuned SmolLM2 model.
83
+ - Max Length: Controls the length of generated text
84
+ - Temperature: Controls randomness (higher = more creative)
85
+ - Top K: Controls diversity of word choices""",
86
  examples=[
87
  ["Once upon a time", 100, 0.7, 50],
88
  ["The quick brown fox", 150, 0.8, 40],
89
  ["In a galaxy far far away", 200, 0.9, 30],
90
+ ],
91
+ allow_flagging="never"
92
  )
93
 
94
  if __name__ == "__main__":