TakiTakiTa's picture
Update app.py
ec40b9a verified
raw
history blame
2.82 kB
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Global dictionary to store loaded models, keyed by model name.
loaded_models = {}
# Global variable to track the currently loaded model's name.
current_model_name = ""
@spaces.GPU
def load_model(model_name: str):
global loaded_models, current_model_name
try:
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
loaded_models[model_name] = (model, tokenizer)
current_model_name = model_name # update global state
return f"Model '{model_name}' loaded successfully."
except Exception as e:
return f"Failed to load model '{model_name}': {str(e)}"
@spaces.GPU
def generate(prompt, history):
global loaded_models, current_model_name
if current_model_name == "" or current_model_name not in loaded_models:
return "Please load a model first by entering a model name and clicking the Load Model button."
model, tokenizer = loaded_models[current_model_name]
# Prepare the messages (with a system prompt and the user's prompt)
messages = [
{"role": "system", "content": "Je bent een vriendelijke, behulpzame assistent."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512
)
# Remove the input tokens from the generated tokens.
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
# Build the Gradio UI using Blocks.
with gr.Blocks() as demo:
gr.Markdown("## Model Loader")
with gr.Row():
model_name_input = gr.Textbox(
label="Model Name",
value="agentica-org/DeepScaleR-1.5B-Preview",
placeholder="Enter model name (e.g., agentica-org/DeepScaleR-1.5B-Preview)"
)
load_button = gr.Button("Load Model")
load_status = gr.Textbox(label="Status", interactive=False)
# When the Load Model button is clicked, load_model is called.
load_button.click(fn=load_model, inputs=model_name_input, outputs=load_status)
gr.Markdown("## Chat Interface")
# Create the chat interface without extra_inputs.
chat_interface = gr.ChatInterface(fn=generate)
demo.launch(share=True)