File size: 2,822 Bytes
44c4d91
e769dfe
 
 
8b1f0bb
83c9d49
 
ec40b9a
 
44c4d91
bd6741d
 
ec40b9a
bd6741d
 
 
 
 
 
 
83c9d49
ec40b9a
 
bd6741d
ec40b9a
e769dfe
 
ec40b9a
 
83c9d49
bd6741d
83c9d49
 
 
ec40b9a
e769dfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd6741d
e769dfe
bd6741d
 
e769dfe
 
 
 
 
bd6741d
 
 
 
 
 
973ac63
ec40b9a
bd6741d
 
 
 
ec40b9a
 
bd6741d
 
ec40b9a
 
e769dfe
bd6741d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Global dictionary to store loaded models, keyed by model name.
loaded_models = {}
# Global variable to track the currently loaded model's name.
current_model_name = ""

@spaces.GPU
def load_model(model_name: str):
    global loaded_models, current_model_name
    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        loaded_models[model_name] = (model, tokenizer)
        current_model_name = model_name  # update global state
        return f"Model '{model_name}' loaded successfully."
    except Exception as e:
        return f"Failed to load model '{model_name}': {str(e)}"

@spaces.GPU
def generate(prompt, history):
    global loaded_models, current_model_name
    if current_model_name == "" or current_model_name not in loaded_models:
        return "Please load a model first by entering a model name and clicking the Load Model button."
    
    model, tokenizer = loaded_models[current_model_name]
    
    # Prepare the messages (with a system prompt and the user's prompt)
    messages = [
        {"role": "system", "content": "Je bent een vriendelijke, behulpzame assistent."},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=512
    )
    # Remove the input tokens from the generated tokens.
    generated_ids = [
        output_ids[len(input_ids):] 
        for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response

# Build the Gradio UI using Blocks.
with gr.Blocks() as demo:
    gr.Markdown("## Model Loader")
    with gr.Row():
        model_name_input = gr.Textbox(
            label="Model Name", 
            value="agentica-org/DeepScaleR-1.5B-Preview",
            placeholder="Enter model name (e.g., agentica-org/DeepScaleR-1.5B-Preview)"
        )
        load_button = gr.Button("Load Model")
    load_status = gr.Textbox(label="Status", interactive=False)
    
    # When the Load Model button is clicked, load_model is called.
    load_button.click(fn=load_model, inputs=model_name_input, outputs=load_status)
    
    gr.Markdown("## Chat Interface")
    # Create the chat interface without extra_inputs.
    chat_interface = gr.ChatInterface(fn=generate)

demo.launch(share=True)