batik / get_llava_response.py
ncoria's picture
remove hf_token parameter
c784516 verified
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import numpy as np
from huggingface_hub import whoami
import llava
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
REPO_NAME = 'ncoria/llava-lora-vicuna-clip-5-epochs-merge'
def load_image(image_file):
if image_file.startswith('http://') or image_file.startswith('https://'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def load_llava_checkpoint(model_path: str):
model_name = get_model_name_from_path(model_path)
return load_pretrained_model(model_path, None, model_name, load_4bit=True, device="cuda")
def load_llava_checkpoint_hf(model_path):
kwargs = {"device_map": "auto"}
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model(device_map="auto")
image_processor = vision_tower.image_processor
return tokenizer, model, image_processor
def get_llava_response(user_prompts: list[str],
images: list,
sys_prompt: str,
tokenizer,
model,
image_processor,
model_path = REPO_NAME,
stream_output = True):
"""
This function returns the response from the given model. It creates a one turn conversation in which
the only content is a system prompt and the given user message applied to each image.
Parameters:
----------
user_prompt : str
The prompt sent by the user.
images : str
List of images from file.
sys_prompt : str
The prompt that sets the tone for the conversation.
model_path : str
The path to the merged checkpoint or base model.
Returns:
--------
"""
# set up and load model
model_name = get_model_name_from_path(model_path)
temperature = 0.2 # default
max_new_tokens = 512 # default
# determine conversation type
if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"
# run clean conversation for each image
llm_outputs = []
for i, img in tqdm(enumerate(images)):
# set up clean conversation
conv = conv_templates[conv_mode].copy()
if "mpt" in model_name.lower():
roles = ('user', 'assistant')
else:
roles = conv.roles
conv.system = sys_prompt
# load image
# image = load_image("../images/mouse.png") # previous method
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = Image.fromarray(img, 'L')
elif isinstance(img, np.ndarray):
img = Image.fromarray(img)
image = img.convert('RGB')
image_size = image.size
# NOTE: image is simply PIL Image (.convert('RGB')), no need for temp files!
# Similar operation in model_worker.py
image_tensor = process_images([image], image_processor, model.config)
if type(image_tensor) is list:
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
else:
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
# execute conversation
inp = user_prompts[i]
if image is not None:
# first message
if model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
else:
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
image = None
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt,
tokenizer,
IMAGE_TOKEN_INDEX,
return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
if stream_output:
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
else:
streamer = None
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
image_sizes=[image_size],
do_sample=True if temperature > 0 else False,
temperature=temperature,
max_new_tokens=max_new_tokens,
streamer=streamer,
use_cache=True)
outputs = tokenizer.decode(output_ids[0]).strip()
llm_outputs.append(outputs)
return llm_outputs