Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import PeftModel | |
import torch | |
import os | |
# Model configuration | |
CHECKPOINT_DIR = "checkpoints" | |
BASE_MODEL = "microsoft/phi-2" | |
class Phi2Chat: | |
def __init__(self): | |
self.tokenizer = None | |
self.model = None | |
self.is_loaded = False | |
self.chat_template = """<|im_start|>user | |
{prompt}\n<|im_end|> | |
<|im_start|>assistant | |
""" | |
def load_model(self): | |
"""Lazy loading of the model""" | |
if not self.is_loaded: | |
try: | |
print("Loading tokenizer...") | |
# Load tokenizer from local checkpoint | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
os.path.join(CHECKPOINT_DIR, "tokenizer"), | |
local_files_only=True | |
) | |
print("Loading base model...") | |
base_model = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL, | |
device_map="cpu", | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=True | |
) | |
print("Loading fine-tuned model...") | |
# Load adapter from local checkpoint | |
self.model = PeftModel.from_pretrained( | |
base_model, | |
os.path.join(CHECKPOINT_DIR, "adapter"), | |
local_files_only=True | |
) | |
self.model.eval() | |
# Try to move to GPU if available | |
if torch.cuda.is_available(): | |
try: | |
self.model = self.model.to("cuda") | |
print("Model moved to GPU") | |
except Exception as e: | |
print(f"Could not move model to GPU: {e}") | |
self.is_loaded = True | |
print("Model loading completed!") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
raise e | |
def generate_response( | |
self, | |
prompt: str, | |
max_new_tokens: int = 300, | |
temperature: float = 0.7, | |
top_p: float = 0.9 | |
) -> str: | |
if not self.is_loaded: | |
return "Model is still loading... Please try again in a moment." | |
try: | |
formatted_prompt = self.chat_template.format(prompt=prompt) | |
inputs = self.tokenizer(formatted_prompt, return_tensors="pt") | |
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
output = self.model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True | |
) | |
response = self.tokenizer.decode(output[0], skip_special_tokens=True) | |
try: | |
response = response.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0].strip() | |
except: | |
response = response.split(prompt)[-1].strip() | |
return response | |
except Exception as e: | |
return f"Error generating response: {str(e)}" | |
# Initialize model | |
phi2_chat = Phi2Chat() | |
def loading_message(): | |
return "Loading the model... This may take a few minutes. Please wait." | |
def chat_response(message, history): | |
# Ensure model is loaded | |
if not phi2_chat.is_loaded: | |
phi2_chat.load_model() | |
return phi2_chat.generate_response(message) | |
# Create Gradio interface | |
css = """ | |
.gradio-container { | |
font-family: 'IBM Plex Sans', sans-serif; | |
} | |
.chat-message { | |
padding: 1rem; | |
border-radius: 0.5rem; | |
margin-bottom: 1rem; | |
background: #f7f7f7; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("# Phi-2 Fine-tuned Chat Assistant") | |
gr.Markdown(""" | |
This is a fine-tuned version of Microsoft's Phi-2 model using QLoRA. | |
The model has been trained on the OpenAssistant dataset to improve its conversational abilities. | |
Note: First-time loading may take a few minutes. Please be patient. | |
""") | |
chatbot = gr.ChatInterface( | |
fn=chat_response, | |
chatbot=gr.Chatbot(height=400), | |
textbox=gr.Textbox( | |
placeholder="Type your message here... (Model will load on first message)", | |
container=False, | |
scale=7 | |
), | |
title="Chat with Phi-2", | |
description="Have a conversation with the fine-tuned Phi-2 model", | |
theme="soft", | |
examples=[ | |
"What is quantum computing?", | |
"Write a Python function to find prime numbers", | |
"Explain the concept of machine learning in simple terms" | |
], | |
retry_btn="Retry", | |
undo_btn="Undo", | |
clear_btn="Clear", | |
concurrency_limit=1 | |
) | |
# Launch with optimized settings | |
demo.launch(max_threads=4) |