File size: 5,354 Bytes
da079a2
 
 
 
 
 
 
 
 
 
 
4e6e9a3
 
da079a2
 
b813523
 
 
da079a2
b813523
a09be9c
da079a2
b813523
da079a2
 
 
a09be9c
 
 
 
 
59a40c7
da079a2
 
 
b813523
da079a2
 
59a40c7
b813523
59a40c7
b813523
da079a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517b6c2
f0b7de9
4e6e9a3
517b6c2
4e6e9a3
 
 
 
 
517b6c2
 
 
 
b813523
517b6c2
 
 
 
 
 
 
 
 
 
 
b813523
517b6c2
 
f0b7de9
59a40c7
 
 
 
 
 
 
d2b427c
59a40c7
d2b427c
59a40c7
b813523
59a40c7
 
 
695e0c9
59a40c7
 
 
 
 
 
d2b427c
b813523
59a40c7
 
da079a2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates

from PIL import Image
import requests
import copy
import torch
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
import spaces
from io import BytesIO
import base64
#model_path = "/scratch/TecManDep/A_Models/llava-v1.6-vicuna-7b"
#conv_template = "vicuna_v1" # Make sure you use correct chat template for different models
from src.utils import (
    build_logger,
)

logger = build_logger("model_llava", "model_llava.log")
def load_llava_model(lora_checkpoint=None):
    model_path =  "Lin-Chen/open-llava-next-llama3-8b"
    conv_template = "llama_v3_student"
    model_name = get_model_name_from_path(model_path)
    device = "cuda"
    device_map = "auto"
    if lora_checkpoint is None:
        tokenizer, model, image_processor, max_length = load_pretrained_model(
            model_path, None, model_name, device_map=device_map) # Add any other thing you want to pass in llava_model_args
    else:
        tokenizer, model, image_processor, max_length = load_pretrained_model(
            lora_checkpoint, model_path, "llava_lora", device_map=device_map)

    model.eval()
    model.tie_weights()
    logger.info(f"model device {model.device}")
    return tokenizer, model, image_processor, conv_template

tokenizer_llava, model_llava, image_processor_llava, conv_template_llava = load_llava_model(None)
tokenizer_llava_fire, model_llava_fire, image_processor_llava_fire, conv_template_llava = load_llava_model("checkpoints/llava-next-llama-3-8b-student-lora-merged-115124")
model_llava_fire.to("cuda")

@spaces.GPU
def inference():
    image = Image.open("assets/example.jpg").convert("RGB")
    device = "cuda"
    image_tensor = process_images([image], image_processor_llava, model_llava.config)
    image_tensor = image_tensor.to(dtype=torch.float16, device=device)

    prompt = """<image>What is in the figure?"""
    conv = conv_templates[conv_template_llava].copy()
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    prompt_question = conv.get_prompt()

    input_ids = tokenizer_image_token(prompt_question, tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
    image_sizes = [image.size]
    print(input_ids.shape, image_tensor.shape)
    with torch.inference_mode():
        cont = model_llava.generate(
            input_ids,
            images=image_tensor,
            image_sizes=image_sizes,
            do_sample=False,
            temperature=0,
            max_new_tokens=256,
            use_cache=True
        )
    text_outputs = tokenizer_llava.batch_decode(cont, skip_special_tokens=True)
    print(text_outputs)
    return text_outputs


@spaces.GPU
def inference_by_prompt_and_images(prompt, images):
    device = "cuda"
    if len(images) > 0 and type(images[0]) is str:
        image_data = []
        for image in images:
            image_data.append(Image.open(BytesIO(base64.b64decode(image))))
        images = image_data
    image_tensor = process_images(images, image_processor_llava, model_llava.config)
    image_tensor = image_tensor.to(dtype=torch.float16, device=device)
    input_ids = tokenizer_image_token(prompt, tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
    image_sizes = [image.size for image in images]
    logger.info(f"Shape: {input_ids.shape};{image_tensor.shape}; Devices: {input_ids.device};{image_tensor.device}")
    with torch.inference_mode():
        cont = model_llava.generate(
            input_ids,
            images=image_tensor,
            image_sizes=image_sizes,
            do_sample=False,
            temperature=0,
            max_new_tokens=256,
            use_cache=True
        )
    text_outputs = tokenizer_llava.batch_decode(cont, skip_special_tokens=True)
    
    return text_outputs

@spaces.GPU
def inference_by_prompt_and_images_fire(prompt, images):
    device = "cuda"
    if len(images) > 0 and type(images[0]) is str:
        image_data = []
        for image in images:
            image_data.append(Image.open(BytesIO(base64.b64decode(image))))
        images = image_data
    image_tensor = process_images(images, image_processor_llava_fire, model_llava_fire.config)
    image_tensor = image_tensor.to(dtype=torch.float16, device=device)
    input_ids = tokenizer_image_token(prompt, tokenizer_llava_fire, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
    image_sizes = [image.size for image in images]
    logger.info(f"Shape: {input_ids.shape};{image_tensor.shape}; Devices: {input_ids.device};{image_tensor.device}")
    with torch.inference_mode():
        cont = model_llava_fire.generate(
            input_ids,
            images=[image_tensor.squeeze(dim=0)],
            image_sizes=image_sizes,
            do_sample=False,
            temperature=0,
            max_new_tokens=256,
            use_cache=True
        )
    text_outputs = tokenizer_llava_fire.batch_decode(cont, skip_special_tokens=True)
    logger.info(f"response={text_outputs}")
    return text_outputs

if __name__ == "__main__":
    inference()