pradeep6kumar2024 commited on
Commit
3bf5e4f
·
1 Parent(s): 537ecfd

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -48
app.py CHANGED
@@ -3,6 +3,8 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel
5
  import time
 
 
6
 
7
  # Configuration
8
  BASE_MODEL = "microsoft/phi-2"
@@ -17,6 +19,13 @@ class ModelWrapper:
17
  def load_model(self):
18
  if not self.loaded:
19
  try:
 
 
 
 
 
 
 
20
  print("Loading tokenizer...")
21
  self.tokenizer = AutoTokenizer.from_pretrained(
22
  BASE_MODEL,
@@ -28,19 +37,25 @@ class ModelWrapper:
28
  print("Loading base model...")
29
  base_model = AutoModelForCausalLM.from_pretrained(
30
  BASE_MODEL,
31
- torch_dtype=torch.float16,
32
- device_map="auto",
33
  trust_remote_code=True,
34
- use_flash_attention_2=False # Disable flash attention if causing issues
 
35
  )
36
 
37
  print("Loading LoRA adapter...")
38
  self.model = PeftModel.from_pretrained(
39
  base_model,
40
  ADAPTER_MODEL,
41
- torch_dtype=torch.float16,
42
- device_map="auto"
43
  )
 
 
 
 
 
44
  self.model.eval()
45
  print("Model loading complete!")
46
  self.loaded = True
@@ -61,9 +76,10 @@ Include:
61
  - Function implementation with comments
62
  - Example usage
63
  - Output demonstration
64
- """
 
65
  elif any(word in prompt.lower() for word in ["explain", "what is", "how does", "describe"]):
66
- enhanced_prompt = f"""Below is a request for explanation. Please provide a complete, detailed response:
67
 
68
  {prompt}
69
 
@@ -72,13 +88,13 @@ Your response should include:
72
  2. Practical examples and applications
73
  3. Important concepts to understand
74
 
75
- Response:"""
76
  else:
77
- enhanced_prompt = f"""Below is a request. Please provide a complete, detailed response:
78
 
79
  {prompt}
80
 
81
- Response:"""
82
 
83
  print(f"Enhanced prompt: {enhanced_prompt}") # Debug logging
84
 
@@ -89,26 +105,26 @@ Response:"""
89
  truncation=True,
90
  max_length=512,
91
  padding=True
92
- ).to(self.model.device)
93
 
94
- # Generate
95
  start_time = time.time()
96
  with torch.no_grad():
97
  outputs = self.model.generate(
98
  **inputs,
99
- max_length=max_length,
100
- min_length=50, # Reduced minimum length requirement
101
- temperature=max(0.6, temperature), # Ensure minimum temperature
102
- top_p=min(0.95, top_p), # Cap top_p
103
  do_sample=True,
104
  pad_token_id=self.tokenizer.pad_token_id,
105
  eos_token_id=self.tokenizer.eos_token_id,
106
- repetition_penalty=1.2, # Increased repetition penalty
107
- no_repeat_ngram_size=3,
108
  num_return_sequences=1,
109
  early_stopping=True,
110
- num_beams=3, # Reduced beam search
111
- length_penalty=0.7 # Encourage shorter responses
112
  )
113
 
114
  # Decode response
@@ -121,7 +137,7 @@ Response:"""
121
 
122
  print(f"After prompt removal: {response}") # Debug logging
123
 
124
- # Remove common closure patterns only if they appear at the very end
125
  closures = [
126
  "Best regards,",
127
  "Sincerely,",
@@ -134,21 +150,52 @@ Response:"""
134
  "[Student]",
135
  "Let me know if you need any clarification",
136
  "I hope this helps",
137
- "Feel free to ask"
 
 
 
 
 
 
 
 
 
138
  ]
139
 
 
140
  for closure in closures:
141
  if response.lower().endswith(closure.lower()):
142
  response = response[:-(len(closure))].strip()
143
 
144
- print(f"After closure removal: {response}") # Debug logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  # Ensure code examples are properly formatted
147
  if "```python" not in response and "def " in response:
148
  response = "```python\n" + response + "\n```"
149
 
150
- # More lenient validation
151
- if len(response.strip()) < 20 or response.strip() == "Response:": # Only check for very short responses
 
 
 
 
152
  print("Response validation failed - using fallback") # Debug logging
153
 
154
  if "machine learning" in prompt.lower():
@@ -194,7 +241,7 @@ result = add_numbers(num1, num2)
194
  print(f"The sum of {num1} and {num2} is: {result}") # Output: The sum of 5 and 3 is: 8
195
  ```"""
196
  else:
197
- fallback_response = "I apologize, but I couldn't generate a complete response. Please try adjusting the temperature (try 0.6-0.8) or providing more context in your prompt."
198
 
199
  response = fallback_response
200
 
@@ -207,7 +254,7 @@ print(f"The sum of {num1} and {num2} is: {result}") # Output: The sum of 5 and
207
  # Initialize model wrapper
208
  model_wrapper = ModelWrapper()
209
 
210
- def generate_text(prompt, max_length=512, temperature=0.7, top_p=0.9):
211
  """Gradio interface function"""
212
  try:
213
  if not prompt.strip():
@@ -224,7 +271,7 @@ def generate_text(prompt, max_length=512, temperature=0.7, top_p=0.9):
224
  print(f"Error in generate_text: {str(e)}")
225
  return f"Error generating response: {str(e)}\nPlease try again with a different prompt or parameters."
226
 
227
- # Create the Gradio interface
228
  demo = gr.Interface(
229
  fn=generate_text,
230
  inputs=[
@@ -235,8 +282,8 @@ demo = gr.Interface(
235
  ),
236
  gr.Slider(
237
  minimum=64,
238
- maximum=1024,
239
- value=512,
240
  step=64,
241
  label="Maximum Length",
242
  info="Longer values = longer responses but slower generation"
@@ -244,7 +291,7 @@ demo = gr.Interface(
244
  gr.Slider(
245
  minimum=0.1,
246
  maximum=1.0,
247
- value=0.7,
248
  step=0.1,
249
  label="Temperature",
250
  info="Higher values = more creative, lower values = more focused"
@@ -252,14 +299,14 @@ demo = gr.Interface(
252
  gr.Slider(
253
  minimum=0.1,
254
  maximum=1.0,
255
- value=0.9,
256
  step=0.1,
257
  label="Top P",
258
  info="Controls diversity of word choices"
259
  ),
260
  ],
261
  outputs=gr.Textbox(label="Generated Response", lines=8),
262
- title="Phi-2 QLoRA Fine-tuned Assistant",
263
  description="""This is a fine-tuned version of Microsoft's Phi-2 model using QLoRA.
264
  The model has been trained to provide helpful responses for various tasks including coding, writing, and general assistance.
265
 
@@ -270,39 +317,39 @@ demo = gr.Interface(
270
 
271
  Tips:
272
  - For code generation, use lower temperature (0.3-0.5)
273
- - For creative writing, use higher temperature (0.7-0.9)
274
- - Adjust max length based on how long you want the response to be
275
  """,
276
  examples=[
277
  [
278
  "Write a Python function to calculate the factorial of a number and provide additional recursive function examples",
279
- 512,
280
  0.5,
281
- 0.9
282
  ],
283
  [
284
  "Explain what machine learning is in simple terms and provide some real-world applications",
285
- 512,
286
- 0.7,
287
- 0.9
288
  ],
289
  [
290
  "Write a professional email to schedule a team meeting for next week to discuss project progress",
291
- 512,
292
- 0.7,
293
- 0.9
294
  ],
295
  [
296
  "Write a Python function to implement binary search algorithm with detailed comments",
297
- 512,
298
  0.5,
299
- 0.9
300
  ],
301
  [
302
  "Explain the concept of object-oriented programming using a real-world analogy",
303
- 512,
304
- 0.7,
305
- 0.9
306
  ]
307
  ],
308
  cache_examples=False
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel
5
  import time
6
+ import gc
7
+ import os
8
 
9
  # Configuration
10
  BASE_MODEL = "microsoft/phi-2"
 
19
  def load_model(self):
20
  if not self.loaded:
21
  try:
22
+ # Force CPU usage
23
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
24
+ device = torch.device("cpu")
25
+
26
+ # Clear memory
27
+ gc.collect()
28
+
29
  print("Loading tokenizer...")
30
  self.tokenizer = AutoTokenizer.from_pretrained(
31
  BASE_MODEL,
 
37
  print("Loading base model...")
38
  base_model = AutoModelForCausalLM.from_pretrained(
39
  BASE_MODEL,
40
+ torch_dtype=torch.float32, # Use float32 for CPU
41
+ device_map="cpu",
42
  trust_remote_code=True,
43
+ use_flash_attention_2=False,
44
+ low_cpu_mem_usage=True
45
  )
46
 
47
  print("Loading LoRA adapter...")
48
  self.model = PeftModel.from_pretrained(
49
  base_model,
50
  ADAPTER_MODEL,
51
+ torch_dtype=torch.float32, # Use float32 for CPU
52
+ device_map="cpu"
53
  )
54
+
55
+ # Free up memory
56
+ del base_model
57
+ gc.collect()
58
+
59
  self.model.eval()
60
  print("Model loading complete!")
61
  self.loaded = True
 
76
  - Function implementation with comments
77
  - Example usage
78
  - Output demonstration
79
+
80
+ Provide only the implementation, no conversation."""
81
  elif any(word in prompt.lower() for word in ["explain", "what is", "how does", "describe"]):
82
+ enhanced_prompt = f"""Below is a request for explanation. Provide a complete, focused response without any conversation:
83
 
84
  {prompt}
85
 
 
88
  2. Practical examples and applications
89
  3. Important concepts to understand
90
 
91
+ End your response when the explanation is complete. Do not ask questions or engage in conversation."""
92
  else:
93
+ enhanced_prompt = f"""Below is a request. Provide a complete, focused response without any conversation:
94
 
95
  {prompt}
96
 
97
+ End your response when complete. Do not ask questions or engage in conversation."""
98
 
99
  print(f"Enhanced prompt: {enhanced_prompt}") # Debug logging
100
 
 
105
  truncation=True,
106
  max_length=512,
107
  padding=True
108
+ ).to("cpu") # Ensure CPU usage
109
 
110
+ # Generate with more conservative parameters for CPU
111
  start_time = time.time()
112
  with torch.no_grad():
113
  outputs = self.model.generate(
114
  **inputs,
115
+ max_length=min(max_length, 384), # Limit max length for CPU
116
+ min_length=50,
117
+ temperature=min(0.5, temperature),
118
+ top_p=min(0.85, top_p),
119
  do_sample=True,
120
  pad_token_id=self.tokenizer.pad_token_id,
121
  eos_token_id=self.tokenizer.eos_token_id,
122
+ repetition_penalty=1.3,
123
+ no_repeat_ngram_size=4,
124
  num_return_sequences=1,
125
  early_stopping=True,
126
+ num_beams=2, # Reduced beam search for CPU
127
+ length_penalty=0.6
128
  )
129
 
130
  # Decode response
 
137
 
138
  print(f"After prompt removal: {response}") # Debug logging
139
 
140
+ # Remove common closure patterns and conversation starters
141
  closures = [
142
  "Best regards,",
143
  "Sincerely,",
 
150
  "[Student]",
151
  "Let me know if you need any clarification",
152
  "I hope this helps",
153
+ "Feel free to ask",
154
+ "Can you provide",
155
+ "Would you like",
156
+ "Do you want",
157
+ "Let me know",
158
+ "Please let me know",
159
+ "Is there anything else",
160
+ "Do you have any questions",
161
+ "Sure!",
162
+ "Here are some examples:"
163
  ]
164
 
165
+ # First remove conversation starters from the end
166
  for closure in closures:
167
  if response.lower().endswith(closure.lower()):
168
  response = response[:-(len(closure))].strip()
169
 
170
+ # Then remove any remaining conversation patterns
171
+ conversation_patterns = [
172
+ r"\?\s*$", # Questions at the end
173
+ r"Sure!.*$", # Responses starting with Sure!
174
+ r"Here are.*examples:?\s*$", # Incomplete example lists
175
+ r"Can you.*\?\s*$", # Questions starting with Can you
176
+ r"Would you.*\?\s*$", # Questions starting with Would you
177
+ r"Do you.*\?\s*$", # Questions starting with Do you
178
+ r"Let me know.*$", # Let me know phrases
179
+ r"I hope.*$", # I hope phrases
180
+ r"Feel free.*$" # Feel free phrases
181
+ ]
182
+
183
+ import re
184
+ for pattern in conversation_patterns:
185
+ response = re.sub(pattern, "", response).strip()
186
+
187
+ print(f"After conversation removal: {response}") # Debug logging
188
 
189
  # Ensure code examples are properly formatted
190
  if "```python" not in response and "def " in response:
191
  response = "```python\n" + response + "\n```"
192
 
193
+ # More lenient validation but check for conversation markers
194
+ if (len(response.strip()) < 20 or
195
+ response.strip() == "Response:" or
196
+ response.strip().endswith("?") or
197
+ "can you" in response.lower() or
198
+ "let me know" in response.lower()):
199
  print("Response validation failed - using fallback") # Debug logging
200
 
201
  if "machine learning" in prompt.lower():
 
241
  print(f"The sum of {num1} and {num2} is: {result}") # Output: The sum of 5 and 3 is: 8
242
  ```"""
243
  else:
244
+ fallback_response = "I apologize, but I couldn't generate a complete response. Please try using a lower temperature (0.3-0.5) for more focused output."
245
 
246
  response = fallback_response
247
 
 
254
  # Initialize model wrapper
255
  model_wrapper = ModelWrapper()
256
 
257
+ def generate_text(prompt, max_length=384, temperature=0.7, top_p=0.9): # Reduced default max_length
258
  """Gradio interface function"""
259
  try:
260
  if not prompt.strip():
 
271
  print(f"Error in generate_text: {str(e)}")
272
  return f"Error generating response: {str(e)}\nPlease try again with a different prompt or parameters."
273
 
274
+ # Create the Gradio interface with CPU-friendly defaults
275
  demo = gr.Interface(
276
  fn=generate_text,
277
  inputs=[
 
282
  ),
283
  gr.Slider(
284
  minimum=64,
285
+ maximum=512,
286
+ value=384, # Reduced default
287
  step=64,
288
  label="Maximum Length",
289
  info="Longer values = longer responses but slower generation"
 
291
  gr.Slider(
292
  minimum=0.1,
293
  maximum=1.0,
294
+ value=0.5, # Reduced default
295
  step=0.1,
296
  label="Temperature",
297
  info="Higher values = more creative, lower values = more focused"
 
299
  gr.Slider(
300
  minimum=0.1,
301
  maximum=1.0,
302
+ value=0.85, # Adjusted default
303
  step=0.1,
304
  label="Top P",
305
  info="Controls diversity of word choices"
306
  ),
307
  ],
308
  outputs=gr.Textbox(label="Generated Response", lines=8),
309
+ title="Phi-2 QLoRA Fine-tuned Assistant (CPU Version)",
310
  description="""This is a fine-tuned version of Microsoft's Phi-2 model using QLoRA.
311
  The model has been trained to provide helpful responses for various tasks including coding, writing, and general assistance.
312
 
 
317
 
318
  Tips:
319
  - For code generation, use lower temperature (0.3-0.5)
320
+ - For creative writing, use higher temperature (0.5-0.7)
321
+ - Keep max length lower (256-384) for faster responses
322
  """,
323
  examples=[
324
  [
325
  "Write a Python function to calculate the factorial of a number and provide additional recursive function examples",
326
+ 384,
327
  0.5,
328
+ 0.85
329
  ],
330
  [
331
  "Explain what machine learning is in simple terms and provide some real-world applications",
332
+ 384,
333
+ 0.5,
334
+ 0.85
335
  ],
336
  [
337
  "Write a professional email to schedule a team meeting for next week to discuss project progress",
338
+ 384,
339
+ 0.5,
340
+ 0.85
341
  ],
342
  [
343
  "Write a Python function to implement binary search algorithm with detailed comments",
344
+ 384,
345
  0.5,
346
+ 0.85
347
  ],
348
  [
349
  "Explain the concept of object-oriented programming using a real-world analogy",
350
+ 384,
351
+ 0.5,
352
+ 0.85
353
  ]
354
  ],
355
  cache_examples=False