File size: 5,049 Bytes
6679c19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import os

# Model configuration
CHECKPOINT_DIR = "checkpoints"
BASE_MODEL = "microsoft/phi-2"

class Phi2Chat:
    def __init__(self):
        self.tokenizer = None
        self.model = None
        self.is_loaded = False
        self.chat_template = """<|im_start|>user
{prompt}\n<|im_end|>
<|im_start|>assistant
"""

    def load_model(self):
        """Lazy loading of the model"""
        if not self.is_loaded:
            try:
                print("Loading tokenizer...")
                # Load tokenizer from local checkpoint
                self.tokenizer = AutoTokenizer.from_pretrained(
                    os.path.join(CHECKPOINT_DIR, "tokenizer"),
                    local_files_only=True
                )
                
                print("Loading base model...")
                base_model = AutoModelForCausalLM.from_pretrained(
                    BASE_MODEL,
                    device_map="cpu",
                    torch_dtype=torch.float32,
                    low_cpu_mem_usage=True
                )
                
                print("Loading fine-tuned model...")
                # Load adapter from local checkpoint
                self.model = PeftModel.from_pretrained(
                    base_model,
                    os.path.join(CHECKPOINT_DIR, "adapter"),
                    local_files_only=True
                )
                self.model.eval()
                
                # Try to move to GPU if available
                if torch.cuda.is_available():
                    try:
                        self.model = self.model.to("cuda")
                        print("Model moved to GPU")
                    except Exception as e:
                        print(f"Could not move model to GPU: {e}")
                
                self.is_loaded = True
                print("Model loading completed!")
            except Exception as e:
                print(f"Error loading model: {e}")
                raise e

    def generate_response(
        self,
        prompt: str,
        max_new_tokens: int = 300,
        temperature: float = 0.7,
        top_p: float = 0.9
    ) -> str:
        if not self.is_loaded:
            return "Model is still loading... Please try again in a moment."
        
        try:
            formatted_prompt = self.chat_template.format(prompt=prompt)
            inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                output = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=True
                )
            
            response = self.tokenizer.decode(output[0], skip_special_tokens=True)
            try:
                response = response.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0].strip()
            except:
                response = response.split(prompt)[-1].strip()
            
            return response
        except Exception as e:
            return f"Error generating response: {str(e)}"

# Initialize model
phi2_chat = Phi2Chat()

def loading_message():
    return "Loading the model... This may take a few minutes. Please wait."

def chat_response(message, history):
    # Ensure model is loaded
    if not phi2_chat.is_loaded:
        phi2_chat.load_model()
    return phi2_chat.generate_response(message)

# Create Gradio interface
css = """
.gradio-container {
    font-family: 'IBM Plex Sans', sans-serif;
}
.chat-message {
    padding: 1rem;
    border-radius: 0.5rem;
    margin-bottom: 1rem;
    background: #f7f7f7;
}
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("# Phi-2 Fine-tuned Chat Assistant")
    gr.Markdown("""
    This is a fine-tuned version of Microsoft's Phi-2 model using QLoRA.
    The model has been trained on the OpenAssistant dataset to improve its conversational abilities.
    
    Note: First-time loading may take a few minutes. Please be patient.
    """)
    
    chatbot = gr.ChatInterface(
        fn=chat_response,
        chatbot=gr.Chatbot(height=400),
        textbox=gr.Textbox(
            placeholder="Type your message here... (Model will load on first message)",
            container=False,
            scale=7
        ),
        title="Chat with Phi-2",
        description="Have a conversation with the fine-tuned Phi-2 model",
        theme="soft",
        examples=[
            "What is quantum computing?",
            "Write a Python function to find prime numbers",
            "Explain the concept of machine learning in simple terms"
        ],
        retry_btn="Retry",
        undo_btn="Undo",
        clear_btn="Clear",
        concurrency_limit=1
    )

# Launch with optimized settings
demo.launch(max_threads=4)