import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import torch import os # Model configuration CHECKPOINT_DIR = "checkpoints" BASE_MODEL = "microsoft/phi-2" class Phi2Chat: def __init__(self): self.tokenizer = None self.model = None self.is_loaded = False self.chat_template = """<|im_start|>user {prompt}\n<|im_end|> <|im_start|>assistant """ def load_model(self): """Lazy loading of the model""" if not self.is_loaded: try: print("Loading tokenizer...") # Load tokenizer from local checkpoint self.tokenizer = AutoTokenizer.from_pretrained( os.path.join(CHECKPOINT_DIR, "tokenizer"), local_files_only=True ) print("Loading base model...") base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, device_map="cpu", torch_dtype=torch.float32, low_cpu_mem_usage=True ) print("Loading fine-tuned model...") # Load adapter from local checkpoint self.model = PeftModel.from_pretrained( base_model, os.path.join(CHECKPOINT_DIR, "adapter"), local_files_only=True ) self.model.eval() # Try to move to GPU if available if torch.cuda.is_available(): try: self.model = self.model.to("cuda") print("Model moved to GPU") except Exception as e: print(f"Could not move model to GPU: {e}") self.is_loaded = True print("Model loading completed!") except Exception as e: print(f"Error loading model: {e}") raise e def generate_response( self, prompt: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9 ) -> str: if not self.is_loaded: return "Model is still loading... Please try again in a moment." try: formatted_prompt = self.chat_template.format(prompt=prompt) inputs = self.tokenizer(formatted_prompt, return_tensors="pt") inputs = {k: v.to(self.model.device) for k, v in inputs.items()} with torch.no_grad(): output = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True ) response = self.tokenizer.decode(output[0], skip_special_tokens=True) try: response = response.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0].strip() except: response = response.split(prompt)[-1].strip() return response except Exception as e: return f"Error generating response: {str(e)}" # Initialize model phi2_chat = Phi2Chat() def loading_message(): return "Loading the model... This may take a few minutes. Please wait." def chat_response(message, history): # Ensure model is loaded if not phi2_chat.is_loaded: phi2_chat.load_model() return phi2_chat.generate_response(message) # Create Gradio interface css = """ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; } .chat-message { padding: 1rem; border-radius: 0.5rem; margin-bottom: 1rem; background: #f7f7f7; } """ with gr.Blocks(css=css) as demo: gr.Markdown("# Phi-2 Fine-tuned Chat Assistant") gr.Markdown(""" This is a fine-tuned version of Microsoft's Phi-2 model using QLoRA. The model has been trained on the OpenAssistant dataset to improve its conversational abilities. Note: First-time loading may take a few minutes. Please be patient. """) chatbot = gr.ChatInterface( fn=chat_response, chatbot=gr.Chatbot(height=400), textbox=gr.Textbox( placeholder="Type your message here... (Model will load on first message)", container=False, scale=7 ), title="Chat with Phi-2", description="Have a conversation with the fine-tuned Phi-2 model", theme="soft", examples=[ "What is quantum computing?", "Write a Python function to find prime numbers", "Explain the concept of machine learning in simple terms" ], retry_btn="Retry", undo_btn="Undo", clear_btn="Clear", concurrency_limit=1 ) # Launch with optimized settings demo.launch(max_threads=4)