pradeep6kumar2024 commited on
Commit
a27324e
·
1 Parent(s): 3bf5e4f

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -187
app.py CHANGED
@@ -5,11 +5,17 @@ from peft import PeftModel
5
  import time
6
  import gc
7
  import os
 
8
 
9
  # Configuration
10
  BASE_MODEL = "microsoft/phi-2"
11
  ADAPTER_MODEL = "pradeep6kumar2024/phi2-qlora-assistant"
12
 
 
 
 
 
 
13
  class ModelWrapper:
14
  def __init__(self):
15
  self.model = None
@@ -26,6 +32,8 @@ class ModelWrapper:
26
  # Clear memory
27
  gc.collect()
28
 
 
 
29
  print("Loading tokenizer...")
30
  self.tokenizer = AutoTokenizer.from_pretrained(
31
  BASE_MODEL,
@@ -34,21 +42,26 @@ class ModelWrapper:
34
  )
35
  self.tokenizer.pad_token = self.tokenizer.eos_token
36
 
 
 
37
  print("Loading base model...")
38
  base_model = AutoModelForCausalLM.from_pretrained(
39
  BASE_MODEL,
40
- torch_dtype=torch.float32, # Use float32 for CPU
41
  device_map="cpu",
42
  trust_remote_code=True,
43
  use_flash_attention_2=False,
44
- low_cpu_mem_usage=True
 
45
  )
46
 
 
 
47
  print("Loading LoRA adapter...")
48
  self.model = PeftModel.from_pretrained(
49
  base_model,
50
  ADAPTER_MODEL,
51
- torch_dtype=torch.float32, # Use float32 for CPU
52
  device_map="cpu"
53
  )
54
 
@@ -56,6 +69,8 @@ class ModelWrapper:
56
  del base_model
57
  gc.collect()
58
 
 
 
59
  self.model.eval()
60
  print("Model loading complete!")
61
  self.loaded = True
@@ -63,188 +78,79 @@ class ModelWrapper:
63
  print(f"Error during model loading: {str(e)}")
64
  raise
65
 
66
- def generate_response(self, prompt, max_length=512, temperature=0.7, top_p=0.9):
67
  if not self.loaded:
68
  self.load_model()
69
 
70
  try:
71
- # Enhance prompt for better completion
72
  if "function" in prompt.lower() and "python" in prompt.lower():
73
- enhanced_prompt = f"""Write a Python function with the following requirements:
74
- {prompt}
75
- Include:
76
- - Function implementation with comments
77
- - Example usage
78
- - Output demonstration
79
-
80
- Provide only the implementation, no conversation."""
81
  elif any(word in prompt.lower() for word in ["explain", "what is", "how does", "describe"]):
82
- enhanced_prompt = f"""Below is a request for explanation. Provide a complete, focused response without any conversation:
83
-
84
- {prompt}
85
-
86
- Your response should include:
87
- 1. A clear explanation in simple terms
88
- 2. Practical examples and applications
89
- 3. Important concepts to understand
90
-
91
- End your response when the explanation is complete. Do not ask questions or engage in conversation."""
92
  else:
93
- enhanced_prompt = f"""Below is a request. Provide a complete, focused response without any conversation:
94
-
95
- {prompt}
96
-
97
- End your response when complete. Do not ask questions or engage in conversation."""
98
 
99
- print(f"Enhanced prompt: {enhanced_prompt}") # Debug logging
100
 
101
- # Tokenize input
102
  inputs = self.tokenizer(
103
  enhanced_prompt,
104
  return_tensors="pt",
105
  truncation=True,
106
- max_length=512,
107
  padding=True
108
- ).to("cpu") # Ensure CPU usage
109
 
110
- # Generate with more conservative parameters for CPU
111
  start_time = time.time()
112
  with torch.no_grad():
113
  outputs = self.model.generate(
114
  **inputs,
115
- max_length=min(max_length, 384), # Limit max length for CPU
116
- min_length=50,
117
  temperature=min(0.5, temperature),
118
  top_p=min(0.85, top_p),
119
  do_sample=True,
120
  pad_token_id=self.tokenizer.pad_token_id,
121
  eos_token_id=self.tokenizer.eos_token_id,
122
- repetition_penalty=1.3,
123
- no_repeat_ngram_size=4,
124
  num_return_sequences=1,
125
  early_stopping=True,
126
- num_beams=2, # Reduced beam search for CPU
127
  length_penalty=0.6
128
  )
129
 
130
  # Decode response
131
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
132
- print(f"Raw response: {response}") # Debug logging
133
 
134
  # Clean up the response
135
  if response.startswith(enhanced_prompt):
136
  response = response[len(enhanced_prompt):].strip()
137
 
138
- print(f"After prompt removal: {response}") # Debug logging
139
-
140
- # Remove common closure patterns and conversation starters
141
- closures = [
142
- "Best regards,",
143
- "Sincerely,",
144
- "Thanks,",
145
- "Thank you,",
146
- "Regards,",
147
- "Assistant:",
148
- "Human:",
149
- "[Your Name]",
150
- "[Student]",
151
- "Let me know if you need any clarification",
152
- "I hope this helps",
153
- "Feel free to ask",
154
- "Can you provide",
155
- "Would you like",
156
- "Do you want",
157
- "Let me know",
158
- "Please let me know",
159
- "Is there anything else",
160
- "Do you have any questions",
161
- "Sure!",
162
- "Here are some examples:"
163
- ]
164
-
165
- # First remove conversation starters from the end
166
- for closure in closures:
167
- if response.lower().endswith(closure.lower()):
168
- response = response[:-(len(closure))].strip()
169
-
170
- # Then remove any remaining conversation patterns
171
- conversation_patterns = [
172
- r"\?\s*$", # Questions at the end
173
- r"Sure!.*$", # Responses starting with Sure!
174
- r"Here are.*examples:?\s*$", # Incomplete example lists
175
- r"Can you.*\?\s*$", # Questions starting with Can you
176
- r"Would you.*\?\s*$", # Questions starting with Would you
177
- r"Do you.*\?\s*$", # Questions starting with Do you
178
- r"Let me know.*$", # Let me know phrases
179
- r"I hope.*$", # I hope phrases
180
- r"Feel free.*$" # Feel free phrases
181
- ]
182
-
183
- import re
184
- for pattern in conversation_patterns:
185
- response = re.sub(pattern, "", response).strip()
186
-
187
- print(f"After conversation removal: {response}") # Debug logging
188
 
189
  # Ensure code examples are properly formatted
190
  if "```python" not in response and "def " in response:
191
  response = "```python\n" + response + "\n```"
192
 
193
- # More lenient validation but check for conversation markers
194
- if (len(response.strip()) < 20 or
195
- response.strip() == "Response:" or
196
- response.strip().endswith("?") or
197
- "can you" in response.lower() or
198
- "let me know" in response.lower()):
199
- print("Response validation failed - using fallback") # Debug logging
200
-
201
- if "machine learning" in prompt.lower():
202
- fallback_response = """Machine learning is a branch of artificial intelligence that enables computers to learn and improve from experience without being explicitly programmed. Think of it like teaching a child:
203
-
204
- 1. Simple Explanation:
205
- - Instead of giving strict rules, we show the computer many examples
206
- - The computer finds patterns in these examples
207
- - It uses these patterns to make decisions about new situations
208
-
209
- 2. Real-World Applications:
210
- - Email Spam Detection: Learning to identify unwanted emails based on past examples
211
- - Netflix Recommendations: Suggesting movies based on what you've watched
212
- - Face Recognition: Unlocking your phone by learning your facial features
213
- - Virtual Assistants: Siri and Alexa understanding and responding to voice commands
214
- - Medical Diagnosis: Helping doctors identify diseases in medical images
215
- - Fraud Detection: Banks identifying suspicious transactions
216
-
217
- 3. Key Benefits:
218
- - Automation of complex tasks
219
- - More accurate predictions over time
220
- - Ability to handle large amounts of data
221
- - Continuous improvement through learning
222
-
223
- Machine learning is transforming industries by automating tasks that once required human intelligence, making processes more efficient and enabling new possibilities in technology."""
224
- elif "function" in prompt.lower():
225
  fallback_response = """```python
226
  def add_numbers(a, b):
227
- '''
228
- Add two numbers and return the result
229
- Args:
230
- a: first number
231
- b: second number
232
- Returns:
233
- sum of a and b
234
- '''
235
  return a + b
236
-
237
- # Example usage
238
- num1 = 5
239
- num2 = 3
240
- result = add_numbers(num1, num2)
241
- print(f"The sum of {num1} and {num2} is: {result}") # Output: The sum of 5 and 3 is: 8
242
  ```"""
243
  else:
244
- fallback_response = "I apologize, but I couldn't generate a complete response. Please try using a lower temperature (0.3-0.5) for more focused output."
245
 
246
  response = fallback_response
247
 
 
 
 
248
  generation_time = time.time() - start_time
249
  return response, generation_time
250
  except Exception as e:
@@ -254,7 +160,7 @@ print(f"The sum of {num1} and {num2} is: {result}") # Output: The sum of 5 and
254
  # Initialize model wrapper
255
  model_wrapper = ModelWrapper()
256
 
257
- def generate_text(prompt, max_length=384, temperature=0.7, top_p=0.9): # Reduced default max_length
258
  """Gradio interface function"""
259
  try:
260
  if not prompt.strip():
@@ -269,91 +175,74 @@ def generate_text(prompt, max_length=384, temperature=0.7, top_p=0.9): # Reduce
269
  return f"Generated in {gen_time:.2f} seconds:\n\n{response}"
270
  except Exception as e:
271
  print(f"Error in generate_text: {str(e)}")
272
- return f"Error generating response: {str(e)}\nPlease try again with a different prompt or parameters."
273
 
274
- # Create the Gradio interface with CPU-friendly defaults
275
  demo = gr.Interface(
276
  fn=generate_text,
277
  inputs=[
278
  gr.Textbox(
279
  label="Enter your prompt",
280
  placeholder="Type your prompt here...",
281
- lines=4
282
  ),
283
  gr.Slider(
284
  minimum=64,
285
- maximum=512,
286
- value=384, # Reduced default
287
- step=64,
288
  label="Maximum Length",
289
- info="Longer values = longer responses but slower generation"
290
  ),
291
  gr.Slider(
292
  minimum=0.1,
293
- maximum=1.0,
294
- value=0.5, # Reduced default
295
  step=0.1,
296
  label="Temperature",
297
- info="Higher values = more creative, lower values = more focused"
298
  ),
299
  gr.Slider(
300
- minimum=0.1,
301
- maximum=1.0,
302
- value=0.85, # Adjusted default
303
  step=0.1,
304
  label="Top P",
305
- info="Controls diversity of word choices"
306
  ),
307
  ],
308
- outputs=gr.Textbox(label="Generated Response", lines=8),
309
- title="Phi-2 QLoRA Fine-tuned Assistant (CPU Version)",
310
- description="""This is a fine-tuned version of Microsoft's Phi-2 model using QLoRA.
311
- The model has been trained to provide helpful responses for various tasks including coding, writing, and general assistance.
312
-
313
- Example tasks:
314
- - Writing Python functions and explaining code
315
- - Explaining technical concepts in simple terms
316
- - Drafting professional emails and documents
317
 
318
  Tips:
319
- - For code generation, use lower temperature (0.3-0.5)
320
- - For creative writing, use higher temperature (0.5-0.7)
321
- - Keep max length lower (256-384) for faster responses
322
  """,
323
  examples=[
324
  [
325
- "Write a Python function to calculate the factorial of a number and provide additional recursive function examples",
326
- 384,
327
- 0.5,
328
- 0.85
329
- ],
330
- [
331
- "Explain what machine learning is in simple terms and provide some real-world applications",
332
- 384,
333
- 0.5,
334
- 0.85
335
- ],
336
- [
337
- "Write a professional email to schedule a team meeting for next week to discuss project progress",
338
- 384,
339
- 0.5,
340
- 0.85
341
  ],
342
  [
343
- "Write a Python function to implement binary search algorithm with detailed comments",
344
- 384,
345
- 0.5,
346
- 0.85
347
  ],
348
  [
349
- "Explain the concept of object-oriented programming using a real-world analogy",
350
- 384,
351
- 0.5,
352
- 0.85
353
  ]
354
  ],
355
  cache_examples=False
356
  )
357
 
358
  if __name__ == "__main__":
 
359
  demo.launch()
 
5
  import time
6
  import gc
7
  import os
8
+ import psutil
9
 
10
  # Configuration
11
  BASE_MODEL = "microsoft/phi-2"
12
  ADAPTER_MODEL = "pradeep6kumar2024/phi2-qlora-assistant"
13
 
14
+ # Memory monitoring
15
+ def get_memory_usage():
16
+ process = psutil.Process(os.getpid())
17
+ return process.memory_info().rss / (1024 * 1024) # MB
18
+
19
  class ModelWrapper:
20
  def __init__(self):
21
  self.model = None
 
32
  # Clear memory
33
  gc.collect()
34
 
35
+ print(f"Memory before loading: {get_memory_usage():.2f} MB")
36
+
37
  print("Loading tokenizer...")
38
  self.tokenizer = AutoTokenizer.from_pretrained(
39
  BASE_MODEL,
 
42
  )
43
  self.tokenizer.pad_token = self.tokenizer.eos_token
44
 
45
+ print(f"Memory after tokenizer: {get_memory_usage():.2f} MB")
46
+
47
  print("Loading base model...")
48
  base_model = AutoModelForCausalLM.from_pretrained(
49
  BASE_MODEL,
50
+ torch_dtype=torch.float32,
51
  device_map="cpu",
52
  trust_remote_code=True,
53
  use_flash_attention_2=False,
54
+ low_cpu_mem_usage=True,
55
+ offload_folder="offload"
56
  )
57
 
58
+ print(f"Memory after base model: {get_memory_usage():.2f} MB")
59
+
60
  print("Loading LoRA adapter...")
61
  self.model = PeftModel.from_pretrained(
62
  base_model,
63
  ADAPTER_MODEL,
64
+ torch_dtype=torch.float32,
65
  device_map="cpu"
66
  )
67
 
 
69
  del base_model
70
  gc.collect()
71
 
72
+ print(f"Memory after adapter: {get_memory_usage():.2f} MB")
73
+
74
  self.model.eval()
75
  print("Model loading complete!")
76
  self.loaded = True
 
78
  print(f"Error during model loading: {str(e)}")
79
  raise
80
 
81
+ def generate_response(self, prompt, max_length=256, temperature=0.7, top_p=0.9):
82
  if not self.loaded:
83
  self.load_model()
84
 
85
  try:
86
+ # Use shorter prompts to save memory
87
  if "function" in prompt.lower() and "python" in prompt.lower():
88
+ enhanced_prompt = f"""Write Python function: {prompt}"""
 
 
 
 
 
 
 
89
  elif any(word in prompt.lower() for word in ["explain", "what is", "how does", "describe"]):
90
+ enhanced_prompt = f"""Explain briefly: {prompt}"""
 
 
 
 
 
 
 
 
 
91
  else:
92
+ enhanced_prompt = prompt
 
 
 
 
93
 
94
+ print(f"Enhanced prompt: {enhanced_prompt}")
95
 
96
+ # Tokenize input with shorter max length
97
  inputs = self.tokenizer(
98
  enhanced_prompt,
99
  return_tensors="pt",
100
  truncation=True,
101
+ max_length=256, # Reduced for memory
102
  padding=True
103
+ ).to("cpu")
104
 
105
+ # Generate with minimal parameters
106
  start_time = time.time()
107
  with torch.no_grad():
108
  outputs = self.model.generate(
109
  **inputs,
110
+ max_length=min(max_length, 256), # Strict limit
111
+ min_length=10, # Reduced minimum
112
  temperature=min(0.5, temperature),
113
  top_p=min(0.85, top_p),
114
  do_sample=True,
115
  pad_token_id=self.tokenizer.pad_token_id,
116
  eos_token_id=self.tokenizer.eos_token_id,
117
+ repetition_penalty=1.2,
118
+ no_repeat_ngram_size=3,
119
  num_return_sequences=1,
120
  early_stopping=True,
121
+ num_beams=1, # Greedy decoding to save memory
122
  length_penalty=0.6
123
  )
124
 
125
  # Decode response
126
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
127
 
128
  # Clean up the response
129
  if response.startswith(enhanced_prompt):
130
  response = response[len(enhanced_prompt):].strip()
131
 
132
+ # Basic cleanup only
133
+ response = response.replace("Human:", "").replace("Assistant:", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  # Ensure code examples are properly formatted
136
  if "```python" not in response and "def " in response:
137
  response = "```python\n" + response + "\n```"
138
 
139
+ # Simple validation
140
+ if len(response.strip()) < 10:
141
+ if "function" in prompt.lower():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  fallback_response = """```python
143
  def add_numbers(a, b):
 
 
 
 
 
 
 
 
144
  return a + b
 
 
 
 
 
 
145
  ```"""
146
  else:
147
+ fallback_response = "I apologize, but I couldn't generate a response. Please try with a simpler prompt."
148
 
149
  response = fallback_response
150
 
151
+ # Clear memory after generation
152
+ gc.collect()
153
+
154
  generation_time = time.time() - start_time
155
  return response, generation_time
156
  except Exception as e:
 
160
  # Initialize model wrapper
161
  model_wrapper = ModelWrapper()
162
 
163
+ def generate_text(prompt, max_length=256, temperature=0.5, top_p=0.85):
164
  """Gradio interface function"""
165
  try:
166
  if not prompt.strip():
 
175
  return f"Generated in {gen_time:.2f} seconds:\n\n{response}"
176
  except Exception as e:
177
  print(f"Error in generate_text: {str(e)}")
178
+ return f"Error generating response: {str(e)}\nPlease try again with a shorter prompt."
179
 
180
+ # Create a very lightweight Gradio interface
181
  demo = gr.Interface(
182
  fn=generate_text,
183
  inputs=[
184
  gr.Textbox(
185
  label="Enter your prompt",
186
  placeholder="Type your prompt here...",
187
+ lines=3
188
  ),
189
  gr.Slider(
190
  minimum=64,
191
+ maximum=256,
192
+ value=192,
193
+ step=32,
194
  label="Maximum Length",
195
+ info="Keep this low for CPU"
196
  ),
197
  gr.Slider(
198
  minimum=0.1,
199
+ maximum=0.7,
200
+ value=0.4,
201
  step=0.1,
202
  label="Temperature",
203
+ info="Lower is better for CPU"
204
  ),
205
  gr.Slider(
206
+ minimum=0.5,
207
+ maximum=0.9,
208
+ value=0.8,
209
  step=0.1,
210
  label="Top P",
211
+ info="Controls diversity"
212
  ),
213
  ],
214
+ outputs=gr.Textbox(label="Generated Response", lines=6),
215
+ title="Phi-2 QLoRA Assistant (CPU-Optimized)",
216
+ description="""This is a lightweight CPU version of the fine-tuned Phi-2 model.
 
 
 
 
 
 
217
 
218
  Tips:
219
+ - Keep prompts short and specific
220
+ - Use lower maximum length (128-192) for faster responses
221
+ - Use lower temperature (0.3-0.5) for more reliable responses
222
  """,
223
  examples=[
224
  [
225
+ "Write a Python function to calculate factorial",
226
+ 192,
227
+ 0.4,
228
+ 0.8
 
 
 
 
 
 
 
 
 
 
 
 
229
  ],
230
  [
231
+ "Explain machine learning simply",
232
+ 192,
233
+ 0.4,
234
+ 0.8
235
  ],
236
  [
237
+ "Write a short email to schedule a meeting",
238
+ 192,
239
+ 0.4,
240
+ 0.8
241
  ]
242
  ],
243
  cache_examples=False
244
  )
245
 
246
  if __name__ == "__main__":
247
+ demo.queue(concurrency_count=1) # Limit concurrency
248
  demo.launch()