sagar007 commited on
Commit
06e746f
·
verified ·
1 Parent(s): bd91e22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -191
app.py CHANGED
@@ -4,7 +4,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLI
4
  from PIL import Image
5
  import logging
6
  import spaces
7
- import numpy
8
 
9
  # Setup logging
10
  logging.basicConfig(level=logging.INFO)
@@ -15,91 +15,70 @@ class LLaVAPhiModel:
15
  self.model_id = model_id
16
  logging.info("Initializing LLaVA-Phi model...")
17
 
18
- # Initialize tokenizer
19
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
20
  if self.tokenizer.pad_token is None:
21
  self.tokenizer.pad_token = self.tokenizer.eos_token
22
-
23
- try:
24
- # Use CLIPProcessor directly instead of AutoProcessor
25
- self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
26
- logging.info("Successfully loaded CLIP processor")
27
- except Exception as e:
28
- logging.error(f"Failed to load CLIP processor: {str(e)}")
29
- self.processor = None
30
-
31
- # Increase history length to retain more context
32
  self.history = []
33
  self.model = None
34
  self.clip = None
 
 
 
35
 
36
  @spaces.GPU
37
  def ensure_models_loaded(self):
38
- """Ensure models are loaded in GPU context"""
39
  if self.model is None:
40
- # Improved quantization config for better quality
41
  from transformers import BitsAndBytesConfig
42
  quantization_config = BitsAndBytesConfig(
43
- load_in_8bit=True, # Changed from 4-bit to 8-bit for better quality
44
  bnb_8bit_compute_dtype=torch.float16,
45
  bnb_8bit_use_double_quant=False
46
  )
47
-
48
- try:
49
- self.model = AutoModelForCausalLM.from_pretrained(
50
- self.model_id,
51
- quantization_config=quantization_config,
52
- device_map="auto",
53
- torch_dtype=torch.bfloat16,
54
- trust_remote_code=True
55
- )
56
- self.model.config.pad_token_id = self.tokenizer.eos_token_id
57
- logging.info("Successfully loaded main model")
58
- except Exception as e:
59
- logging.error(f"Failed to load main model: {str(e)}")
60
- raise
61
 
62
  if self.clip is None:
63
- try:
64
- # Use CLIPModel directly instead of AutoModel
65
- self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
66
- logging.info("Successfully loaded CLIP model")
67
- except Exception as e:
68
- logging.error(f"Failed to load CLIP model: {str(e)}")
69
- self.clip = None
70
 
71
  @spaces.GPU
72
  def process_image(self, image):
73
- """Process image through CLIP if available"""
74
  try:
75
  self.ensure_models_loaded()
76
-
77
  if self.clip is None or self.processor is None:
78
  logging.warning("CLIP model or processor not available")
79
  return None
80
 
81
- # Convert image to correct format
82
  if isinstance(image, str):
83
  image = Image.open(image)
84
- elif isinstance(image, numpy.ndarray):
85
  image = Image.fromarray(image)
86
-
87
- # Ensure image is in RGB mode
88
  if image.mode != 'RGB':
89
  image = image.convert('RGB')
90
 
91
  with torch.no_grad():
92
- try:
93
- # Process image with error handling
94
- image_inputs = self.processor(images=image, return_tensors="pt")
95
- image_features = self.clip.get_image_features(
96
- pixel_values=image_inputs.pixel_values.to(self.device)
97
- )
98
- logging.info("Successfully processed image through CLIP")
99
- return image_features
100
- except Exception as e:
101
- logging.error(f"Error during image processing: {str(e)}")
102
- return None
103
  except Exception as e:
104
  logging.error(f"Error in process_image: {str(e)}")
105
  return None
@@ -116,82 +95,68 @@ class LLaVAPhiModel:
116
  message = "Note: Image processing is not available - continuing with text only.\n" + message
117
 
118
  prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:"
119
-
120
- # Include more history for better context (previous 5 turns instead of 3)
121
  context = ""
122
  for turn in self.history[-5:]:
123
  context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
124
-
125
  full_prompt = context + prompt
126
 
127
- # Increased context window
128
  inputs = self.tokenizer(
129
  full_prompt,
130
  return_tensors="pt",
131
  padding=True,
132
  truncation=True,
133
- max_length=1024 # Increased from 512
134
  )
135
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
136
 
137
  if has_image:
138
- inputs["image_features"] = image_features
139
-
140
- with torch.no_grad():
141
- # More conservative generation settings to reduce hallucinations
142
- outputs = self.model.generate(
143
- **inputs,
144
- max_new_tokens=256,
145
- min_length=20,
146
- temperature=0.3, # Reduced from 0.7 for more deterministic output
147
- do_sample=True,
148
- top_p=0.92,
149
- top_k=50,
150
- repetition_penalty=1.2, # Adjusted for more natural responses
151
- no_repeat_ngram_size=3,
152
- use_cache=True,
153
- pad_token_id=self.tokenizer.pad_token_id,
154
- eos_token_id=self.tokenizer.eos_token_id
155
  )
 
 
156
  else:
157
  prompt = f"human: {message}\ngpt:"
158
- # Include more history
159
  context = ""
160
  for turn in self.history[-5:]:
161
  context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
162
-
163
  full_prompt = context + prompt
164
 
165
- # Increased context window
166
  inputs = self.tokenizer(
167
  full_prompt,
168
  return_tensors="pt",
169
  padding=True,
170
  truncation=True,
171
- max_length=1024 # Increased from 512
172
  )
173
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
174
-
175
- with torch.no_grad():
176
- # More conservative generation settings
177
- outputs = self.model.generate(
178
- **inputs,
179
- max_new_tokens=200, # Slightly increased from 150
180
- min_length=20,
181
- temperature=0.3, # Reduced from 0.6
182
- do_sample=True,
183
- top_p=0.92,
184
- top_k=50,
185
- repetition_penalty=1.2,
186
- no_repeat_ngram_size=4,
187
- use_cache=True,
188
- pad_token_id=self.tokenizer.pad_token_id,
189
- eos_token_id=self.tokenizer.eos_token_id
190
- )
191
 
192
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- # Clean up response
195
  if "gpt:" in response:
196
  response = response.split("gpt:")[-1].strip()
197
  if "human:" in response:
@@ -204,106 +169,18 @@ class LLaVAPhiModel:
204
 
205
  except Exception as e:
206
  logging.error(f"Error generating response: {str(e)}")
207
- logging.error(f"Full traceback:", exc_info=True)
208
  return f"Error: {str(e)}"
209
 
210
  def clear_history(self):
211
  self.history = []
212
  return None
213
 
214
- # Add new function to control generation parameters
215
- def update_generation_params(self, temperature=0.3, top_p=0.92, top_k=50, repetition_penalty=1.2):
216
- """Update generation parameters to control hallucination tendency"""
217
- self.temperature = temperature
218
- self.top_p = top_p
219
- self.top_k = top_k
220
- self.repetition_penalty = repetition_penalty
221
- return f"Generation parameters updated: temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}"
222
-
223
  def create_demo():
224
- try:
225
- model = LLaVAPhiModel()
226
-
227
- with gr.Blocks(css="footer {visibility: hidden}") as demo:
228
- gr.Markdown(
229
- """
230
- # LLaVA-Phi Demo (Optimized for Accuracy)
231
- Chat with a vision-language model that can understand both text and images.
232
- """
233
- )
234
-
235
- chatbot = gr.Chatbot(height=400)
236
- with gr.Row():
237
- with gr.Column(scale=0.7):
238
- msg = gr.Textbox(
239
- show_label=False,
240
- placeholder="Enter text and/or upload an image",
241
- container=False
242
- )
243
- with gr.Column(scale=0.15, min_width=0):
244
- clear = gr.Button("Clear")
245
- with gr.Column(scale=0.15, min_width=0):
246
- submit = gr.Button("Submit", variant="primary")
247
-
248
- image = gr.Image(type="pil", label="Upload Image (Optional)")
249
-
250
- # Add generation parameter controls
251
- with gr.Accordion("Advanced Settings", open=False):
252
- gr.Markdown("Adjust these parameters to control hallucination tendency")
253
- temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)")
254
- top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)")
255
- top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k")
256
- rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
257
- update_params = gr.Button("Update Parameters")
258
-
259
- def respond(message, chat_history, image):
260
- if not message and image is None:
261
- return chat_history
262
-
263
- response = model.generate_response(message, image)
264
- chat_history.append((message, response))
265
- return "", chat_history
266
-
267
- def clear_chat():
268
- model.clear_history()
269
- return None, None
270
-
271
- def update_params_fn(temp, top_p, top_k, rep_penalty):
272
- return model.update_generation_params(temp, top_p, top_k, rep_penalty)
273
-
274
- submit.click(
275
- respond,
276
- [msg, chatbot, image],
277
- [msg, chatbot],
278
- )
279
-
280
- clear.click(
281
- clear_chat,
282
- None,
283
- [chatbot, image],
284
- )
285
-
286
- msg.submit(
287
- respond,
288
- [msg, chatbot, image],
289
- [msg, chatbot],
290
- )
291
-
292
- update_params.click(
293
- update_params_fn,
294
- [temp_slider, top_p_slider, top_k_slider, rep_penalty_slider],
295
- None
296
- )
297
-
298
- return demo
299
- except Exception as e:
300
- logging.error(f"Error creating demo: {str(e)}")
301
- raise
302
 
303
  if __name__ == "__main__":
304
  demo = create_demo()
305
- demo.launch(
306
- server_name="0.0.0.0",
307
- server_port=7860,
308
- share=True
309
- )
 
4
  from PIL import Image
5
  import logging
6
  import spaces
7
+ import numpy as np
8
 
9
  # Setup logging
10
  logging.basicConfig(level=logging.INFO)
 
15
  self.model_id = model_id
16
  logging.info("Initializing LLaVA-Phi model...")
17
 
 
18
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
19
  if self.tokenizer.pad_token is None:
20
  self.tokenizer.pad_token = self.tokenizer.eos_token
21
+
22
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
 
 
 
 
 
 
 
23
  self.history = []
24
  self.model = None
25
  self.clip = None
26
+
27
+ # Add a linear projection layer to align CLIP features with text embeddings
28
+ self.projection = None
29
 
30
  @spaces.GPU
31
  def ensure_models_loaded(self):
 
32
  if self.model is None:
 
33
  from transformers import BitsAndBytesConfig
34
  quantization_config = BitsAndBytesConfig(
35
+ load_in_8bit=True,
36
  bnb_8bit_compute_dtype=torch.float16,
37
  bnb_8bit_use_double_quant=False
38
  )
39
+ self.model = AutoModelForCausalLM.from_pretrained(
40
+ self.model_id,
41
+ quantization_config=quantization_config,
42
+ device_map="auto",
43
+ torch_dtype=torch.bfloat16,
44
+ trust_remote_code=True
45
+ )
46
+ self.model.config.pad_token_id = self.tokenizer.eos_token_id
47
+ logging.info("Successfully loaded main model")
 
 
 
 
 
48
 
49
  if self.clip is None:
50
+ self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
51
+ logging.info("Successfully loaded CLIP model")
52
+
53
+ # Initialize projection layer (CLIP features: 512-dim, model embedding size: e.g., 2048 for Phi)
54
+ embed_dim = self.model.config.hidden_size # e.g., 2048 for Phi-1.5
55
+ clip_dim = self.clip.config.projection_dim # 512 for CLIP
56
+ self.projection = torch.nn.Linear(clip_dim, embed_dim).to(self.device)
57
 
58
  @spaces.GPU
59
  def process_image(self, image):
 
60
  try:
61
  self.ensure_models_loaded()
 
62
  if self.clip is None or self.processor is None:
63
  logging.warning("CLIP model or processor not available")
64
  return None
65
 
 
66
  if isinstance(image, str):
67
  image = Image.open(image)
68
+ elif isinstance(image, np.ndarray):
69
  image = Image.fromarray(image)
 
 
70
  if image.mode != 'RGB':
71
  image = image.convert('RGB')
72
 
73
  with torch.no_grad():
74
+ image_inputs = self.processor(images=image, return_tensors="pt")
75
+ image_features = self.clip.get_image_features(
76
+ pixel_values=image_inputs.pixel_values.to(self.device)
77
+ )
78
+ # Project image features to text embedding space
79
+ projected_features = self.projection(image_features)
80
+ logging.info("Successfully processed image through CLIP")
81
+ return projected_features
 
 
 
82
  except Exception as e:
83
  logging.error(f"Error in process_image: {str(e)}")
84
  return None
 
95
  message = "Note: Image processing is not available - continuing with text only.\n" + message
96
 
97
  prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:"
 
 
98
  context = ""
99
  for turn in self.history[-5:]:
100
  context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
 
101
  full_prompt = context + prompt
102
 
 
103
  inputs = self.tokenizer(
104
  full_prompt,
105
  return_tensors="pt",
106
  padding=True,
107
  truncation=True,
108
+ max_length=1024
109
  )
110
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
111
 
112
  if has_image:
113
+ # Convert input_ids to embeddings
114
+ embeddings = self.model.get_input_embeddings()(inputs["input_ids"])
115
+ # Concatenate image features with text embeddings
116
+ image_features_expanded = image_features.unsqueeze(1) # Shape: [batch, 1, embed_dim]
117
+ combined_embeddings = torch.cat([image_features_expanded, embeddings], dim=1)
118
+ inputs["inputs_embeds"] = combined_embeddings
119
+ # Update attention mask to account for the extra image token
120
+ inputs["attention_mask"] = torch.cat(
121
+ [torch.ones(inputs["attention_mask"].shape[0], 1).to(self.device),
122
+ inputs["attention_mask"]],
123
+ dim=1
 
 
 
 
 
 
124
  )
125
+ # Remove input_ids since we're using inputs_embeds
126
+ del inputs["input_ids"]
127
  else:
128
  prompt = f"human: {message}\ngpt:"
 
129
  context = ""
130
  for turn in self.history[-5:]:
131
  context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
 
132
  full_prompt = context + prompt
133
 
 
134
  inputs = self.tokenizer(
135
  full_prompt,
136
  return_tensors="pt",
137
  padding=True,
138
  truncation=True,
139
+ max_length=1024
140
  )
141
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ with torch.no_grad():
144
+ outputs = self.model.generate(
145
+ **inputs,
146
+ max_new_tokens=256,
147
+ min_length=20,
148
+ temperature=0.3,
149
+ do_sample=True,
150
+ top_p=0.92,
151
+ top_k=50,
152
+ repetition_penalty=1.2,
153
+ no_repeat_ngram_size=3,
154
+ use_cache=True,
155
+ pad_token_id=self.tokenizer.pad_token_id,
156
+ eos_token_id=self.tokenizer.eos_token_id
157
+ )
158
 
159
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
160
  if "gpt:" in response:
161
  response = response.split("gpt:")[-1].strip()
162
  if "human:" in response:
 
169
 
170
  except Exception as e:
171
  logging.error(f"Error generating response: {str(e)}")
 
172
  return f"Error: {str(e)}"
173
 
174
  def clear_history(self):
175
  self.history = []
176
  return None
177
 
 
 
 
 
 
 
 
 
 
178
  def create_demo():
179
+ model = LLaVAPhiModel()
180
+ # Rest of your Gradio setup remains the same
181
+ # ... (omitted for brevity)
182
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  if __name__ == "__main__":
185
  demo = create_demo()
186
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)