batik / get_llava_response.py
ncoria's picture
fix method for converting image arr to PIL Image arr
32bc64a verified
raw
history blame
6.99 kB
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import numpy as np
from huggingface_hub import whoami
import llava
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
REPO_NAME = 'ncoria/llava-lora-vicuna-clip-5-epochs-merge'
def load_image(image_file):
if image_file.startswith('http://') or image_file.startswith('https://'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def load_llava_checkpoint(model_path: str):
model_name = get_model_name_from_path(model_path)
return load_pretrained_model(model_path, None, model_name, load_4bit=True, device="cuda")
def load_llava_checkpoint_hf(model_path, hf_token):
user = whoami(token=hf_token)
kwargs = {"device_map": "auto"}
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, token=hf_token, **kwargs)
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model(device_map="auto")
image_processor = vision_tower.image_processor
return tokenizer, model, image_processor
def get_llava_response(user_prompts: list[str],
images: list,
sys_prompt: str,
tokenizer,
model,
image_processor,
model_path = REPO_NAME,
stream_output = True):
"""
This function returns the response from the given model. It creates a one turn conversation in which
the only content is a system prompt and the given user message applied to each image.
Parameters:
----------
user_prompt : str
The prompt sent by the user.
images : str
List of images from file.
sys_prompt : str
The prompt that sets the tone for the conversation.
model_path : str
The path to the merged checkpoint or base model.
Returns:
--------
"""
# set up and load model
model_name = get_model_name_from_path(model_path)
temperature = 0.2 # default
max_new_tokens = 512 # default
# determine conversation type
if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"
# run clean conversation for each image
llm_outputs = []
for i, img in tqdm(enumerate(images)):
# set up clean conversation
conv = conv_templates[conv_mode].copy()
if "mpt" in model_name.lower():
roles = ('user', 'assistant')
else:
roles = conv.roles
conv.system = sys_prompt
# load image
# image = load_image("../images/mouse.png") # previous method
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = Image.fromarray(img, 'L')
elif isinstance(img, np.ndarray):
img = Image.fromarray(img)
image = img.convert('RGB')
image_size = image.size
# NOTE: image is simply PIL Image (.convert('RGB')), no need for temp files!
# Similar operation in model_worker.py
image_tensor = process_images([image], image_processor, model.config)
if type(image_tensor) is list:
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
else:
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
# execute conversation
inp = user_prompts[i]
if image is not None:
# first message
if model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
else:
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
image = None
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt,
tokenizer,
IMAGE_TOKEN_INDEX,
return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
if stream_output:
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
else:
streamer = None
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
image_sizes=[image_size],
do_sample=True if temperature > 0 else False,
temperature=temperature,
max_new_tokens=max_new_tokens,
streamer=streamer,
use_cache=True)
outputs = tokenizer.decode(output_ids[0]).strip()
llm_outputs.append(outputs)
return llm_outputs