FIRE / src /model /model_llava.py
li-qing's picture
fix: dataset
489bcf5
raw
history blame
5.31 kB
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
from loguru import logger
from PIL import Image
import requests
import copy
import torch
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
import spaces
from io import BytesIO
import base64
#model_path = "/scratch/TecManDep/A_Models/llava-v1.6-vicuna-7b"
#conv_template = "vicuna_v1" # Make sure you use correct chat template for different models
def load_llava_model(lora_checkpoint=None):
model_path = "Lin-Chen/open-llava-next-llama3-8b"
conv_template = "llama_v3"
model_name = get_model_name_from_path(model_path)
device = "cuda"
device_map = "auto"
if lora_checkpoint is None:
tokenizer, model, image_processor, max_length = load_pretrained_model(
model_path, None, model_name, device_map=device_map) # Add any other thing you want to pass in llava_model_args
else:
tokenizer, model, image_processor, max_length = load_pretrained_model(
lora_checkpoint, model_path, "llava_lora", device_map=device_map)
model.eval()
model.tie_weights()
logger.info("model device {}", model.device)
return tokenizer, model, image_processor, conv_template
tokenizer_llava, model_llava, image_processor_llava, conv_template_llava = load_llava_model(None)
tokenizer_llava_fire, model_llava_fire, image_processor_llava_fire, conv_template_llava = load_llava_model("checkpoints/llava-next-llama-3-8b-student-lora-merged-110224")
model_llava_fire.to("cuda")
@spaces.GPU
def inference():
image = Image.open("assets/example.jpg").convert("RGB")
device = "cuda"
image_tensor = process_images([image], image_processor_llava, model_llava.config)
image_tensor = image_tensor.to(dtype=torch.float16, device=device)
prompt = """<image>What is in the figure?"""
conv = conv_templates[conv_template_llava].copy()
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()
input_ids = tokenizer_image_token(prompt_question, tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size]
print(input_ids.shape, image_tensor.shape)
with torch.inference_mode():
cont = model_llava.generate(
input_ids,
images=image_tensor,
image_sizes=image_sizes,
do_sample=False,
temperature=0,
max_new_tokens=256,
use_cache=True
)
text_outputs = tokenizer_llava.batch_decode(cont, skip_special_tokens=True)
print(text_outputs)
return text_outputs
@spaces.GPU
def inference_by_prompt_and_images(prompt, images):
device = "cuda"
if len(images) > 0 and type(images[0]) is str:
image_data = []
for image in images:
image_data.append(Image.open(BytesIO(base64.b64decode(image))))
images = image_data
image_tensor = process_images(images, image_processor_llava, model_llava.config)
image_tensor = image_tensor.to(dtype=torch.float16, device=device)
input_ids = tokenizer_image_token(prompt, tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size for image in images]
logger.info("Shape: {};{}; Devices: {};{}",input_ids.shape, image_tensor.shape, input_ids.device, image_tensor.device)
with torch.inference_mode():
cont = model_llava.generate(
input_ids,
images=image_tensor,
image_sizes=image_sizes,
do_sample=False,
temperature=0,
max_new_tokens=256,
use_cache=True
)
text_outputs = tokenizer_llava.batch_decode(cont, skip_special_tokens=True)
logger.info("response={}", text_outputs)
return text_outputs
@spaces.GPU
def inference_by_prompt_and_images_fire(prompt, images):
device = "cuda"
if len(images) > 0 and type(images[0]) is str:
image_data = []
for image in images:
image_data.append(Image.open(BytesIO(base64.b64decode(image))))
images = image_data
image_tensor = process_images(images, image_processor_llava_fire, model_llava_fire.config)
image_tensor = image_tensor.to(dtype=torch.float16, device=device)
input_ids = tokenizer_image_token(prompt, tokenizer_llava_fire, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size for image in images]
logger.info("Shape: {};{}; Devices: {};{}",input_ids.shape, image_tensor.shape, input_ids.device, image_tensor.device)
with torch.inference_mode():
cont = model_llava_fire.generate(
input_ids,
images=image_tensor,
image_sizes=image_sizes,
do_sample=False,
temperature=0,
max_new_tokens=256,
use_cache=True
)
text_outputs = tokenizer_llava_fire.batch_decode(cont, skip_special_tokens=True)
logger.info("response={}", text_outputs)
return text_outputs
if __name__ == "__main__":
inference()