import argparse import torch from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig import numpy as np from huggingface_hub import whoami import llava from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path from PIL import Image import requests from PIL import Image from io import BytesIO from transformers import TextStreamer from tqdm import tqdm import warnings warnings.filterwarnings('ignore') REPO_NAME = 'ncoria/llava-lora-vicuna-clip-5-epochs-merge' def load_image(image_file): if image_file.startswith('http://') or image_file.startswith('https://'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image def load_llava_checkpoint(model_path: str): model_name = get_model_name_from_path(model_path) return load_pretrained_model(model_path, None, model_name, load_4bit=True, device="cuda") def load_llava_checkpoint_hf(model_path): kwargs = {"device_map": "auto"} kwargs['load_in_4bit'] = True kwargs['quantization_config'] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4' ) tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model(device_map="auto") image_processor = vision_tower.image_processor return tokenizer, model, image_processor def get_llava_response(user_prompts: list[str], images: list, sys_prompt: str, tokenizer, model, image_processor, model_path = REPO_NAME, stream_output = True): """ This function returns the response from the given model. It creates a one turn conversation in which the only content is a system prompt and the given user message applied to each image. Parameters: ---------- user_prompt : str The prompt sent by the user. images : str List of images from file. sys_prompt : str The prompt that sets the tone for the conversation. model_path : str The path to the merged checkpoint or base model. Returns: -------- """ # set up and load model model_name = get_model_name_from_path(model_path) temperature = 0.2 # default max_new_tokens = 512 # default # determine conversation type if "llama-2" in model_name.lower(): conv_mode = "llava_llama_2" elif "mistral" in model_name.lower(): conv_mode = "mistral_instruct" elif "v1.6-34b" in model_name.lower(): conv_mode = "chatml_direct" elif "v1" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt" else: conv_mode = "llava_v0" # run clean conversation for each image llm_outputs = [] for i, img in tqdm(enumerate(images)): # set up clean conversation conv = conv_templates[conv_mode].copy() if "mpt" in model_name.lower(): roles = ('user', 'assistant') else: roles = conv.roles conv.system = sys_prompt # load image # image = load_image("../images/mouse.png") # previous method if isinstance(img, np.ndarray) and len(img.shape) == 2: img = Image.fromarray(img, 'L') elif isinstance(img, np.ndarray): img = Image.fromarray(img) image = img.convert('RGB') image_size = image.size # NOTE: image is simply PIL Image (.convert('RGB')), no need for temp files! # Similar operation in model_worker.py image_tensor = process_images([image], image_processor, model.config) if type(image_tensor) is list: image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] else: image_tensor = image_tensor.to(model.device, dtype=torch.float16) # execute conversation inp = user_prompts[i] if image is not None: # first message if model.config.mm_use_im_start_end: inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp else: inp = DEFAULT_IMAGE_TOKEN + '\n' + inp image = None conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] if stream_output: streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) else: streamer = None with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, image_sizes=[image_size], do_sample=True if temperature > 0 else False, temperature=temperature, max_new_tokens=max_new_tokens, streamer=streamer, use_cache=True) outputs = tokenizer.decode(output_ids[0]).strip() llm_outputs.append(outputs) return llm_outputs