Spaces:
Runtime error
Runtime error
File size: 5,049 Bytes
6679c19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import os
# Model configuration
CHECKPOINT_DIR = "checkpoints"
BASE_MODEL = "microsoft/phi-2"
class Phi2Chat:
def __init__(self):
self.tokenizer = None
self.model = None
self.is_loaded = False
self.chat_template = """<|im_start|>user
{prompt}\n<|im_end|>
<|im_start|>assistant
"""
def load_model(self):
"""Lazy loading of the model"""
if not self.is_loaded:
try:
print("Loading tokenizer...")
# Load tokenizer from local checkpoint
self.tokenizer = AutoTokenizer.from_pretrained(
os.path.join(CHECKPOINT_DIR, "tokenizer"),
local_files_only=True
)
print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="cpu",
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
print("Loading fine-tuned model...")
# Load adapter from local checkpoint
self.model = PeftModel.from_pretrained(
base_model,
os.path.join(CHECKPOINT_DIR, "adapter"),
local_files_only=True
)
self.model.eval()
# Try to move to GPU if available
if torch.cuda.is_available():
try:
self.model = self.model.to("cuda")
print("Model moved to GPU")
except Exception as e:
print(f"Could not move model to GPU: {e}")
self.is_loaded = True
print("Model loading completed!")
except Exception as e:
print(f"Error loading model: {e}")
raise e
def generate_response(
self,
prompt: str,
max_new_tokens: int = 300,
temperature: float = 0.7,
top_p: float = 0.9
) -> str:
if not self.is_loaded:
return "Model is still loading... Please try again in a moment."
try:
formatted_prompt = self.chat_template.format(prompt=prompt)
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad():
output = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True
)
response = self.tokenizer.decode(output[0], skip_special_tokens=True)
try:
response = response.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0].strip()
except:
response = response.split(prompt)[-1].strip()
return response
except Exception as e:
return f"Error generating response: {str(e)}"
# Initialize model
phi2_chat = Phi2Chat()
def loading_message():
return "Loading the model... This may take a few minutes. Please wait."
def chat_response(message, history):
# Ensure model is loaded
if not phi2_chat.is_loaded:
phi2_chat.load_model()
return phi2_chat.generate_response(message)
# Create Gradio interface
css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.chat-message {
padding: 1rem;
border-radius: 0.5rem;
margin-bottom: 1rem;
background: #f7f7f7;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# Phi-2 Fine-tuned Chat Assistant")
gr.Markdown("""
This is a fine-tuned version of Microsoft's Phi-2 model using QLoRA.
The model has been trained on the OpenAssistant dataset to improve its conversational abilities.
Note: First-time loading may take a few minutes. Please be patient.
""")
chatbot = gr.ChatInterface(
fn=chat_response,
chatbot=gr.Chatbot(height=400),
textbox=gr.Textbox(
placeholder="Type your message here... (Model will load on first message)",
container=False,
scale=7
),
title="Chat with Phi-2",
description="Have a conversation with the fine-tuned Phi-2 model",
theme="soft",
examples=[
"What is quantum computing?",
"Write a Python function to find prime numbers",
"Explain the concept of machine learning in simple terms"
],
retry_btn="Retry",
undo_btn="Undo",
clear_btn="Clear",
concurrency_limit=1
)
# Launch with optimized settings
demo.launch(max_threads=4) |