Spaces:
Sleeping
Sleeping
File size: 6,726 Bytes
c784516 ed29c11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import numpy as np
from huggingface_hub import whoami
import llava
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
REPO_NAME = 'ncoria/llava-lora-vicuna-clip-5-epochs-merge'
def load_image(image_file):
if image_file.startswith('http://') or image_file.startswith('https://'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def load_llava_checkpoint(model_path: str):
model_name = get_model_name_from_path(model_path)
return load_pretrained_model(model_path, None, model_name, load_4bit=True, device="cuda")
def load_llava_checkpoint_hf(model_path):
kwargs = {"device_map": "auto"}
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model(device_map="auto")
image_processor = vision_tower.image_processor
return tokenizer, model, image_processor
def get_llava_response(user_prompts: list[str],
images: list,
sys_prompt: str,
tokenizer,
model,
image_processor,
model_path = REPO_NAME,
stream_output = True):
"""
This function returns the response from the given model. It creates a one turn conversation in which
the only content is a system prompt and the given user message applied to each image.
Parameters:
----------
user_prompt : str
The prompt sent by the user.
images : str
List of images from file.
sys_prompt : str
The prompt that sets the tone for the conversation.
model_path : str
The path to the merged checkpoint or base model.
Returns:
--------
"""
# set up and load model
model_name = get_model_name_from_path(model_path)
temperature = 0.2 # default
max_new_tokens = 512 # default
# determine conversation type
if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"
# run clean conversation for each image
llm_outputs = []
for i, img in tqdm(enumerate(images)):
# set up clean conversation
conv = conv_templates[conv_mode].copy()
if "mpt" in model_name.lower():
roles = ('user', 'assistant')
else:
roles = conv.roles
conv.system = sys_prompt
# load image
# image = load_image("../images/mouse.png") # previous method
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = Image.fromarray(img, 'L')
elif isinstance(img, np.ndarray):
img = Image.fromarray(img)
image = img.convert('RGB')
image_size = image.size
# NOTE: image is simply PIL Image (.convert('RGB')), no need for temp files!
# Similar operation in model_worker.py
image_tensor = process_images([image], image_processor, model.config)
if type(image_tensor) is list:
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
else:
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
# execute conversation
inp = user_prompts[i]
if image is not None:
# first message
if model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
else:
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
image = None
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt,
tokenizer,
IMAGE_TOKEN_INDEX,
return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
if stream_output:
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
else:
streamer = None
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
image_sizes=[image_size],
do_sample=True if temperature > 0 else False,
temperature=temperature,
max_new_tokens=max_new_tokens,
streamer=streamer,
use_cache=True)
outputs = tokenizer.decode(output_ids[0]).strip()
llm_outputs.append(outputs)
return llm_outputs
|