sagar007 commited on
Commit
2144e66
·
verified ·
1 Parent(s): 5c998e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -84
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoConfig, AutoModel
4
  from PIL import Image
5
  import logging
6
- from transformers import BitsAndBytesConfig
7
 
8
  # Setup logging
9
  logging.basicConfig(level=logging.INFO)
@@ -13,24 +12,36 @@ class LLaVAPhiModel:
13
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  logging.info(f"Using device: {self.device}")
15
 
16
- # Initialize quantization config
17
- quantization_config = BitsAndBytesConfig(
18
- load_in_4bit=True,
19
- bnb_4bit_compute_dtype=torch.float16,
20
- bnb_4bit_use_double_quant=True,
21
- bnb_4bit_quant_type="nf4"
22
- )
23
-
24
  try:
25
- # Load model directly from Hugging Face Hub
26
  logging.info(f"Loading model from {model_id}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  self.model = AutoModelForCausalLM.from_pretrained(
28
  model_id,
29
- quantization_config=quantization_config,
30
- device_map="auto",
31
- torch_dtype=torch.bfloat16,
32
- trust_remote_code=True
33
  )
 
34
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
35
 
36
  # Set up padding token
@@ -49,24 +60,41 @@ class LLaVAPhiModel:
49
  except Exception as e:
50
  logging.error(f"Error initializing model: {str(e)}")
51
  raise
52
-
53
  def process_image(self, image):
54
  """Process image through CLIP"""
55
- with torch.no_grad():
56
- image_inputs = self.processor(images=image, return_tensors="pt")
57
- image_features = self.clip.get_image_features(
58
- pixel_values=image_inputs.pixel_values.to(self.device)
59
- )
60
- return image_features
61
-
 
 
 
 
 
 
 
 
 
 
62
  def generate_response(self, message, image=None):
63
  try:
64
  if image is not None:
65
- # Get image features
66
- image_features = self.process_image(image)
 
 
 
 
 
 
 
67
 
68
  # Format prompt
69
- prompt = f"human: <image>\n{message}\ngpt:"
70
 
71
  # Add context from history
72
  context = ""
@@ -85,8 +113,9 @@ class LLaVAPhiModel:
85
  )
86
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
87
 
88
- # Add image features to inputs
89
- inputs["image_features"] = image_features
 
90
 
91
  # Generate response
92
  with torch.no_grad():
@@ -163,63 +192,67 @@ class LLaVAPhiModel:
163
  return None
164
 
165
  def create_demo():
166
- # Initialize model
167
- model = LLaVAPhiModel()
168
-
169
- with gr.Blocks(css="footer {visibility: hidden}") as demo:
170
- gr.Markdown(
171
- """
172
- # LLaVA-Phi Demo
173
- Chat with a vision-language model that can understand both text and images.
174
- """
175
- )
176
-
177
- chatbot = gr.Chatbot(height=400)
178
- with gr.Row():
179
- with gr.Column(scale=0.7):
180
- msg = gr.Textbox(
181
- show_label=False,
182
- placeholder="Enter text and/or upload an image",
183
- container=False
184
- )
185
- with gr.Column(scale=0.15, min_width=0):
186
- clear = gr.Button("Clear")
187
- with gr.Column(scale=0.15, min_width=0):
188
- submit = gr.Button("Submit", variant="primary")
189
 
190
- image = gr.Image(type="pil", label="Upload Image (Optional)")
191
-
192
- def respond(message, chat_history, image):
193
- if not message and image is None:
194
- return chat_history
 
 
195
 
196
- response = model.generate_response(message, image)
197
- chat_history.append((message, response))
198
- return "", chat_history
199
-
200
- def clear_chat():
201
- model.clear_history()
202
- return None, None
203
-
204
- submit.click(
205
- respond,
206
- [msg, chatbot, image],
207
- [msg, chatbot],
208
- )
209
-
210
- clear.click(
211
- clear_chat,
212
- None,
213
- [chatbot, image],
214
- )
215
-
216
- msg.submit(
217
- respond,
218
- [msg, chatbot, image],
219
- [msg, chatbot],
220
- )
221
-
222
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
  if __name__ == "__main__":
225
  demo = create_demo()
@@ -227,4 +260,4 @@ if __name__ == "__main__":
227
  server_name="0.0.0.0",
228
  server_port=7860,
229
  share=True
230
- )
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoModel
4
  from PIL import Image
5
  import logging
 
6
 
7
  # Setup logging
8
  logging.basicConfig(level=logging.INFO)
 
12
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  logging.info(f"Using device: {self.device}")
14
 
 
 
 
 
 
 
 
 
15
  try:
16
+ # Load model with appropriate settings based on available hardware
17
  logging.info(f"Loading model from {model_id}...")
18
+
19
+ # Determine model loading configuration
20
+ model_kwargs = {
21
+ "device_map": "auto",
22
+ "trust_remote_code": True
23
+ }
24
+
25
+ # Add quantization only if CUDA is available
26
+ if torch.cuda.is_available():
27
+ from transformers import BitsAndBytesConfig
28
+ quantization_config = BitsAndBytesConfig(
29
+ load_in_4bit=True,
30
+ bnb_4bit_compute_dtype=torch.float16,
31
+ bnb_4bit_use_double_quant=True,
32
+ bnb_4bit_quant_type="nf4"
33
+ )
34
+ model_kwargs["quantization_config"] = quantization_config
35
+ model_kwargs["torch_dtype"] = torch.bfloat16
36
+ else:
37
+ # For CPU, use lighter configuration
38
+ model_kwargs["torch_dtype"] = torch.float32
39
+
40
  self.model = AutoModelForCausalLM.from_pretrained(
41
  model_id,
42
+ **model_kwargs
 
 
 
43
  )
44
+
45
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
46
 
47
  # Set up padding token
 
60
  except Exception as e:
61
  logging.error(f"Error initializing model: {str(e)}")
62
  raise
63
+
64
  def process_image(self, image):
65
  """Process image through CLIP"""
66
+ try:
67
+ # Ensure image is in correct format
68
+ if isinstance(image, str): # If image path is provided
69
+ image = Image.open(image)
70
+ elif isinstance(image, numpy.ndarray): # If numpy array (from gradio)
71
+ image = Image.fromarray(image)
72
+
73
+ with torch.no_grad():
74
+ image_inputs = self.processor(images=image, return_tensors="pt")
75
+ image_features = self.clip.get_image_features(
76
+ pixel_values=image_inputs.pixel_values.to(self.device)
77
+ )
78
+ return image_features
79
+ except Exception as e:
80
+ logging.error(f"Error processing image: {str(e)}")
81
+ raise
82
+
83
  def generate_response(self, message, image=None):
84
  try:
85
  if image is not None:
86
+ try:
87
+ # Get image features
88
+ image_features = self.process_image(image)
89
+ has_image = True
90
+ except Exception as e:
91
+ logging.error(f"Failed to process image: {str(e)}")
92
+ image_features = None
93
+ has_image = False
94
+ message = f"Note: Failed to process image. Continuing with text only. Error: {str(e)}\n{message}"
95
 
96
  # Format prompt
97
+ prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:"
98
 
99
  # Add context from history
100
  context = ""
 
113
  )
114
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
115
 
116
+ # Add image features to inputs if available
117
+ if has_image:
118
+ inputs["image_features"] = image_features
119
 
120
  # Generate response
121
  with torch.no_grad():
 
192
  return None
193
 
194
  def create_demo():
195
+ try:
196
+ # Initialize model
197
+ model = LLaVAPhiModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
200
+ gr.Markdown(
201
+ """
202
+ # LLaVA-Phi Demo
203
+ Chat with a vision-language model that can understand both text and images.
204
+ """
205
+ )
206
 
207
+ chatbot = gr.Chatbot(height=400)
208
+ with gr.Row():
209
+ with gr.Column(scale=0.7):
210
+ msg = gr.Textbox(
211
+ show_label=False,
212
+ placeholder="Enter text and/or upload an image",
213
+ container=False
214
+ )
215
+ with gr.Column(scale=0.15, min_width=0):
216
+ clear = gr.Button("Clear")
217
+ with gr.Column(scale=0.15, min_width=0):
218
+ submit = gr.Button("Submit", variant="primary")
219
+
220
+ image = gr.Image(type="pil", label="Upload Image (Optional)")
221
+
222
+ def respond(message, chat_history, image):
223
+ if not message and image is None:
224
+ return chat_history
225
+
226
+ response = model.generate_response(message, image)
227
+ chat_history.append((message, response))
228
+ return "", chat_history
229
+
230
+ def clear_chat():
231
+ model.clear_history()
232
+ return None, None
233
+
234
+ submit.click(
235
+ respond,
236
+ [msg, chatbot, image],
237
+ [msg, chatbot],
238
+ )
239
+
240
+ clear.click(
241
+ clear_chat,
242
+ None,
243
+ [chatbot, image],
244
+ )
245
+
246
+ msg.submit(
247
+ respond,
248
+ [msg, chatbot, image],
249
+ [msg, chatbot],
250
+ )
251
+
252
+ return demo
253
+ except Exception as e:
254
+ logging.error(f"Error creating demo: {str(e)}")
255
+ raise
256
 
257
  if __name__ == "__main__":
258
  demo = create_demo()
 
260
  server_name="0.0.0.0",
261
  server_port=7860,
262
  share=True
263
+ )