Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from functools import lru_cache | |
# Cache the loaded model and tokenizer based on the model name. | |
def get_model(model_name: str): | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
print("Cached model loaded for:", model_name) | |
return model, tokenizer | |
def load_model(model_name: str): | |
try: | |
# Call the caching function. (This will load the model if not already cached.) | |
model, tokenizer = get_model(model_name) | |
# Print to verify caching (will show up in the logs). | |
print("Loaded model:", model_name) | |
return f"Model '{model_name}' loaded successfully.", model_name | |
except Exception as e: | |
return f"Failed to load model '{model_name}': {str(e)}", "" | |
def generate_response(prompt, chat_history, current_model_name): | |
if current_model_name == "": | |
return "Please load a model first by entering a model name and clicking the Load Model button.", current_model_name, chat_history | |
try: | |
model, tokenizer = get_model(current_model_name) | |
except Exception as e: | |
return f"Error loading model: {str(e)}", current_model_name, chat_history | |
# Prepare conversation messages. | |
messages = [ | |
{"role": "system", "content": "You are a friendly, helpful assistant."}, | |
{"role": "user", "content": prompt} | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=512 | |
) | |
# Strip out the prompt tokens. | |
generated_ids = [ | |
output_ids[len(input_ids):] | |
for input_ids, output_ids in zip(inputs.input_ids, generated_ids) | |
] | |
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
chat_history.append([prompt, response]) | |
return "", current_model_name, chat_history | |
with gr.Blocks() as demo: | |
gr.Markdown("## Model Loader") | |
with gr.Row(): | |
model_input = gr.Textbox( | |
label="Model Name", | |
value="agentica-org/DeepScaleR-1.5B-Preview", | |
placeholder="Enter model name (e.g., agentica-org/DeepScaleR-1.5B-Preview)" | |
) | |
load_button = gr.Button("Load Model") | |
status_output = gr.Textbox(label="Status", interactive=False) | |
# Hidden state for the model name. | |
model_state = gr.State("") | |
# When the load button is clicked, update status and state. | |
load_button.click(fn=load_model, inputs=model_input, outputs=[status_output, model_state]) | |
gr.Markdown("## Chat Interface") | |
chatbot = gr.Chatbot() | |
prompt_box = gr.Textbox(placeholder="Enter your prompt here...") | |
def chat_submit(prompt, history, current_model_name): | |
output, updated_state, history = generate_response(prompt, history, current_model_name) | |
return "", updated_state, history | |
# When a prompt is submitted, clear the prompt textbox and update chat history and model state. | |
prompt_box.submit(fn=chat_submit, inputs=[prompt_box, chatbot, model_state], | |
outputs=[prompt_box, model_state, chatbot]) | |
demo.launch(share=True) | |