Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
# Load model and tokenizer | |
model_id = "jatingocodeo/SmolLM2" | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
return model, tokenizer | |
def generate_text(prompt, max_length=100, temperature=0.7, top_k=50): | |
# Load model and tokenizer (caching them for subsequent calls) | |
if not hasattr(generate_text, "model"): | |
generate_text.model, generate_text.tokenizer = load_model() | |
# Encode the prompt | |
input_ids = generate_text.tokenizer.encode(prompt, return_tensors="pt") | |
input_ids = input_ids.to(generate_text.model.device) | |
# Generate text | |
with torch.no_grad(): | |
output_ids = generate_text.model.generate( | |
input_ids, | |
max_length=max_length, | |
temperature=temperature, | |
top_k=top_k, | |
pad_token_id=generate_text.tokenizer.pad_token_id, | |
eos_token_id=generate_text.tokenizer.eos_token_id, | |
do_sample=True | |
) | |
# Decode and return the generated text | |
generated_text = generate_text.tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return generated_text | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."), | |
gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top K"), | |
], | |
outputs=gr.Textbox(label="Generated Text"), | |
title="SmolLM2 Text Generator", | |
description="Generate text using the fine-tuned SmolLM2 model", | |
examples=[ | |
["Once upon a time", 100, 0.7, 50], | |
["The quick brown fox", 150, 0.8, 40], | |
["In a galaxy far far away", 200, 0.9, 30], | |
] | |
) | |
if __name__ == "__main__": | |
iface.launch() |