File size: 3,493 Bytes
44c4d91
e769dfe
 
 
01ccf7c
8b1f0bb
01ccf7c
 
 
 
 
 
 
 
 
 
 
44c4d91
bd6741d
 
 
01ccf7c
 
 
 
 
bd6741d
01ccf7c
e769dfe
 
01ccf7c
 
 
 
 
 
 
 
 
e769dfe
01ccf7c
e769dfe
 
 
 
 
 
 
01ccf7c
e769dfe
 
01ccf7c
e769dfe
 
01ccf7c
e769dfe
01ccf7c
 
e769dfe
 
01ccf7c
 
e769dfe
bd6741d
 
 
01ccf7c
bd6741d
973ac63
ec40b9a
bd6741d
 
01ccf7c
 
 
 
 
 
bd6741d
 
01ccf7c
 
 
 
 
 
 
 
 
 
e769dfe
bd6741d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from functools import lru_cache

# Cache the loaded model and tokenizer based on the model name.
@lru_cache(maxsize=1)
def get_model(model_name: str):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    print("Cached model loaded for:", model_name)
    return model, tokenizer

@spaces.GPU
def load_model(model_name: str):
    try:
        # Call the caching function. (This will load the model if not already cached.)
        model, tokenizer = get_model(model_name)
        # Print to verify caching (will show up in the logs).
        print("Loaded model:", model_name)
        return f"Model '{model_name}' loaded successfully.", model_name
    except Exception as e:
        return f"Failed to load model '{model_name}': {str(e)}", ""

@spaces.GPU
def generate_response(prompt, chat_history, current_model_name):
    if current_model_name == "":
        return "Please load a model first by entering a model name and clicking the Load Model button.", current_model_name, chat_history
    try:
        model, tokenizer = get_model(current_model_name)
    except Exception as e:
        return f"Error loading model: {str(e)}", current_model_name, chat_history

    # Prepare conversation messages.
    messages = [
        {"role": "system", "content": "You are a friendly, helpful assistant."},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    generated_ids = model.generate(
        **inputs,
        max_new_tokens=512
    )
    # Strip out the prompt tokens.
    generated_ids = [
        output_ids[len(input_ids):]
        for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    chat_history.append([prompt, response])
    return "", current_model_name, chat_history

with gr.Blocks() as demo:
    gr.Markdown("## Model Loader")
    with gr.Row():
        model_input = gr.Textbox(
            label="Model Name", 
            value="agentica-org/DeepScaleR-1.5B-Preview",
            placeholder="Enter model name (e.g., agentica-org/DeepScaleR-1.5B-Preview)"
        )
        load_button = gr.Button("Load Model")
    status_output = gr.Textbox(label="Status", interactive=False)
    # Hidden state for the model name.
    model_state = gr.State("")

    # When the load button is clicked, update status and state.
    load_button.click(fn=load_model, inputs=model_input, outputs=[status_output, model_state])
    
    gr.Markdown("## Chat Interface")
    chatbot = gr.Chatbot()
    prompt_box = gr.Textbox(placeholder="Enter your prompt here...")

    def chat_submit(prompt, history, current_model_name):
        output, updated_state, history = generate_response(prompt, history, current_model_name)
        return "", updated_state, history

    # When a prompt is submitted, clear the prompt textbox and update chat history and model state.
    prompt_box.submit(fn=chat_submit, inputs=[prompt_box, chatbot, model_state],
                      outputs=[prompt_box, model_state, chatbot])

demo.launch(share=True)