File size: 5,354 Bytes
da079a2 4e6e9a3 da079a2 b813523 da079a2 b813523 a09be9c da079a2 b813523 da079a2 a09be9c 59a40c7 da079a2 b813523 da079a2 59a40c7 b813523 59a40c7 b813523 da079a2 517b6c2 f0b7de9 4e6e9a3 517b6c2 4e6e9a3 517b6c2 b813523 517b6c2 b813523 517b6c2 f0b7de9 59a40c7 d2b427c 59a40c7 d2b427c 59a40c7 b813523 59a40c7 695e0c9 59a40c7 d2b427c b813523 59a40c7 da079a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
from PIL import Image
import requests
import copy
import torch
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
import spaces
from io import BytesIO
import base64
#model_path = "/scratch/TecManDep/A_Models/llava-v1.6-vicuna-7b"
#conv_template = "vicuna_v1" # Make sure you use correct chat template for different models
from src.utils import (
build_logger,
)
logger = build_logger("model_llava", "model_llava.log")
def load_llava_model(lora_checkpoint=None):
model_path = "Lin-Chen/open-llava-next-llama3-8b"
conv_template = "llama_v3_student"
model_name = get_model_name_from_path(model_path)
device = "cuda"
device_map = "auto"
if lora_checkpoint is None:
tokenizer, model, image_processor, max_length = load_pretrained_model(
model_path, None, model_name, device_map=device_map) # Add any other thing you want to pass in llava_model_args
else:
tokenizer, model, image_processor, max_length = load_pretrained_model(
lora_checkpoint, model_path, "llava_lora", device_map=device_map)
model.eval()
model.tie_weights()
logger.info(f"model device {model.device}")
return tokenizer, model, image_processor, conv_template
tokenizer_llava, model_llava, image_processor_llava, conv_template_llava = load_llava_model(None)
tokenizer_llava_fire, model_llava_fire, image_processor_llava_fire, conv_template_llava = load_llava_model("checkpoints/llava-next-llama-3-8b-student-lora-merged-115124")
model_llava_fire.to("cuda")
@spaces.GPU
def inference():
image = Image.open("assets/example.jpg").convert("RGB")
device = "cuda"
image_tensor = process_images([image], image_processor_llava, model_llava.config)
image_tensor = image_tensor.to(dtype=torch.float16, device=device)
prompt = """<image>What is in the figure?"""
conv = conv_templates[conv_template_llava].copy()
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()
input_ids = tokenizer_image_token(prompt_question, tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size]
print(input_ids.shape, image_tensor.shape)
with torch.inference_mode():
cont = model_llava.generate(
input_ids,
images=image_tensor,
image_sizes=image_sizes,
do_sample=False,
temperature=0,
max_new_tokens=256,
use_cache=True
)
text_outputs = tokenizer_llava.batch_decode(cont, skip_special_tokens=True)
print(text_outputs)
return text_outputs
@spaces.GPU
def inference_by_prompt_and_images(prompt, images):
device = "cuda"
if len(images) > 0 and type(images[0]) is str:
image_data = []
for image in images:
image_data.append(Image.open(BytesIO(base64.b64decode(image))))
images = image_data
image_tensor = process_images(images, image_processor_llava, model_llava.config)
image_tensor = image_tensor.to(dtype=torch.float16, device=device)
input_ids = tokenizer_image_token(prompt, tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size for image in images]
logger.info(f"Shape: {input_ids.shape};{image_tensor.shape}; Devices: {input_ids.device};{image_tensor.device}")
with torch.inference_mode():
cont = model_llava.generate(
input_ids,
images=image_tensor,
image_sizes=image_sizes,
do_sample=False,
temperature=0,
max_new_tokens=256,
use_cache=True
)
text_outputs = tokenizer_llava.batch_decode(cont, skip_special_tokens=True)
return text_outputs
@spaces.GPU
def inference_by_prompt_and_images_fire(prompt, images):
device = "cuda"
if len(images) > 0 and type(images[0]) is str:
image_data = []
for image in images:
image_data.append(Image.open(BytesIO(base64.b64decode(image))))
images = image_data
image_tensor = process_images(images, image_processor_llava_fire, model_llava_fire.config)
image_tensor = image_tensor.to(dtype=torch.float16, device=device)
input_ids = tokenizer_image_token(prompt, tokenizer_llava_fire, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size for image in images]
logger.info(f"Shape: {input_ids.shape};{image_tensor.shape}; Devices: {input_ids.device};{image_tensor.device}")
with torch.inference_mode():
cont = model_llava_fire.generate(
input_ids,
images=[image_tensor.squeeze(dim=0)],
image_sizes=image_sizes,
do_sample=False,
temperature=0,
max_new_tokens=256,
use_cache=True
)
text_outputs = tokenizer_llava_fire.batch_decode(cont, skip_special_tokens=True)
logger.info(f"response={text_outputs}")
return text_outputs
if __name__ == "__main__":
inference() |