import gradio as gr import spaces from transformers import AutoModelForCausalLM, AutoTokenizer import torch from functools import lru_cache # Cache the loaded model and tokenizer based on the model name. @lru_cache(maxsize=1) def get_model(model_name: str): model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained(model_name) print("Cached model loaded for:", model_name) return model, tokenizer @spaces.GPU def load_model(model_name: str): try: # Call the caching function. (This will load the model if not already cached.) model, tokenizer = get_model(model_name) # Print to verify caching (will show up in the logs). print("Loaded model:", model_name) return f"Model '{model_name}' loaded successfully.", model_name except Exception as e: return f"Failed to load model '{model_name}': {str(e)}", "" @spaces.GPU def generate_response(prompt, chat_history, current_model_name): if current_model_name == "": return "Please load a model first by entering a model name and clicking the Load Model button.", current_model_name, chat_history try: model, tokenizer = get_model(current_model_name) except Exception as e: return f"Error loading model: {str(e)}", current_model_name, chat_history # Prepare conversation messages. messages = [ {"role": "system", "content": "You are a friendly, helpful assistant."}, {"role": "user", "content": prompt} ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer([text], return_tensors="pt").to(model.device) generated_ids = model.generate( **inputs, max_new_tokens=512 ) # Strip out the prompt tokens. generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids) ] response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] chat_history.append([prompt, response]) return "", current_model_name, chat_history with gr.Blocks() as demo: gr.Markdown("## Model Loader") with gr.Row(): model_input = gr.Textbox( label="Model Name", value="agentica-org/DeepScaleR-1.5B-Preview", placeholder="Enter model name (e.g., agentica-org/DeepScaleR-1.5B-Preview)" ) load_button = gr.Button("Load Model") status_output = gr.Textbox(label="Status", interactive=False) # Hidden state for the model name. model_state = gr.State("") # When the load button is clicked, update status and state. load_button.click(fn=load_model, inputs=model_input, outputs=[status_output, model_state]) gr.Markdown("## Chat Interface") chatbot = gr.Chatbot() prompt_box = gr.Textbox(placeholder="Enter your prompt here...") def chat_submit(prompt, history, current_model_name): output, updated_state, history = generate_response(prompt, history, current_model_name) return "", updated_state, history # When a prompt is submitted, clear the prompt textbox and update chat history and model state. prompt_box.submit(fn=chat_submit, inputs=[prompt_box, chatbot, model_state], outputs=[prompt_box, model_state, chatbot]) demo.launch(share=True)