cnzzx commited on
Commit
271c21d
·
1 Parent(s): c640227
Files changed (1) hide show
  1. models/vsa_model.py +3 -93
models/vsa_model.py CHANGED
@@ -298,12 +298,12 @@ class VisionSearchAssistant:
298
  self.use_correlate = True
299
 
300
  @spaces.GPU
301
- def __call__(
302
  self,
303
  image: Union[str, Image.Image, np.ndarray],
304
  text: str,
305
- ground_classes: Union[List[str], None] = None
306
- ):
307
  self.searcher = WebSearcher(
308
  model_path = self.search_model
309
  )
@@ -318,96 +318,6 @@ class VisionSearchAssistant:
318
  load_8bit = self.vlm_load_8bit
319
  )
320
 
321
- # Create and clear the temporary directory.
322
- if not os.access('temp', os.F_OK):
323
- os.makedirs('temp')
324
- for file in os.listdir('temp'):
325
- os.remove(os.path.join('temp', file))
326
-
327
- with open('temp/text.txt', 'w', encoding='utf-8') as wf:
328
- wf.write(text)
329
-
330
- # Load Image
331
- if isinstance(image, str):
332
- in_image = Image.open(image)
333
- elif isinstance(image, Image.Image):
334
- in_image = image
335
- elif isinstance(image, np.ndarray):
336
- in_image = Image.fromarray(image.astype(np.uint8))
337
- else:
338
- raise Exception('Unsupported input image format.')
339
-
340
- # Visual Grounding
341
- bboxes, labels, out_image = self.grounder(in_image, classes = ground_classes)
342
-
343
- det_images = []
344
- for bid, bbox in enumerate(bboxes):
345
- crop_box = (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]))
346
- det_image = in_image.crop(crop_box)
347
- det_image.save('temp/debug_bbox_image_{}.jpg'.format(bid))
348
- det_images.append(det_image)
349
-
350
- if len(det_images) == 0: # No object detected, use the full image.
351
- det_images.append(in_image)
352
- labels.append('image')
353
-
354
- # Visual Captioning
355
- captions = []
356
- for det_image, label in zip(det_images, labels):
357
- inp = get_caption_prompt(label, text)
358
- caption = self.vlm(det_image, inp)
359
- captions.append(caption)
360
-
361
- for cid, caption in enumerate(captions):
362
- with open('temp/caption_{}.txt'.format(cid), 'w', encoding='utf-8') as wf:
363
- wf.write(caption)
364
-
365
- # Visual Correlation
366
- if len(captions) >= 2 and self.use_correlate:
367
- queries = []
368
- for mid, det_image in enumerate(det_images):
369
- caption = captions[mid]
370
- other_captions = []
371
- for cid in range(len(captions)):
372
- if cid == mid:
373
- continue
374
- other_captions.append(captions[cid])
375
- inp = get_correlate_prompt(caption, other_captions)
376
- query = self.vlm(det_image, inp)
377
- queries.append(query)
378
- else:
379
- queries = captions
380
-
381
- for qid, query in enumerate(queries):
382
- with open('temp/query_{}.txt'.format(qid), 'w', encoding='utf-8') as wf:
383
- wf.write(query)
384
-
385
- queries = [text + " " + query for query in queries]
386
-
387
- # Web Searching
388
- contexts = self.searcher(queries)
389
-
390
- # QA
391
- TOKEN_LIMIT = 3500
392
- max_length_per_context = TOKEN_LIMIT // len(contexts)
393
- for cid, context in enumerate(contexts):
394
- contexts[cid] = (queries[cid] + context)[:max_length_per_context]
395
-
396
- inp = get_qa_prompt(text, contexts)
397
- answer = self.vlm(in_image, inp)
398
-
399
- with open('temp/answer.txt', 'w', encoding='utf-8') as wf:
400
- wf.write(answer)
401
- print(answer)
402
-
403
- return answer
404
-
405
- def app_run(
406
- self,
407
- image: Union[str, Image.Image, np.ndarray],
408
- text: str,
409
- ground_classes: List[str] = COCO_CLASSES
410
- ):
411
  # Create and clear the temporary directory.
412
  if not os.access('temp', os.F_OK):
413
  os.makedirs('temp')
 
298
  self.use_correlate = True
299
 
300
  @spaces.GPU
301
+ def app_run(
302
  self,
303
  image: Union[str, Image.Image, np.ndarray],
304
  text: str,
305
+ ground_classes: List[str] = COCO_CLASSES
306
+ ):
307
  self.searcher = WebSearcher(
308
  model_path = self.search_model
309
  )
 
318
  load_8bit = self.vlm_load_8bit
319
  )
320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  # Create and clear the temporary directory.
322
  if not os.access('temp', os.F_OK):
323
  os.makedirs('temp')