File size: 3,888 Bytes
fcc02a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption

import torch
from PIL import Image, ImageOps

from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor

img_ext = ['.jpg', '.jpeg', '.png', '.webp']


class LLaVAImageProcessor:
    def __init__(self, device='cuda'):
        try:
            from llava.model import LlavaLlamaForCausalLM
        except ImportError:
            # print("You need to manually install llava -> pip install --no-deps  git+https://github.com/haotian-liu/LLaVA.git")
            print(
                "You need to manually install llava -> pip install --no-deps  git+https://github.com/haotian-liu/LLaVA.git")
            raise
        self.device = device
        self.model: LlavaLlamaForCausalLM = None
        self.tokenizer: AutoTokenizer = None
        self.image_processor: CLIPImageProcessor = None
        self.is_loaded = False

    def load_model(self):
        from llava.model import LlavaLlamaForCausalLM

        model_path = "4bit/llava-v1.5-13b-3GB"
        # kwargs = {"device_map": "auto"}
        kwargs = {"device_map": self.device}
        kwargs['load_in_4bit'] = True
        kwargs['quantization_config'] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        )
        self.model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
        vision_tower = self.model.get_vision_tower()
        if not vision_tower.is_loaded:
            vision_tower.load_model()
        vision_tower.to(device=self.device)
        self.image_processor = vision_tower.image_processor
        self.is_loaded = True

    def generate_caption(
            self, image:
            Image, prompt: str = default_long_prompt,
            replacements=default_replacements,
            max_new_tokens=512
    ):
        from llava.conversation import conv_templates, SeparatorStyle
        from llava.utils import disable_torch_init
        from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
        from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
        # question = "how many dogs are in the picture?"
        disable_torch_init()
        conv_mode = "llava_v0"
        conv = conv_templates[conv_mode].copy()
        roles = conv.roles
        image_tensor = self.image_processor.preprocess([image], return_tensors='pt')['pixel_values'].half().cuda()

        inp = f"{roles[0]}: {prompt}"
        inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
        conv.append_message(conv.roles[0], inp)
        conv.append_message(conv.roles[1], None)
        raw_prompt = conv.get_prompt()
        input_ids = tokenizer_image_token(raw_prompt, self.tokenizer, IMAGE_TOKEN_INDEX,
                                          return_tensors='pt').unsqueeze(0).cuda()
        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids, images=image_tensor, do_sample=True, temperature=0.1,
                max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria],
                top_p=0.8
            )
        outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
        conv.messages[-1][-1] = outputs
        output = outputs.rsplit('</s>', 1)[0]
        return clean_caption(output, replacements=replacements)