Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
# Global dictionary to store loaded models, keyed by model name. | |
loaded_models = {} | |
# Global variable to track the currently loaded model's name. | |
current_model_name = "" | |
def load_model(model_name: str): | |
global loaded_models, current_model_name | |
try: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
loaded_models[model_name] = (model, tokenizer) | |
current_model_name = model_name # update global state | |
return f"Model '{model_name}' loaded successfully." | |
except Exception as e: | |
return f"Failed to load model '{model_name}': {str(e)}" | |
def generate(prompt, history): | |
global loaded_models, current_model_name | |
if current_model_name == "" or current_model_name not in loaded_models: | |
return "Please load a model first by entering a model name and clicking the Load Model button." | |
model, tokenizer = loaded_models[current_model_name] | |
# Prepare the messages (with a system prompt and the user's prompt) | |
messages = [ | |
{"role": "system", "content": "Je bent een vriendelijke, behulpzame assistent."}, | |
{"role": "user", "content": prompt} | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
generated_ids = model.generate( | |
**model_inputs, | |
max_new_tokens=512 | |
) | |
# Remove the input tokens from the generated tokens. | |
generated_ids = [ | |
output_ids[len(input_ids):] | |
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
] | |
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return response | |
# Build the Gradio UI using Blocks. | |
with gr.Blocks() as demo: | |
gr.Markdown("## Model Loader") | |
with gr.Row(): | |
model_name_input = gr.Textbox( | |
label="Model Name", | |
value="agentica-org/DeepScaleR-1.5B-Preview", | |
placeholder="Enter model name (e.g., agentica-org/DeepScaleR-1.5B-Preview)" | |
) | |
load_button = gr.Button("Load Model") | |
load_status = gr.Textbox(label="Status", interactive=False) | |
# When the Load Model button is clicked, load_model is called. | |
load_button.click(fn=load_model, inputs=model_name_input, outputs=load_status) | |
gr.Markdown("## Chat Interface") | |
# Create the chat interface without extra_inputs. | |
chat_interface = gr.ChatInterface(fn=generate) | |
demo.launch(share=True) | |