Spaces:
Sleeping
Sleeping
import argparse | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig | |
import numpy as np | |
from huggingface_hub import whoami | |
import llava | |
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN | |
from llava.conversation import conv_templates, SeparatorStyle | |
from llava.model.builder import load_pretrained_model | |
from llava.utils import disable_torch_init | |
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path | |
from PIL import Image | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
from transformers import TextStreamer | |
from tqdm import tqdm | |
import warnings | |
warnings.filterwarnings('ignore') | |
REPO_NAME = 'ncoria/llava-lora-vicuna-clip-5-epochs-merge' | |
def load_image(image_file): | |
if image_file.startswith('http://') or image_file.startswith('https://'): | |
response = requests.get(image_file) | |
image = Image.open(BytesIO(response.content)).convert('RGB') | |
else: | |
image = Image.open(image_file).convert('RGB') | |
return image | |
def load_llava_checkpoint(model_path: str): | |
model_name = get_model_name_from_path(model_path) | |
return load_pretrained_model(model_path, None, model_name, load_4bit=True, device="cuda") | |
def load_llava_checkpoint_hf(model_path): | |
kwargs = {"device_map": "auto"} | |
kwargs['load_in_4bit'] = True | |
kwargs['quantization_config'] = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type='nf4' | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) | |
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) | |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) | |
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) | |
if mm_use_im_patch_token: | |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) | |
if mm_use_im_start_end: | |
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) | |
model.resize_token_embeddings(len(tokenizer)) | |
vision_tower = model.get_vision_tower() | |
if not vision_tower.is_loaded: | |
vision_tower.load_model(device_map="auto") | |
image_processor = vision_tower.image_processor | |
return tokenizer, model, image_processor | |
def get_llava_response(user_prompts: list[str], | |
images: list, | |
sys_prompt: str, | |
tokenizer, | |
model, | |
image_processor, | |
model_path = REPO_NAME, | |
stream_output = True): | |
""" | |
This function returns the response from the given model. It creates a one turn conversation in which | |
the only content is a system prompt and the given user message applied to each image. | |
Parameters: | |
---------- | |
user_prompt : str | |
The prompt sent by the user. | |
images : str | |
List of images from file. | |
sys_prompt : str | |
The prompt that sets the tone for the conversation. | |
model_path : str | |
The path to the merged checkpoint or base model. | |
Returns: | |
-------- | |
""" | |
# set up and load model | |
model_name = get_model_name_from_path(model_path) | |
temperature = 0.2 # default | |
max_new_tokens = 512 # default | |
# determine conversation type | |
if "llama-2" in model_name.lower(): | |
conv_mode = "llava_llama_2" | |
elif "mistral" in model_name.lower(): | |
conv_mode = "mistral_instruct" | |
elif "v1.6-34b" in model_name.lower(): | |
conv_mode = "chatml_direct" | |
elif "v1" in model_name.lower(): | |
conv_mode = "llava_v1" | |
elif "mpt" in model_name.lower(): | |
conv_mode = "mpt" | |
else: | |
conv_mode = "llava_v0" | |
# run clean conversation for each image | |
llm_outputs = [] | |
for i, img in tqdm(enumerate(images)): | |
# set up clean conversation | |
conv = conv_templates[conv_mode].copy() | |
if "mpt" in model_name.lower(): | |
roles = ('user', 'assistant') | |
else: | |
roles = conv.roles | |
conv.system = sys_prompt | |
# load image | |
# image = load_image("../images/mouse.png") # previous method | |
if isinstance(img, np.ndarray) and len(img.shape) == 2: | |
img = Image.fromarray(img, 'L') | |
elif isinstance(img, np.ndarray): | |
img = Image.fromarray(img) | |
image = img.convert('RGB') | |
image_size = image.size | |
# NOTE: image is simply PIL Image (.convert('RGB')), no need for temp files! | |
# Similar operation in model_worker.py | |
image_tensor = process_images([image], image_processor, model.config) | |
if type(image_tensor) is list: | |
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] | |
else: | |
image_tensor = image_tensor.to(model.device, dtype=torch.float16) | |
# execute conversation | |
inp = user_prompts[i] | |
if image is not None: | |
# first message | |
if model.config.mm_use_im_start_end: | |
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp | |
else: | |
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp | |
image = None | |
conv.append_message(conv.roles[0], inp) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
input_ids = tokenizer_image_token(prompt, | |
tokenizer, | |
IMAGE_TOKEN_INDEX, | |
return_tensors='pt').unsqueeze(0).to(model.device) | |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
keywords = [stop_str] | |
if stream_output: | |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
else: | |
streamer = None | |
with torch.inference_mode(): | |
output_ids = model.generate( | |
input_ids, | |
images=image_tensor, | |
image_sizes=[image_size], | |
do_sample=True if temperature > 0 else False, | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
streamer=streamer, | |
use_cache=True) | |
outputs = tokenizer.decode(output_ids[0]).strip() | |
llm_outputs.append(outputs) | |
return llm_outputs | |