Spaces:
Sleeping
Sleeping
File size: 3,493 Bytes
44c4d91 e769dfe 01ccf7c 8b1f0bb 01ccf7c 44c4d91 bd6741d 01ccf7c bd6741d 01ccf7c e769dfe 01ccf7c e769dfe 01ccf7c e769dfe 01ccf7c e769dfe 01ccf7c e769dfe 01ccf7c e769dfe 01ccf7c e769dfe 01ccf7c e769dfe bd6741d 01ccf7c bd6741d 973ac63 ec40b9a bd6741d 01ccf7c bd6741d 01ccf7c e769dfe bd6741d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from functools import lru_cache
# Cache the loaded model and tokenizer based on the model name.
@lru_cache(maxsize=1)
def get_model(model_name: str):
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Cached model loaded for:", model_name)
return model, tokenizer
@spaces.GPU
def load_model(model_name: str):
try:
# Call the caching function. (This will load the model if not already cached.)
model, tokenizer = get_model(model_name)
# Print to verify caching (will show up in the logs).
print("Loaded model:", model_name)
return f"Model '{model_name}' loaded successfully.", model_name
except Exception as e:
return f"Failed to load model '{model_name}': {str(e)}", ""
@spaces.GPU
def generate_response(prompt, chat_history, current_model_name):
if current_model_name == "":
return "Please load a model first by entering a model name and clicking the Load Model button.", current_model_name, chat_history
try:
model, tokenizer = get_model(current_model_name)
except Exception as e:
return f"Error loading model: {str(e)}", current_model_name, chat_history
# Prepare conversation messages.
messages = [
{"role": "system", "content": "You are a friendly, helpful assistant."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**inputs,
max_new_tokens=512
)
# Strip out the prompt tokens.
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
chat_history.append([prompt, response])
return "", current_model_name, chat_history
with gr.Blocks() as demo:
gr.Markdown("## Model Loader")
with gr.Row():
model_input = gr.Textbox(
label="Model Name",
value="agentica-org/DeepScaleR-1.5B-Preview",
placeholder="Enter model name (e.g., agentica-org/DeepScaleR-1.5B-Preview)"
)
load_button = gr.Button("Load Model")
status_output = gr.Textbox(label="Status", interactive=False)
# Hidden state for the model name.
model_state = gr.State("")
# When the load button is clicked, update status and state.
load_button.click(fn=load_model, inputs=model_input, outputs=[status_output, model_state])
gr.Markdown("## Chat Interface")
chatbot = gr.Chatbot()
prompt_box = gr.Textbox(placeholder="Enter your prompt here...")
def chat_submit(prompt, history, current_model_name):
output, updated_state, history = generate_response(prompt, history, current_model_name)
return "", updated_state, history
# When a prompt is submitted, clear the prompt textbox and update chat history and model state.
prompt_box.submit(fn=chat_submit, inputs=[prompt_box, chatbot, model_state],
outputs=[prompt_box, model_state, chatbot])
demo.launch(share=True)
|