Spaces:
Sleeping
Sleeping
update
Browse files- models/vsa_model.py +13 -30
models/vsa_model.py
CHANGED
@@ -41,7 +41,8 @@ from lmdeploy.messages import PytorchEngineConfig
|
|
41 |
from typing import List, Union
|
42 |
|
43 |
SEARCH_MODEL_NAMES = {
|
44 |
-
'internlm2_5-7b-chat': 'internlm2'
|
|
|
45 |
}
|
46 |
|
47 |
|
@@ -125,7 +126,7 @@ class VLM:
|
|
125 |
load_8bit: bool = False,
|
126 |
load_4bit: bool = True,
|
127 |
temperature: float = 0.2,
|
128 |
-
max_new_tokens: int =
|
129 |
):
|
130 |
disable_torch_init()
|
131 |
model_name = get_model_name_from_path(model_path)
|
@@ -325,6 +326,16 @@ class VisionSearchAssistant:
|
|
325 |
self.searcher = WebSearcher(
|
326 |
model_path = self.search_model
|
327 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
|
329 |
def app_run(
|
330 |
self,
|
@@ -352,10 +363,6 @@ class VisionSearchAssistant:
|
|
352 |
raise Exception('Unsupported input image format.')
|
353 |
|
354 |
# Visual Grounding
|
355 |
-
self.grounder = VisualGrounder(
|
356 |
-
model_path = self.ground_model,
|
357 |
-
device = self.ground_device,
|
358 |
-
)
|
359 |
bboxes, labels, out_image = self.grounder(in_image, classes = ground_classes)
|
360 |
yield out_image, 'ground'
|
361 |
|
@@ -370,17 +377,7 @@ class VisionSearchAssistant:
|
|
370 |
det_images.append(in_image)
|
371 |
labels.append('image')
|
372 |
|
373 |
-
del self.grounder
|
374 |
-
torch.cuda.empty_cache()
|
375 |
-
torch.cuda.synchronize()
|
376 |
-
|
377 |
# Visual Captioning
|
378 |
-
self.vlm = VLM(
|
379 |
-
model_path = self.vlm_model,
|
380 |
-
device = self.vlm_device,
|
381 |
-
load_4bit = self.vlm_load_4bit,
|
382 |
-
load_8bit = self.vlm_load_8bit
|
383 |
-
)
|
384 |
captions = []
|
385 |
for det_image, label in zip(det_images, labels):
|
386 |
inp = get_caption_prompt(label, text)
|
@@ -414,21 +411,11 @@ class VisionSearchAssistant:
|
|
414 |
|
415 |
queries = [text + " " + query for query in queries]
|
416 |
|
417 |
-
del self.vlm
|
418 |
-
torch.cuda.empty_cache()
|
419 |
-
torch.cuda.synchronize()
|
420 |
-
|
421 |
# Web Searching
|
422 |
contexts = self.searcher(queries)
|
423 |
yield contexts, 'search'
|
424 |
|
425 |
# QA
|
426 |
-
self.vlm = VLM(
|
427 |
-
model_path = self.vlm_model,
|
428 |
-
device = self.vlm_device,
|
429 |
-
load_4bit = self.vlm_load_4bit,
|
430 |
-
load_8bit = self.vlm_load_8bit
|
431 |
-
)
|
432 |
TOKEN_LIMIT = 3500
|
433 |
max_length_per_context = TOKEN_LIMIT // len(contexts)
|
434 |
for cid, context in enumerate(contexts):
|
@@ -442,7 +429,3 @@ class VisionSearchAssistant:
|
|
442 |
print(answer)
|
443 |
|
444 |
yield answer, 'answer'
|
445 |
-
|
446 |
-
del self.vlm
|
447 |
-
torch.cuda.empty_cache()
|
448 |
-
torch.cuda.synchronize()
|
|
|
41 |
from typing import List, Union
|
42 |
|
43 |
SEARCH_MODEL_NAMES = {
|
44 |
+
'internlm2_5-7b-chat': 'internlm2',
|
45 |
+
'internlm2_5-1_8b-chat': 'internlm2'
|
46 |
}
|
47 |
|
48 |
|
|
|
126 |
load_8bit: bool = False,
|
127 |
load_4bit: bool = True,
|
128 |
temperature: float = 0.2,
|
129 |
+
max_new_tokens: int = 1024,
|
130 |
):
|
131 |
disable_torch_init()
|
132 |
model_name = get_model_name_from_path(model_path)
|
|
|
326 |
self.searcher = WebSearcher(
|
327 |
model_path = self.search_model
|
328 |
)
|
329 |
+
self.grounder = VisualGrounder(
|
330 |
+
model_path = self.ground_model,
|
331 |
+
device = self.ground_device,
|
332 |
+
)
|
333 |
+
self.vlm = VLM(
|
334 |
+
model_path = self.vlm_model,
|
335 |
+
device = self.vlm_device,
|
336 |
+
load_4bit = self.vlm_load_4bit,
|
337 |
+
load_8bit = self.vlm_load_8bit
|
338 |
+
)
|
339 |
|
340 |
def app_run(
|
341 |
self,
|
|
|
363 |
raise Exception('Unsupported input image format.')
|
364 |
|
365 |
# Visual Grounding
|
|
|
|
|
|
|
|
|
366 |
bboxes, labels, out_image = self.grounder(in_image, classes = ground_classes)
|
367 |
yield out_image, 'ground'
|
368 |
|
|
|
377 |
det_images.append(in_image)
|
378 |
labels.append('image')
|
379 |
|
|
|
|
|
|
|
|
|
380 |
# Visual Captioning
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
captions = []
|
382 |
for det_image, label in zip(det_images, labels):
|
383 |
inp = get_caption_prompt(label, text)
|
|
|
411 |
|
412 |
queries = [text + " " + query for query in queries]
|
413 |
|
|
|
|
|
|
|
|
|
414 |
# Web Searching
|
415 |
contexts = self.searcher(queries)
|
416 |
yield contexts, 'search'
|
417 |
|
418 |
# QA
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
TOKEN_LIMIT = 3500
|
420 |
max_length_per_context = TOKEN_LIMIT // len(contexts)
|
421 |
for cid, context in enumerate(contexts):
|
|
|
429 |
print(answer)
|
430 |
|
431 |
yield answer, 'answer'
|
|
|
|
|
|
|
|