cnzzx commited on
Commit
7b85afa
1 Parent(s): f16f78e
Files changed (1) hide show
  1. models/vsa_model.py +13 -30
models/vsa_model.py CHANGED
@@ -41,7 +41,8 @@ from lmdeploy.messages import PytorchEngineConfig
41
  from typing import List, Union
42
 
43
  SEARCH_MODEL_NAMES = {
44
- 'internlm2_5-7b-chat': 'internlm2'
 
45
  }
46
 
47
 
@@ -125,7 +126,7 @@ class VLM:
125
  load_8bit: bool = False,
126
  load_4bit: bool = True,
127
  temperature: float = 0.2,
128
- max_new_tokens: int = 2000,
129
  ):
130
  disable_torch_init()
131
  model_name = get_model_name_from_path(model_path)
@@ -325,6 +326,16 @@ class VisionSearchAssistant:
325
  self.searcher = WebSearcher(
326
  model_path = self.search_model
327
  )
 
 
 
 
 
 
 
 
 
 
328
 
329
  def app_run(
330
  self,
@@ -352,10 +363,6 @@ class VisionSearchAssistant:
352
  raise Exception('Unsupported input image format.')
353
 
354
  # Visual Grounding
355
- self.grounder = VisualGrounder(
356
- model_path = self.ground_model,
357
- device = self.ground_device,
358
- )
359
  bboxes, labels, out_image = self.grounder(in_image, classes = ground_classes)
360
  yield out_image, 'ground'
361
 
@@ -370,17 +377,7 @@ class VisionSearchAssistant:
370
  det_images.append(in_image)
371
  labels.append('image')
372
 
373
- del self.grounder
374
- torch.cuda.empty_cache()
375
- torch.cuda.synchronize()
376
-
377
  # Visual Captioning
378
- self.vlm = VLM(
379
- model_path = self.vlm_model,
380
- device = self.vlm_device,
381
- load_4bit = self.vlm_load_4bit,
382
- load_8bit = self.vlm_load_8bit
383
- )
384
  captions = []
385
  for det_image, label in zip(det_images, labels):
386
  inp = get_caption_prompt(label, text)
@@ -414,21 +411,11 @@ class VisionSearchAssistant:
414
 
415
  queries = [text + " " + query for query in queries]
416
 
417
- del self.vlm
418
- torch.cuda.empty_cache()
419
- torch.cuda.synchronize()
420
-
421
  # Web Searching
422
  contexts = self.searcher(queries)
423
  yield contexts, 'search'
424
 
425
  # QA
426
- self.vlm = VLM(
427
- model_path = self.vlm_model,
428
- device = self.vlm_device,
429
- load_4bit = self.vlm_load_4bit,
430
- load_8bit = self.vlm_load_8bit
431
- )
432
  TOKEN_LIMIT = 3500
433
  max_length_per_context = TOKEN_LIMIT // len(contexts)
434
  for cid, context in enumerate(contexts):
@@ -442,7 +429,3 @@ class VisionSearchAssistant:
442
  print(answer)
443
 
444
  yield answer, 'answer'
445
-
446
- del self.vlm
447
- torch.cuda.empty_cache()
448
- torch.cuda.synchronize()
 
41
  from typing import List, Union
42
 
43
  SEARCH_MODEL_NAMES = {
44
+ 'internlm2_5-7b-chat': 'internlm2',
45
+ 'internlm2_5-1_8b-chat': 'internlm2'
46
  }
47
 
48
 
 
126
  load_8bit: bool = False,
127
  load_4bit: bool = True,
128
  temperature: float = 0.2,
129
+ max_new_tokens: int = 1024,
130
  ):
131
  disable_torch_init()
132
  model_name = get_model_name_from_path(model_path)
 
326
  self.searcher = WebSearcher(
327
  model_path = self.search_model
328
  )
329
+ self.grounder = VisualGrounder(
330
+ model_path = self.ground_model,
331
+ device = self.ground_device,
332
+ )
333
+ self.vlm = VLM(
334
+ model_path = self.vlm_model,
335
+ device = self.vlm_device,
336
+ load_4bit = self.vlm_load_4bit,
337
+ load_8bit = self.vlm_load_8bit
338
+ )
339
 
340
  def app_run(
341
  self,
 
363
  raise Exception('Unsupported input image format.')
364
 
365
  # Visual Grounding
 
 
 
 
366
  bboxes, labels, out_image = self.grounder(in_image, classes = ground_classes)
367
  yield out_image, 'ground'
368
 
 
377
  det_images.append(in_image)
378
  labels.append('image')
379
 
 
 
 
 
380
  # Visual Captioning
 
 
 
 
 
 
381
  captions = []
382
  for det_image, label in zip(det_images, labels):
383
  inp = get_caption_prompt(label, text)
 
411
 
412
  queries = [text + " " + query for query in queries]
413
 
 
 
 
 
414
  # Web Searching
415
  contexts = self.searcher(queries)
416
  yield contexts, 'search'
417
 
418
  # QA
 
 
 
 
 
 
419
  TOKEN_LIMIT = 3500
420
  max_length_per_context = TOKEN_LIMIT // len(contexts)
421
  for cid, context in enumerate(contexts):
 
429
  print(answer)
430
 
431
  yield answer, 'answer'