yu-rp commited on
Commit
69d0bde
·
1 Parent(s): a531efa

add gpt box

Browse files
Files changed (1) hide show
  1. app.py +111 -21
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import os
2
  import gradio as gr
3
  import torch
 
 
 
4
 
5
  from API_LLaVA.functions import get_model as llava_get_model, get_preanswer as llava_get_preanswer, from_preanswer_to_mask as llava_from_preanswer_to_mask
6
  from API_LLaVA.hook import hook_logger as llava_hook_logger
@@ -23,8 +26,55 @@ MARKDOWN = """
23
  </div>
24
  """
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def init_clip():
27
- clip_model, clip_prs, clip_preprocess, _, clip_tokenizer = clip_get_model(model_name = "ViT-L-14-336", layer_index = 22, device= DEVICE)
 
 
28
  return {"clip_model": clip_model, "clip_prs": clip_prs, "clip_preprocess": clip_preprocess, "clip_tokenizer": clip_tokenizer}
29
 
30
  def init_llava():
@@ -133,21 +183,38 @@ image_output = gr.Image(
133
  text_query = gr.Textbox(
134
  label="Query",
135
  placeholder="Enter a query about the image",
136
- lines=4,
137
  type="text")
138
  text_pre_answer = gr.Textbox(
139
  label="LLaVA Response",
140
  info = 'Only used for LLaVA-Based API. Press "Pre-Answer" to generate the response.',
141
  placeholder="",
142
- lines=4,
143
  interactive=False,
144
  type="text")
145
  text_highlight_text = gr.Textbox(
146
  label = "Hint Text.",
147
- info = "The text based on which the mask will be generated. For CLIP-Based API, it should be a substring of the query. For LLaVA-Based API, it should be a substring of the pre-answer.",
148
  placeholder="Enter the hint text",
149
  lines=1,
150
  type="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  radio_api_method = gr.Radio(
153
  ["CLIP_Based API", "LLaVA_Based API"] if torch.cuda.is_available() else ["CLIP_Based API"],
@@ -187,38 +254,56 @@ radio_interpolate_method_name = gr.Radio(
187
 
188
  generate_llava_response_button = gr.Button("Pre-Answer", interactive=False)
189
  generate_mask_button = gr.Button("API Go!")
 
190
 
191
  with gr.Blocks() as demo:
192
  gr.Markdown(MARKDOWN)
193
  state_cache = gr.State({})
194
  state_model = gr.State(init_clip())
195
  with gr.Row():
196
- with gr.Column():
197
- image_input.render()
198
- with gr.Column():
199
- image_output.render()
200
- with gr.Row():
201
- radio_api_method.render()
202
- with gr.Row():
203
- with gr.Column():
204
- with gr.Row():
205
  text_query.render()
206
- with gr.Row():
207
  generate_llava_response_button.render()
208
- with gr.Row():
209
  text_pre_answer.render()
210
- with gr.Row():
211
  text_highlight_text.render()
212
- with gr.Column():
213
- with gr.Row():
214
  slider_enhance_coe.render()
215
- with gr.Row():
216
  slider_kernel_size.render()
217
- with gr.Row():
218
  radio_interpolate_method_name.render()
219
- with gr.Row():
220
  slider_mask_grayscale.render()
 
221
  generate_mask_button.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  radio_api_method.change(
224
  fn=change_api_method,
@@ -264,5 +349,10 @@ with gr.Blocks() as demo:
264
  ],
265
  outputs=[image_output, state_cache]
266
  )
 
 
 
 
 
267
 
268
  demo.queue(max_size = 1).launch(show_error=True)
 
1
  import os
2
  import gradio as gr
3
  import torch
4
+ import base64
5
+ import requests
6
+ from io import BytesIO
7
 
8
  from API_LLaVA.functions import get_model as llava_get_model, get_preanswer as llava_get_preanswer, from_preanswer_to_mask as llava_from_preanswer_to_mask
9
  from API_LLaVA.hook import hook_logger as llava_hook_logger
 
26
  </div>
27
  """
28
 
29
+ def get_base64_images(image):
30
+ image = image.convert('RGB')
31
+ buffer = BytesIO()
32
+ image.save(buffer, format='JPEG')
33
+ image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
34
+ return image_base64
35
+
36
+ def vqa(image, question, api_key):
37
+ base64_image = get_base64_images(image)
38
+ headers = {
39
+ "Content-Type": "application/json",
40
+ "Authorization": f"Bearer {api_key}"
41
+ }
42
+
43
+ payload = {
44
+ "model": "gpt-4-turbo-2024-04-09",
45
+ "messages": [
46
+ {
47
+ "role": "user",
48
+ "content": [
49
+ {
50
+ "type": "text",
51
+ "text": question
52
+ },
53
+ {
54
+ "type": "image_url",
55
+ "image_url": {
56
+ "url": f"data:image/jpeg;base64,{base64_image}",
57
+ "detail":"low"
58
+ }
59
+ }
60
+ ]
61
+ }
62
+ ],
63
+ "max_tokens": 300
64
+ }
65
+
66
+ response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
67
+ return response.json()["choices"][0]["message"]["content"]
68
+
69
+ def compare(input_image, output_image, query, api_key):
70
+ original_response = vqa(input_image, query, api_key)
71
+ api_response = vqa(output_image, query, api_key)
72
+ return original_response, api_response
73
+
74
  def init_clip():
75
+ clip_model, clip_prs, clip_preprocess, _, clip_tokenizer = clip_get_model(
76
+ model_name = "ViT-L-14-336" if torch.cuda.is_available() else "ViT-L-14",
77
+ layer_index = 22, device= DEVICE)
78
  return {"clip_model": clip_model, "clip_prs": clip_prs, "clip_preprocess": clip_preprocess, "clip_tokenizer": clip_tokenizer}
79
 
80
  def init_llava():
 
183
  text_query = gr.Textbox(
184
  label="Query",
185
  placeholder="Enter a query about the image",
186
+ lines=2,
187
  type="text")
188
  text_pre_answer = gr.Textbox(
189
  label="LLaVA Response",
190
  info = 'Only used for LLaVA-Based API. Press "Pre-Answer" to generate the response.',
191
  placeholder="",
192
+ lines=2,
193
  interactive=False,
194
  type="text")
195
  text_highlight_text = gr.Textbox(
196
  label = "Hint Text.",
197
+ info = "The text based on which the mask will be generated. For LLaVA-Based API, it should be a substring of the pre-answer.",
198
  placeholder="Enter the hint text",
199
  lines=1,
200
  type="text")
201
+ text_api_token = gr.Textbox(
202
+ label = "OpenAI API Token",
203
+ placeholder="Input your OpenAI API token",
204
+ lines=1,
205
+ type="text")
206
+ text_original_image_response = gr.Textbox(
207
+ label="GPT Response (Original Image)",
208
+ placeholder="",
209
+ lines=2,
210
+ interactive=False,
211
+ type="text")
212
+ text_API_image_response = gr.Textbox(
213
+ label="GPT Response (API-maksed Image)",
214
+ placeholder="",
215
+ lines=2,
216
+ interactive=False,
217
+ type="text")
218
 
219
  radio_api_method = gr.Radio(
220
  ["CLIP_Based API", "LLaVA_Based API"] if torch.cuda.is_available() else ["CLIP_Based API"],
 
254
 
255
  generate_llava_response_button = gr.Button("Pre-Answer", interactive=False)
256
  generate_mask_button = gr.Button("API Go!")
257
+ ask_gpt_button = gr.Button("GPT Go!")
258
 
259
  with gr.Blocks() as demo:
260
  gr.Markdown(MARKDOWN)
261
  state_cache = gr.State({})
262
  state_model = gr.State(init_clip())
263
  with gr.Row():
264
+ image_input.render()
265
+ image_output.render()
266
+ with gr.Accordion("Query and API Processing"):
267
+ with gr.Row():
268
+ radio_api_method.render()
269
+ with gr.Row(equal_height=True):
270
+ with gr.Column():
 
 
271
  text_query.render()
 
272
  generate_llava_response_button.render()
 
273
  text_pre_answer.render()
 
274
  text_highlight_text.render()
275
+ with gr.Column():
 
276
  slider_enhance_coe.render()
 
277
  slider_kernel_size.render()
 
278
  radio_interpolate_method_name.render()
 
279
  slider_mask_grayscale.render()
280
+ with gr.Row():
281
  generate_mask_button.render()
282
+ with gr.Accordion("GPT Response"):
283
+ text_api_token.render()
284
+ ask_gpt_button.render()
285
+ with gr.Row():
286
+ text_original_image_response.render()
287
+ text_API_image_response.render()
288
+ with gr.Accordion("Examples"):
289
+ examples_images_responses = gr.Examples(
290
+ [
291
+
292
+ ],
293
+ [
294
+ image_input,
295
+ image_output,
296
+ text_query,
297
+ text_pre_answer,
298
+ text_highlight_text,
299
+ slider_enhance_coe,
300
+ slider_kernel_size,
301
+ radio_interpolate_method_name,
302
+ slider_mask_grayscale,
303
+ text_original_image_response,
304
+ text_API_image_response
305
+ ],
306
+ )
307
 
308
  radio_api_method.change(
309
  fn=change_api_method,
 
349
  ],
350
  outputs=[image_output, state_cache]
351
  )
352
+ ask_gpt_button.click(
353
+ fn=compare,
354
+ inputs=[image_input, image_output, text_query, text_api_token],
355
+ outputs=[text_original_image_response, text_API_image_response]
356
+ )
357
 
358
  demo.queue(max_size = 1).launch(show_error=True)