Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,822 Bytes
44c4d91 e769dfe 8b1f0bb 83c9d49 ec40b9a 44c4d91 bd6741d ec40b9a bd6741d 83c9d49 ec40b9a bd6741d ec40b9a e769dfe ec40b9a 83c9d49 bd6741d 83c9d49 ec40b9a e769dfe bd6741d e769dfe bd6741d e769dfe bd6741d 973ac63 ec40b9a bd6741d ec40b9a bd6741d ec40b9a e769dfe bd6741d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Global dictionary to store loaded models, keyed by model name.
loaded_models = {}
# Global variable to track the currently loaded model's name.
current_model_name = ""
@spaces.GPU
def load_model(model_name: str):
global loaded_models, current_model_name
try:
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
loaded_models[model_name] = (model, tokenizer)
current_model_name = model_name # update global state
return f"Model '{model_name}' loaded successfully."
except Exception as e:
return f"Failed to load model '{model_name}': {str(e)}"
@spaces.GPU
def generate(prompt, history):
global loaded_models, current_model_name
if current_model_name == "" or current_model_name not in loaded_models:
return "Please load a model first by entering a model name and clicking the Load Model button."
model, tokenizer = loaded_models[current_model_name]
# Prepare the messages (with a system prompt and the user's prompt)
messages = [
{"role": "system", "content": "Je bent een vriendelijke, behulpzame assistent."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512
)
# Remove the input tokens from the generated tokens.
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
# Build the Gradio UI using Blocks.
with gr.Blocks() as demo:
gr.Markdown("## Model Loader")
with gr.Row():
model_name_input = gr.Textbox(
label="Model Name",
value="agentica-org/DeepScaleR-1.5B-Preview",
placeholder="Enter model name (e.g., agentica-org/DeepScaleR-1.5B-Preview)"
)
load_button = gr.Button("Load Model")
load_status = gr.Textbox(label="Status", interactive=False)
# When the Load Model button is clicked, load_model is called.
load_button.click(fn=load_model, inputs=model_name_input, outputs=load_status)
gr.Markdown("## Chat Interface")
# Create the chat interface without extra_inputs.
chat_interface = gr.ChatInterface(fn=generate)
demo.launch(share=True)
|