from llava.model.builder import load_pretrained_model from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN from llava.conversation import conv_templates from loguru import logger from PIL import Image import requests import copy import torch from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path import spaces from io import BytesIO import base64 #model_path = "/scratch/TecManDep/A_Models/llava-v1.6-vicuna-7b" #conv_template = "vicuna_v1" # Make sure you use correct chat template for different models def load_llava_model(lora_checkpoint=None): model_path = "Lin-Chen/open-llava-next-llama3-8b" conv_template = "llama_v3" model_name = get_model_name_from_path(model_path) device = "cuda" device_map = "auto" if lora_checkpoint is None: tokenizer, model, image_processor, max_length = load_pretrained_model( model_path, None, model_name, device_map=device_map) # Add any other thing you want to pass in llava_model_args else: tokenizer, model, image_processor, max_length = load_pretrained_model( lora_checkpoint, model_path, "llava_lora", device_map=device_map) model.eval() model.tie_weights() logger.info("model device {}", model.device) return tokenizer, model, image_processor, conv_template tokenizer_llava, model_llava, image_processor_llava, conv_template_llava = load_llava_model(None) tokenizer_llava_fire, model_llava_fire, image_processor_llava_fire, conv_template_llava = load_llava_model("checkpoints/llava-next-llama-3-8b-student-lora-merged-110224") model_llava_fire.to("cuda") @spaces.GPU def inference(): image = Image.open("assets/example.jpg").convert("RGB") device = "cuda" image_tensor = process_images([image], image_processor_llava, model_llava.config) image_tensor = image_tensor.to(dtype=torch.float16, device=device) prompt = """What is in the figure?""" conv = conv_templates[conv_template_llava].copy() conv.append_message(conv.roles[0], prompt) conv.append_message(conv.roles[1], None) prompt_question = conv.get_prompt() input_ids = tokenizer_image_token(prompt_question, tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) image_sizes = [image.size] print(input_ids.shape, image_tensor.shape) with torch.inference_mode(): cont = model_llava.generate( input_ids, images=image_tensor, image_sizes=image_sizes, do_sample=False, temperature=0, max_new_tokens=256, use_cache=True ) text_outputs = tokenizer_llava.batch_decode(cont, skip_special_tokens=True) print(text_outputs) return text_outputs @spaces.GPU def inference_by_prompt_and_images(prompt, images): device = "cuda" if len(images) > 0 and type(images[0]) is str: image_data = [] for image in images: image_data.append(Image.open(BytesIO(base64.b64decode(image)))) images = image_data image_tensor = process_images(images, image_processor_llava, model_llava.config) image_tensor = image_tensor.to(dtype=torch.float16, device=device) input_ids = tokenizer_image_token(prompt, tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) image_sizes = [image.size for image in images] logger.info("Shape: {};{}; Devices: {};{}",input_ids.shape, image_tensor.shape, input_ids.device, image_tensor.device) with torch.inference_mode(): cont = model_llava.generate( input_ids, images=image_tensor, image_sizes=image_sizes, do_sample=False, temperature=0, max_new_tokens=256, use_cache=True ) text_outputs = tokenizer_llava.batch_decode(cont, skip_special_tokens=True) logger.info("response={}", text_outputs) return text_outputs @spaces.GPU def inference_by_prompt_and_images_fire(prompt, images): device = "cuda" if len(images) > 0 and type(images[0]) is str: image_data = [] for image in images: image_data.append(Image.open(BytesIO(base64.b64decode(image)))) images = image_data image_tensor = process_images(images, image_processor_llava_fire, model_llava_fire.config) image_tensor = image_tensor.to(dtype=torch.float16, device=device) input_ids = tokenizer_image_token(prompt, tokenizer_llava_fire, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) image_sizes = [image.size for image in images] logger.info("Shape: {};{}; Devices: {};{}",input_ids.shape, image_tensor.shape, input_ids.device, image_tensor.device) with torch.inference_mode(): cont = model_llava_fire.generate( input_ids, images=image_tensor, image_sizes=image_sizes, do_sample=False, temperature=0, max_new_tokens=256, use_cache=True ) text_outputs = tokenizer_llava_fire.batch_decode(cont, skip_special_tokens=True) logger.info("response={}", text_outputs) return text_outputs if __name__ == "__main__": inference()