nikhiljais commited on
Commit
6679c19
·
verified ·
1 Parent(s): 74d456d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -154
app.py CHANGED
@@ -1,154 +1,154 @@
1
- import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- from peft import PeftModel
4
- import torch
5
- import os
6
-
7
- # Model configuration
8
- CHECKPOINT_DIR = "checkpoints"
9
- BASE_MODEL = "microsoft/phi-2"
10
-
11
- class Phi2Chat:
12
- def __init__(self):
13
- self.tokenizer = None
14
- self.model = None
15
- self.is_loaded = False
16
- self.chat_template = """<|im_start|>user
17
- {prompt}\n<|im_end|>
18
- <|im_start|>assistant
19
- """
20
-
21
- def load_model(self):
22
- """Lazy loading of the model"""
23
- if not self.is_loaded:
24
- try:
25
- print("Loading tokenizer...")
26
- # Load tokenizer from local checkpoint
27
- self.tokenizer = AutoTokenizer.from_pretrained(
28
- os.path.join(CHECKPOINT_DIR, "tokenizer"),
29
- local_files_only=True
30
- )
31
-
32
- print("Loading base model...")
33
- base_model = AutoModelForCausalLM.from_pretrained(
34
- BASE_MODEL,
35
- device_map="cpu",
36
- torch_dtype=torch.float32,
37
- low_cpu_mem_usage=True
38
- )
39
-
40
- print("Loading fine-tuned model...")
41
- # Load adapter from local checkpoint
42
- self.model = PeftModel.from_pretrained(
43
- base_model,
44
- os.path.join(CHECKPOINT_DIR, "adapter"),
45
- local_files_only=True
46
- )
47
- self.model.eval()
48
-
49
- # Try to move to GPU if available
50
- if torch.cuda.is_available():
51
- try:
52
- self.model = self.model.to("cuda")
53
- print("Model moved to GPU")
54
- except Exception as e:
55
- print(f"Could not move model to GPU: {e}")
56
-
57
- self.is_loaded = True
58
- print("Model loading completed!")
59
- except Exception as e:
60
- print(f"Error loading model: {e}")
61
- raise e
62
-
63
- def generate_response(
64
- self,
65
- prompt: str,
66
- max_new_tokens: int = 300,
67
- temperature: float = 0.7,
68
- top_p: float = 0.9
69
- ) -> str:
70
- if not self.is_loaded:
71
- return "Model is still loading... Please try again in a moment."
72
-
73
- try:
74
- formatted_prompt = self.chat_template.format(prompt=prompt)
75
- inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
76
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
77
-
78
- with torch.no_grad():
79
- output = self.model.generate(
80
- **inputs,
81
- max_new_tokens=max_new_tokens,
82
- temperature=temperature,
83
- top_p=top_p,
84
- do_sample=True
85
- )
86
-
87
- response = self.tokenizer.decode(output[0], skip_special_tokens=True)
88
- try:
89
- response = response.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0].strip()
90
- except:
91
- response = response.split(prompt)[-1].strip()
92
-
93
- return response
94
- except Exception as e:
95
- return f"Error generating response: {str(e)}"
96
-
97
- # Initialize model
98
- phi2_chat = Phi2Chat()
99
-
100
- def loading_message():
101
- return "Loading the model... This may take a few minutes. Please wait."
102
-
103
- def chat_response(message, history):
104
- # Ensure model is loaded
105
- if not phi2_chat.is_loaded:
106
- phi2_chat.load_model()
107
- return phi2_chat.generate_response(message)
108
-
109
- # Create Gradio interface
110
- css = """
111
- .gradio-container {
112
- font-family: 'IBM Plex Sans', sans-serif;
113
- }
114
- .chat-message {
115
- padding: 1rem;
116
- border-radius: 0.5rem;
117
- margin-bottom: 1rem;
118
- background: #f7f7f7;
119
- }
120
- """
121
-
122
- with gr.Blocks(css=css) as demo:
123
- gr.Markdown("# Phi-2 Fine-tuned Chat Assistant")
124
- gr.Markdown("""
125
- This is a fine-tuned version of Microsoft's Phi-2 model using QLoRA.
126
- The model has been trained on the OpenAssistant dataset to improve its conversational abilities.
127
-
128
- Note: First-time loading may take a few minutes. Please be patient.
129
- """)
130
-
131
- chatbot = gr.ChatInterface(
132
- chat_response,
133
- chatbot=gr.Chatbot(height=400),
134
- textbox=gr.Textbox(
135
- placeholder="Type your message here... (Model will load on first message)",
136
- container=False,
137
- scale=7
138
- ),
139
- title="Chat with Phi-2",
140
- description="Have a conversation with the fine-tuned Phi-2 model",
141
- theme="soft",
142
- examples=[
143
- "What is quantum computing?",
144
- "Write a Python function to find prime numbers",
145
- "Explain the concept of machine learning in simple terms"
146
- ],
147
- retry_btn="Retry",
148
- undo_btn="Undo",
149
- clear_btn="Clear",
150
- )
151
-
152
- # Configure queue and launch
153
- demo.queue(concurrency_count=1, max_size=10)
154
- demo.launch()
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import PeftModel
4
+ import torch
5
+ import os
6
+
7
+ # Model configuration
8
+ CHECKPOINT_DIR = "checkpoints"
9
+ BASE_MODEL = "microsoft/phi-2"
10
+
11
+ class Phi2Chat:
12
+ def __init__(self):
13
+ self.tokenizer = None
14
+ self.model = None
15
+ self.is_loaded = False
16
+ self.chat_template = """<|im_start|>user
17
+ {prompt}\n<|im_end|>
18
+ <|im_start|>assistant
19
+ """
20
+
21
+ def load_model(self):
22
+ """Lazy loading of the model"""
23
+ if not self.is_loaded:
24
+ try:
25
+ print("Loading tokenizer...")
26
+ # Load tokenizer from local checkpoint
27
+ self.tokenizer = AutoTokenizer.from_pretrained(
28
+ os.path.join(CHECKPOINT_DIR, "tokenizer"),
29
+ local_files_only=True
30
+ )
31
+
32
+ print("Loading base model...")
33
+ base_model = AutoModelForCausalLM.from_pretrained(
34
+ BASE_MODEL,
35
+ device_map="cpu",
36
+ torch_dtype=torch.float32,
37
+ low_cpu_mem_usage=True
38
+ )
39
+
40
+ print("Loading fine-tuned model...")
41
+ # Load adapter from local checkpoint
42
+ self.model = PeftModel.from_pretrained(
43
+ base_model,
44
+ os.path.join(CHECKPOINT_DIR, "adapter"),
45
+ local_files_only=True
46
+ )
47
+ self.model.eval()
48
+
49
+ # Try to move to GPU if available
50
+ if torch.cuda.is_available():
51
+ try:
52
+ self.model = self.model.to("cuda")
53
+ print("Model moved to GPU")
54
+ except Exception as e:
55
+ print(f"Could not move model to GPU: {e}")
56
+
57
+ self.is_loaded = True
58
+ print("Model loading completed!")
59
+ except Exception as e:
60
+ print(f"Error loading model: {e}")
61
+ raise e
62
+
63
+ def generate_response(
64
+ self,
65
+ prompt: str,
66
+ max_new_tokens: int = 300,
67
+ temperature: float = 0.7,
68
+ top_p: float = 0.9
69
+ ) -> str:
70
+ if not self.is_loaded:
71
+ return "Model is still loading... Please try again in a moment."
72
+
73
+ try:
74
+ formatted_prompt = self.chat_template.format(prompt=prompt)
75
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
76
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
77
+
78
+ with torch.no_grad():
79
+ output = self.model.generate(
80
+ **inputs,
81
+ max_new_tokens=max_new_tokens,
82
+ temperature=temperature,
83
+ top_p=top_p,
84
+ do_sample=True
85
+ )
86
+
87
+ response = self.tokenizer.decode(output[0], skip_special_tokens=True)
88
+ try:
89
+ response = response.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0].strip()
90
+ except:
91
+ response = response.split(prompt)[-1].strip()
92
+
93
+ return response
94
+ except Exception as e:
95
+ return f"Error generating response: {str(e)}"
96
+
97
+ # Initialize model
98
+ phi2_chat = Phi2Chat()
99
+
100
+ def loading_message():
101
+ return "Loading the model... This may take a few minutes. Please wait."
102
+
103
+ def chat_response(message, history):
104
+ # Ensure model is loaded
105
+ if not phi2_chat.is_loaded:
106
+ phi2_chat.load_model()
107
+ return phi2_chat.generate_response(message)
108
+
109
+ # Create Gradio interface
110
+ css = """
111
+ .gradio-container {
112
+ font-family: 'IBM Plex Sans', sans-serif;
113
+ }
114
+ .chat-message {
115
+ padding: 1rem;
116
+ border-radius: 0.5rem;
117
+ margin-bottom: 1rem;
118
+ background: #f7f7f7;
119
+ }
120
+ """
121
+
122
+ with gr.Blocks(css=css) as demo:
123
+ gr.Markdown("# Phi-2 Fine-tuned Chat Assistant")
124
+ gr.Markdown("""
125
+ This is a fine-tuned version of Microsoft's Phi-2 model using QLoRA.
126
+ The model has been trained on the OpenAssistant dataset to improve its conversational abilities.
127
+
128
+ Note: First-time loading may take a few minutes. Please be patient.
129
+ """)
130
+
131
+ chatbot = gr.ChatInterface(
132
+ fn=chat_response,
133
+ chatbot=gr.Chatbot(height=400),
134
+ textbox=gr.Textbox(
135
+ placeholder="Type your message here... (Model will load on first message)",
136
+ container=False,
137
+ scale=7
138
+ ),
139
+ title="Chat with Phi-2",
140
+ description="Have a conversation with the fine-tuned Phi-2 model",
141
+ theme="soft",
142
+ examples=[
143
+ "What is quantum computing?",
144
+ "Write a Python function to find prime numbers",
145
+ "Explain the concept of machine learning in simple terms"
146
+ ],
147
+ retry_btn="Retry",
148
+ undo_btn="Undo",
149
+ clear_btn="Clear",
150
+ concurrency_limit=1
151
+ )
152
+
153
+ # Launch with optimized settings
154
+ demo.launch(max_threads=4)