VSA / models /vsa_model.py
cnzzx's picture
update
a0498a4
# Vision Search Assistant: Empower Vision-Language Models as Multimodal Search Engines
# Github source: https://github.com/cnzzx/VSA-dev
# Licensed under The Apache License 2.0 License [see LICENSE for details]
# Based on LLaVA and MindSearch code bases
# https://github.com/haotian-liu/LLaVA
# https://github.com/IDEA-Research/GroundingDINO
# https://github.com/InternLM/MindSearch
# --------------------------------------------------------
import os
import copy
import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from .vsa_prompt import COCO_CLASSES, get_caption_prompt, get_correlate_prompt, get_qa_prompt
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from datetime import datetime
from lagent.actions import ActionExecutor, BingBrowser
from lagent.llms import INTERNLM2_META, LMDeployServer, LMDeployPipeline
from lagent.schema import AgentReturn, AgentStatusCode
from lagent.schema import AgentStatusCode
from .search_agent.mindsearch_agent import (
MindSearchAgent, SimpleSearchAgent, MindSearchProtocol
)
from .search_agent.mindsearch_prompt import (
FINAL_RESPONSE_CN, FINAL_RESPONSE_EN, GRAPH_PROMPT_CN, GRAPH_PROMPT_EN,
searcher_context_template_cn, searcher_context_template_en,
searcher_input_template_cn, searcher_input_template_en,
searcher_system_prompt_cn, searcher_system_prompt_en
)
from lmdeploy.messages import PytorchEngineConfig
from typing import List, Union
SEARCH_MODEL_NAMES = {
'internlm2_5-7b-chat': 'internlm2',
'internlm2_5-1_8b-chat': 'internlm2'
}
def render_bboxes(in_image: Image.Image, bboxes: np.ndarray, labels: List[str]):
out_image = copy.deepcopy(in_image)
draw = ImageDraw.Draw(out_image)
font = ImageFont.truetype(font = 'assets/Arial.ttf', size = min(in_image.width, in_image.height) // 30)
line_width = min(in_image.width, in_image.height) // 100
for i in range(bboxes.shape[0]):
draw.rectangle((bboxes[i, 0], bboxes[i, 1], bboxes[i, 2], bboxes[i, 3]), outline=(0, 255, 0), width=line_width)
bbox = draw.textbbox((bboxes[i, 0], bboxes[i, 1]), '[Area {}] '.format(i), font=font)
draw.rectangle(bbox, fill='white')
draw.text((bboxes[i, 0], bboxes[i, 1]), '[Area {}] '.format(i), fill='black', font=font)
if bboxes.shape[0] == 0:
draw.rectangle((0, 0, in_image.width, in_image.height), outline=(0, 255, 0), width=line_width)
bbox = draw.textbbox((0, 0), '[Area {}] '.format(0), font=font)
draw.rectangle(bbox, fill='white')
draw.text((0, 0), '[Area {}] '.format(0), fill='black', font=font)
return out_image
class VisualGrounder:
def __init__(
self,
model_path: str = "IDEA-Research/grounding-dino-base",
device: str = "cuda:1",
box_threshold: float = 0.4,
text_threshold: float = 0.3,
):
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_path).to(device)
self.device = device
self.default_classes = COCO_CLASSES
self.box_threshold = box_threshold
self.text_threshold = text_threshold
def __call__(
self,
in_image: Image.Image,
classes: Union[List[str], None] = None,
):
# Save image.
in_image.save('temp/in_image.jpg')
# Preparation.
if classes is None:
classes = self.default_classes
text = ". ".join(classes)
inputs = self.processor(images=in_image, text=text, return_tensors="pt").to(self.device)
# Grounding.
with torch.no_grad():
outputs = self.model(**inputs)
# Postprocess
results = self.processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold = self.box_threshold,
text_threshold = self.text_threshold,
target_sizes=[in_image.size[::-1]]
)
bboxes = results[0]['boxes'].cpu().numpy()
labels = results[0]['labels']
print(results)
# Visualization.
out_image = render_bboxes(in_image, bboxes, labels)
out_image.save('temp/ground_bbox.jpg')
return bboxes, labels, out_image
class VLM:
def __init__(
self,
model_path: str = "liuhaotian/llava-v1.6-vicuna-7b",
device: str = "cuda:2",
load_8bit: bool = False,
load_4bit: bool = True,
temperature: float = 0.2,
max_new_tokens: int = 1024,
):
disable_torch_init()
model_name = get_model_name_from_path(model_path)
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path, None, model_name, load_8bit, load_4bit, device=device
)
self.device = device
if "llama-2" in model_name.lower():
self.conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
self.conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
self.conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
self.conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
self.conv_mode = "mpt"
else:
self.conv_mode = "llava_v0"
self.temperature = temperature
self.max_new_tokens = max_new_tokens
def __call__(
self,
image: Image.Image,
text: str,
):
image_size = image.size
image_tensor = process_images([image], self.image_processor, self.model.config)
if type(image_tensor) is list:
image_tensor = [image.to(self.device, dtype=torch.float16) for image in image_tensor]
else:
image_tensor = image_tensor.to(self.device, dtype=torch.float16)
if image is not None:
# first message
if self.model.config.mm_use_im_start_end:
text = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + text
else:
text = DEFAULT_IMAGE_TOKEN + '\n' + text
image = None
conv = conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], text)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images = image_tensor,
image_sizes = [image_size],
do_sample = True if self.temperature > 0 else False,
temperature = self.temperature,
max_new_tokens = self.max_new_tokens,
streamer = None,
use_cache = True)
outputs = self.tokenizer.decode(output_ids[0]).strip()
outputs = outputs.replace('<s>', '').replace('</s>', '').replace('"', "'")
return outputs
class WebSearcher:
def __init__(
self,
model_path: str = 'internlm/internlm2_5-7b-chat',
lang: str = 'cn',
top_p: float = 0.8,
top_k: int = 1,
temperature: float = 0,
max_new_tokens: int = 8192,
repetition_penalty: float = 1.02,
max_turn: int = 10,
):
model_name = get_model_name_from_path(model_path)
if model_name in SEARCH_MODEL_NAMES:
model_name = SEARCH_MODEL_NAMES[model_name]
else:
raise Exception('Unsupported model for web searcher.')
self.lang = lang
backend_config = PytorchEngineConfig(
max_batch_size = 1,
)
llm = LMDeployServer(
path = model_path,
model_name = model_name,
meta_template = INTERNLM2_META,
top_p = top_p,
top_k = top_k,
temperature = temperature,
max_new_tokens = max_new_tokens,
repetition_penalty = repetition_penalty,
stop_words = ['<|im_end|>'],
serve_cfg = dict(
backend_config = backend_config
)
)
# llm = LMDeployPipeline(
# path = model_path,
# model_name = model_name,
# meta_template = INTERNLM2_META,
# top_p = top_p,
# top_k = top_k,
# temperature = temperature,
# max_new_tokens = max_new_tokens,
# repetition_penalty = repetition_penalty,
# stop_words = ['<|im_end|>'],
# )
self.agent = MindSearchAgent(
llm = llm,
protocol = MindSearchProtocol(
meta_prompt = datetime.now().strftime('The current date is %Y-%m-%d.'),
interpreter_prompt = GRAPH_PROMPT_CN if lang == 'cn' else GRAPH_PROMPT_EN,
response_prompt = FINAL_RESPONSE_CN if lang == 'cn' else FINAL_RESPONSE_EN
),
searcher_cfg=dict(
llm = llm,
plugin_executor = ActionExecutor(
BingBrowser(searcher_type='DuckDuckGoSearch', topk=6)
),
protocol = MindSearchProtocol(
meta_prompt=datetime.now().strftime('The current date is %Y-%m-%d.'),
plugin_prompt=searcher_system_prompt_cn if lang == 'cn' else searcher_system_prompt_en,
),
template = dict(
input=searcher_input_template_cn if lang == 'cn' else searcher_input_template_en,
context=searcher_context_template_cn if lang == 'cn' else searcher_context_template_en)
),
max_turn = max_turn
)
def __call__(
self,
queries: List[str]
):
results = []
for qid, query in enumerate(queries):
result = None
for agent_return in self.agent.stream_chat(query):
if isinstance(agent_return, AgentReturn):
if agent_return.state == AgentStatusCode.END:
result = agent_return.response
assert result is not None
with open('temp/search_result_{}.txt'.format(qid), 'w', encoding='utf-8') as wf:
wf.write(result)
results.append(result)
# for qid, query in enumerate(queries):
# result = None
# agent_return = self.agent.generate(query)
# result = agent_return.response
# assert result is not None
# with open('temp/search_result_{}.txt'.format(qid), 'w', encoding='utf-8') as wf:
# wf.write(result)
# results.append(result)
return results
class VisionSearchAssistant:
"""
Vision Search Assistant: Empower Vision-Language Models as Multimodal Search Engines
This class implements all variants of Vision Search Assistant:
* search_model: Vision Search Assistant use this model for dealing with the search process,
it corresponds to the $\mathcal{F}_{llm}(cdot)$ in the paper. You can choose the model
according to your preference.
* ground_model: The vision foundation model used in the open-vocab detection process,
it's relevant to the specific contents of the classes in the image.
* vlm_model: The main vision-language model we used in our paper is LLaVA-1.6 baseline,
It can be further improved by using advanced models. And it corresponds to
the $\mathcal{F}_{vlm}(cdot)$ in the paper.
"""
def __init__(
self,
search_model: str = "internlm/internlm2_5-1_8b-chat",
ground_model: str = "IDEA-Research/grounding-dino-tiny",
ground_device: str = "cuda:1",
vlm_model: str = "liuhaotian/llava-v1.6-vicuna-7b",
vlm_device: str = "cuda:2",
vlm_load_4bit: bool = True,
vlm_load_8bit: bool = False,
):
self.search_model = search_model
self.ground_model = ground_model
self.ground_device = ground_device
self.vlm_model = vlm_model
self.vlm_device = vlm_device
self.vlm_load_4bit = vlm_load_4bit
self.vlm_load_8bit = vlm_load_8bit
self.use_correlate = True
self.searcher = WebSearcher(
model_path = self.search_model,
lang = 'en'
)
self.grounder = VisualGrounder(
model_path = self.ground_model,
device = self.ground_device,
)
self.vlm = VLM(
model_path = self.vlm_model,
device = self.vlm_device,
load_4bit = self.vlm_load_4bit,
load_8bit = self.vlm_load_8bit
)
def app_run(
self,
image: Union[str, Image.Image, np.ndarray],
text: str,
ground_classes: List[str] = COCO_CLASSES
):
# Create and clear the temporary directory.
if not os.access('temp', os.F_OK):
os.makedirs('temp')
for file in os.listdir('temp'):
os.remove(os.path.join('temp', file))
with open('temp/text.txt', 'w', encoding='utf-8') as wf:
wf.write(text)
# Load Image
if isinstance(image, str):
in_image = Image.open(image)
elif isinstance(image, Image.Image):
in_image = image
elif isinstance(image, np.ndarray):
in_image = Image.fromarray(image.astype(np.uint8))
else:
raise Exception('Unsupported input image format.')
# Visual Grounding
bboxes, labels, out_image = self.grounder(in_image, classes = ground_classes)
yield out_image, 'ground'
det_images = []
for bid, bbox in enumerate(bboxes):
crop_box = (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]))
det_image = in_image.crop(crop_box)
det_image.save('temp/debug_bbox_image_{}.jpg'.format(bid))
det_images.append(det_image)
if len(det_images) == 0: # No object detected, use the full image.
det_images.append(in_image)
labels.append('image')
# Visual Captioning
captions = []
for det_image, label in zip(det_images, labels):
inp = get_caption_prompt(label, text)
caption = self.vlm(det_image, inp)
captions.append(caption)
for cid, caption in enumerate(captions):
with open('temp/caption_{}.txt'.format(cid), 'w', encoding='utf-8') as wf:
wf.write(caption)
# Visual Correlation
if len(captions) >= 2 and self.use_correlate:
queries = []
for mid, det_image in enumerate(det_images):
caption = captions[mid]
other_captions = []
for cid in range(len(captions)):
if cid == mid:
continue
other_captions.append(captions[cid])
inp = get_correlate_prompt(caption, other_captions)
query = self.vlm(det_image, inp)
queries.append(query)
else:
queries = captions
for qid, query in enumerate(queries):
with open('temp/query_{}.txt'.format(qid), 'w', encoding='utf-8') as wf:
wf.write(query)
yield queries, 'query'
queries = [text + " " + query for query in queries]
# Web Searching
contexts = self.searcher(queries)
yield contexts, 'search'
# QA
TOKEN_LIMIT = 3500
max_length_per_context = TOKEN_LIMIT // len(contexts)
for cid, context in enumerate(contexts):
contexts[cid] = (queries[cid] + context)[:max_length_per_context]
inp = get_qa_prompt(text, contexts)
answer = self.vlm(in_image, inp)
with open('temp/answer.txt', 'w', encoding='utf-8') as wf:
wf.write(answer)
print(answer)
yield answer, 'answer'