Spaces:
Runtime error
Runtime error
update
Browse files- models/vsa_model.py +3 -93
models/vsa_model.py
CHANGED
@@ -298,12 +298,12 @@ class VisionSearchAssistant:
|
|
298 |
self.use_correlate = True
|
299 |
|
300 |
@spaces.GPU
|
301 |
-
def
|
302 |
self,
|
303 |
image: Union[str, Image.Image, np.ndarray],
|
304 |
text: str,
|
305 |
-
ground_classes:
|
306 |
-
):
|
307 |
self.searcher = WebSearcher(
|
308 |
model_path = self.search_model
|
309 |
)
|
@@ -318,96 +318,6 @@ class VisionSearchAssistant:
|
|
318 |
load_8bit = self.vlm_load_8bit
|
319 |
)
|
320 |
|
321 |
-
# Create and clear the temporary directory.
|
322 |
-
if not os.access('temp', os.F_OK):
|
323 |
-
os.makedirs('temp')
|
324 |
-
for file in os.listdir('temp'):
|
325 |
-
os.remove(os.path.join('temp', file))
|
326 |
-
|
327 |
-
with open('temp/text.txt', 'w', encoding='utf-8') as wf:
|
328 |
-
wf.write(text)
|
329 |
-
|
330 |
-
# Load Image
|
331 |
-
if isinstance(image, str):
|
332 |
-
in_image = Image.open(image)
|
333 |
-
elif isinstance(image, Image.Image):
|
334 |
-
in_image = image
|
335 |
-
elif isinstance(image, np.ndarray):
|
336 |
-
in_image = Image.fromarray(image.astype(np.uint8))
|
337 |
-
else:
|
338 |
-
raise Exception('Unsupported input image format.')
|
339 |
-
|
340 |
-
# Visual Grounding
|
341 |
-
bboxes, labels, out_image = self.grounder(in_image, classes = ground_classes)
|
342 |
-
|
343 |
-
det_images = []
|
344 |
-
for bid, bbox in enumerate(bboxes):
|
345 |
-
crop_box = (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]))
|
346 |
-
det_image = in_image.crop(crop_box)
|
347 |
-
det_image.save('temp/debug_bbox_image_{}.jpg'.format(bid))
|
348 |
-
det_images.append(det_image)
|
349 |
-
|
350 |
-
if len(det_images) == 0: # No object detected, use the full image.
|
351 |
-
det_images.append(in_image)
|
352 |
-
labels.append('image')
|
353 |
-
|
354 |
-
# Visual Captioning
|
355 |
-
captions = []
|
356 |
-
for det_image, label in zip(det_images, labels):
|
357 |
-
inp = get_caption_prompt(label, text)
|
358 |
-
caption = self.vlm(det_image, inp)
|
359 |
-
captions.append(caption)
|
360 |
-
|
361 |
-
for cid, caption in enumerate(captions):
|
362 |
-
with open('temp/caption_{}.txt'.format(cid), 'w', encoding='utf-8') as wf:
|
363 |
-
wf.write(caption)
|
364 |
-
|
365 |
-
# Visual Correlation
|
366 |
-
if len(captions) >= 2 and self.use_correlate:
|
367 |
-
queries = []
|
368 |
-
for mid, det_image in enumerate(det_images):
|
369 |
-
caption = captions[mid]
|
370 |
-
other_captions = []
|
371 |
-
for cid in range(len(captions)):
|
372 |
-
if cid == mid:
|
373 |
-
continue
|
374 |
-
other_captions.append(captions[cid])
|
375 |
-
inp = get_correlate_prompt(caption, other_captions)
|
376 |
-
query = self.vlm(det_image, inp)
|
377 |
-
queries.append(query)
|
378 |
-
else:
|
379 |
-
queries = captions
|
380 |
-
|
381 |
-
for qid, query in enumerate(queries):
|
382 |
-
with open('temp/query_{}.txt'.format(qid), 'w', encoding='utf-8') as wf:
|
383 |
-
wf.write(query)
|
384 |
-
|
385 |
-
queries = [text + " " + query for query in queries]
|
386 |
-
|
387 |
-
# Web Searching
|
388 |
-
contexts = self.searcher(queries)
|
389 |
-
|
390 |
-
# QA
|
391 |
-
TOKEN_LIMIT = 3500
|
392 |
-
max_length_per_context = TOKEN_LIMIT // len(contexts)
|
393 |
-
for cid, context in enumerate(contexts):
|
394 |
-
contexts[cid] = (queries[cid] + context)[:max_length_per_context]
|
395 |
-
|
396 |
-
inp = get_qa_prompt(text, contexts)
|
397 |
-
answer = self.vlm(in_image, inp)
|
398 |
-
|
399 |
-
with open('temp/answer.txt', 'w', encoding='utf-8') as wf:
|
400 |
-
wf.write(answer)
|
401 |
-
print(answer)
|
402 |
-
|
403 |
-
return answer
|
404 |
-
|
405 |
-
def app_run(
|
406 |
-
self,
|
407 |
-
image: Union[str, Image.Image, np.ndarray],
|
408 |
-
text: str,
|
409 |
-
ground_classes: List[str] = COCO_CLASSES
|
410 |
-
):
|
411 |
# Create and clear the temporary directory.
|
412 |
if not os.access('temp', os.F_OK):
|
413 |
os.makedirs('temp')
|
|
|
298 |
self.use_correlate = True
|
299 |
|
300 |
@spaces.GPU
|
301 |
+
def app_run(
|
302 |
self,
|
303 |
image: Union[str, Image.Image, np.ndarray],
|
304 |
text: str,
|
305 |
+
ground_classes: List[str] = COCO_CLASSES
|
306 |
+
):
|
307 |
self.searcher = WebSearcher(
|
308 |
model_path = self.search_model
|
309 |
)
|
|
|
318 |
load_8bit = self.vlm_load_8bit
|
319 |
)
|
320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
# Create and clear the temporary directory.
|
322 |
if not os.access('temp', os.F_OK):
|
323 |
os.makedirs('temp')
|