Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList | |
import spaces | |
import os | |
import json | |
from huggingface_hub import login | |
# Hugging Face authentication | |
HF_TOKEN = os.getenv('Secrets.HF_TOKEN') | |
try: | |
login(token=HF_TOKEN) | |
except Exception as e: | |
print(f"Error logging in to Hugging Face: {str(e)}") | |
# File to store model links | |
MODEL_FILE = "model_links.txt" | |
def load_model_links(): | |
"""Load model links from file""" | |
if not os.path.exists(MODEL_FILE): | |
# Create default file with some example models | |
with open(MODEL_FILE, "w") as f: | |
f.write("meta-llama/Llama-2-7b-chat-hf\n") | |
with open(MODEL_FILE, "r") as f: | |
return [line.strip() for line in f.readlines() if line.strip()] | |
class ModelManager: | |
def __init__(self): | |
self.current_model = None | |
self.current_tokenizer = None | |
self.current_model_name = None | |
# Don't initialize CUDA in __init__ | |
self.device = None | |
def load_model(self, model_name): | |
"""Load model and free previous model's memory""" | |
if self.current_model is not None: | |
del self.current_model | |
del self.current_tokenizer | |
torch.cuda.empty_cache() | |
try: | |
self.current_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.current_model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
load_in_4bit=False, | |
torch_dtype=torch.bfloat16, | |
device_map="auto" # Let the model decide device mapping | |
) | |
self.current_model_name = model_name | |
return f"Successfully loaded model: {model_name}" | |
except Exception as e: | |
return f"Error loading model: {str(e)}" | |
def generate(self, prompt): | |
"""Helper method for generation""" | |
inputs = self.current_tokenizer(prompt, return_tensors="pt") | |
# Let device mapping happen automatically | |
return inputs | |
# Initialize model manager | |
model_manager = ModelManager() | |
# Default system message for JSON output | |
default_system_message = """You are a helpful AI assistant. You must ALWAYS return your response in valid JSON format. | |
Each response should be formatted as follows: | |
{ | |
"response": { | |
"main_answer": "Your primary response here", | |
"additional_details": "Any additional information or context", | |
"confidence": 0.0 to 1.0, | |
"tags": ["relevant", "tags", "here"] | |
}, | |
"metadata": { | |
"response_type": "type of response", | |
"source": "basis of response if applicable" | |
} | |
} | |
Ensure EVERY response strictly follows this JSON structure.""" | |
# This decorator handles the GPU allocation | |
def generate_response(model_name, system_instruction, user_input): | |
"""Generate response with GPU support and JSON formatting""" | |
if model_manager.current_model_name != model_name: | |
return json.dumps({"error": "Please load the model first using the 'Load Selected Model' button."}, indent=2) | |
if model_manager.current_model is None: | |
return json.dumps({"error": "No model loaded. Please load a model first."}, indent=2) | |
prompt = f"""### Instruction: | |
{system_instruction} | |
Remember to ALWAYS format your response as valid JSON. | |
### Input: | |
{user_input} | |
### Response: | |
{{""" | |
try: | |
inputs = model_manager.generate(prompt) | |
meta_config = { | |
"do_sample": False, | |
"temperature": 0.0, | |
"max_new_tokens": 512, | |
"repetition_penalty": 1.1, | |
"use_cache": True, | |
"pad_token_id": model_manager.current_tokenizer.eos_token_id, | |
"eos_token_id": model_manager.current_tokenizer.eos_token_id | |
} | |
generation_config = GenerationConfig(**meta_config) | |
with torch.no_grad(): | |
outputs = model_manager.current_model.generate( | |
**inputs, | |
generation_config=generation_config | |
) | |
decoded_output = model_manager.current_tokenizer.batch_decode( | |
outputs, | |
skip_special_tokens=True | |
)[0] | |
assistant_response = decoded_output.split("### Response:")[-1].strip() | |
try: | |
last_brace = assistant_response.rindex('}') | |
assistant_response = assistant_response[:last_brace + 1] | |
json_response = json.loads(assistant_response) | |
return json.dumps(json_response, indent=2) | |
except (json.JSONDecodeError, ValueError): | |
return json.dumps({ | |
"error": "Failed to generate valid JSON", | |
"raw_response": assistant_response | |
}, indent=2) | |
except Exception as e: | |
return json.dumps({ | |
"error": f"Error generating response: {str(e)}", | |
"details": "An unexpected error occurred during generation" | |
}, indent=2) | |
# Gradio interface setup | |
with gr.Blocks() as demo: | |
gr.Markdown("# Chat Interface with Model Selection (JSON Output)") | |
with gr.Row(): | |
# Left column for inputs | |
with gr.Column(): | |
model_dropdown = gr.Dropdown( | |
choices=load_model_links(), | |
label="Select Model", | |
info="Choose a model from the list" | |
) | |
load_button = gr.Button("Load Selected Model") | |
model_status = gr.Textbox(label="Model Status") | |
system_instruction = gr.Textbox( | |
value=default_system_message, | |
placeholder="Enter system instruction here...", | |
label="System Instruction", | |
lines=3 | |
) | |
user_input = gr.Textbox( | |
placeholder="Type your message here...", | |
label="Your Message", | |
lines=3 | |
) | |
submit_btn = gr.Button("Submit") | |
# Right column for bot response | |
with gr.Column(): | |
response_display = gr.Textbox( | |
label="Bot Response (JSON)", | |
interactive=False, | |
placeholder="Response will appear here in JSON format.", | |
lines=10 | |
) | |
# Event handlers | |
load_button.click( | |
fn=model_manager.load_model, | |
inputs=[model_dropdown], | |
outputs=[model_status] | |
) | |
submit_btn.click( | |
fn=generate_response, | |
inputs=[model_dropdown, system_instruction, user_input], | |
outputs=[response_display] | |
) | |
# Launch the app | |
demo.launch() |