TakiTakiTa's picture
Update app.py
01ccf7c verified
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from functools import lru_cache
# Cache the loaded model and tokenizer based on the model name.
@lru_cache(maxsize=1)
def get_model(model_name: str):
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Cached model loaded for:", model_name)
return model, tokenizer
@spaces.GPU
def load_model(model_name: str):
try:
# Call the caching function. (This will load the model if not already cached.)
model, tokenizer = get_model(model_name)
# Print to verify caching (will show up in the logs).
print("Loaded model:", model_name)
return f"Model '{model_name}' loaded successfully.", model_name
except Exception as e:
return f"Failed to load model '{model_name}': {str(e)}", ""
@spaces.GPU
def generate_response(prompt, chat_history, current_model_name):
if current_model_name == "":
return "Please load a model first by entering a model name and clicking the Load Model button.", current_model_name, chat_history
try:
model, tokenizer = get_model(current_model_name)
except Exception as e:
return f"Error loading model: {str(e)}", current_model_name, chat_history
# Prepare conversation messages.
messages = [
{"role": "system", "content": "You are a friendly, helpful assistant."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**inputs,
max_new_tokens=512
)
# Strip out the prompt tokens.
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
chat_history.append([prompt, response])
return "", current_model_name, chat_history
with gr.Blocks() as demo:
gr.Markdown("## Model Loader")
with gr.Row():
model_input = gr.Textbox(
label="Model Name",
value="agentica-org/DeepScaleR-1.5B-Preview",
placeholder="Enter model name (e.g., agentica-org/DeepScaleR-1.5B-Preview)"
)
load_button = gr.Button("Load Model")
status_output = gr.Textbox(label="Status", interactive=False)
# Hidden state for the model name.
model_state = gr.State("")
# When the load button is clicked, update status and state.
load_button.click(fn=load_model, inputs=model_input, outputs=[status_output, model_state])
gr.Markdown("## Chat Interface")
chatbot = gr.Chatbot()
prompt_box = gr.Textbox(placeholder="Enter your prompt here...")
def chat_submit(prompt, history, current_model_name):
output, updated_state, history = generate_response(prompt, history, current_model_name)
return "", updated_state, history
# When a prompt is submitted, clear the prompt textbox and update chat history and model state.
prompt_box.submit(fn=chat_submit, inputs=[prompt_box, chatbot, model_state],
outputs=[prompt_box, model_state, chatbot])
demo.launch(share=True)