from CircumSpect.vqa.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN from CircumSpect.vqa.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria from CircumSpect.vqa.conversation_obj import conv_templates_obj, SeparatorStyle_obj from CircumSpect.vqa.conversation_vqa import conv_templates, SeparatorStyle from transformers import AutoTokenizer, BitsAndBytesConfig from CircumSpect.vqa.utils import disable_torch_init from Perceptrix.streamer import TextStreamer from CircumSpect.vqa.model import * from utils import setup_device from io import BytesIO from PIL import Image import requests import torch import os device = setup_device() def load_image(image_file): if image_file.startswith('http') or image_file.startswith('https'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image disable_torch_init() model_name = os.environ.get('VLM_MODEL') model_path = "models/CRYSTAL-vision" if model_name == None else model_name model_base = None conv_mode = None temperature = 0.2 max_new_tokens = 512 model_name = get_model_name_from_path(model_path) image_processor = None bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16 ) tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = LlavaMPTForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, trust_remote_code=True, device_map="auto", torch_dtype=torch.float32 if str(device) == "cpu" else torch.float16, quantization_config=bnb_config, offload_folder="offloads", ) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens( [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model() vision_tower.to(device=device, dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 if 'llama-2' in model_name.lower(): conv_mode = "llava_llama_2" elif "v1" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt" else: conv_mode = "llava_v0" if conv_mode is not None and conv_mode != conv_mode: print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, conv_mode, conv_mode)) else: conv_mode = conv_mode conv = conv_templates[conv_mode].copy() if "mpt" in model_name.lower(): roles = ('User', 'Assistant') else: roles = conv.roles streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, save_file="vlm-reply.txt") def answer_question(question, image_file): conv = conv_templates[conv_mode].copy() inp = question image = load_image(image_file) if str(device) == "cpu": image_tensor = image_processor.preprocess(image, return_tensors='pt')[ 'pixel_values'].to(device) else: image_tensor = image_processor.preprocess(image, return_tensors='pt')[ 'pixel_values'].half().to(device) print(f"{roles[1]}: ", end="") if image is not None: # first message if model.config.mm_use_im_start_end: inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + \ DEFAULT_IM_END_TOKEN + '\n' + inp else: inp = DEFAULT_IMAGE_TOKEN + '\n' + inp conv.append_message(conv.roles[0], inp) image = None else: # later messages conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token( prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device) stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria( keywords, tokenizer, input_ids) with open("./database/vlm-reply.txt", 'w') as clear_file: clear_file.write("") with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria]) outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() conv.messages[-1][-1] = outputs return outputs conv_obj = conv_templates_obj[conv_mode].copy() if "mpt" in model_name.lower(): roles = ('User', 'Assistant') else: roles = conv_obj.roles def find_object_description(question, image_file): conv_obj = conv_templates_obj[conv_mode].copy() inp = question image = load_image(image_file) if str(device) == "cpu": image_tensor = image_processor.preprocess(image, return_tensors='pt')[ 'pixel_values'].to(device) else: image_tensor = image_processor.preprocess(image, return_tensors='pt')[ 'pixel_values'].half().to(device) print(f"{roles[1]}: ", end="") if image is not None: # first message if model.config.mm_use_im_start_end: inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + \ DEFAULT_IM_END_TOKEN + '\n' + inp else: inp = DEFAULT_IMAGE_TOKEN + '\n' + inp conv_obj.append_message(conv_obj.roles[0], inp) image = None else: # later messages conv_obj.append_message(conv_obj.roles[0], inp) conv_obj.append_message(conv_obj.roles[1], None) prompt = conv_obj.get_prompt() input_ids = tokenizer_image_token( prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device) stop_str = conv_obj.sep if conv_obj.sep_style != SeparatorStyle_obj.TWO else conv_obj.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria( keywords, tokenizer, input_ids) with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=True, temperature=0.1, max_new_tokens=32, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria]) outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() conv_obj.messages[-1][-1] = outputs return outputs if __name__ == "__main__": print("RUNNING TEST\n\tTest Image: https://llava-vl.github.io/static/images/view.jpg\n\tPrompt: Describe this image") find_object_description("Describe this image", "https://llava-vl.github.io/static/images/view.jpg") print("RUNNING TEST\n\tTest Image: https://llava-vl.github.io/static/images/view.jpg\n\tPrompt: Describe this image") find_object_description("Describe this image", "https://llava-vl.github.io/static/images/view.jpg") print("RUNNING TEST\n\tTest Image: https://llava-vl.github.io/static/images/view.jpg\n\tPrompt: Describe this image") find_object_description("Describe this image", "https://llava-vl.github.io/static/images/view.jpg") print("RUNNING TEST\n\tTest Image: https://llava-vl.github.io/static/images/view.jpg\n\tPrompt: Describe this image") find_object_description("Describe this image", "https://llava-vl.github.io/static/images/view.jpg")