File size: 3,081 Bytes
f4d058d
ffd0e5b
 
 
 
 
 
 
 
 
 
 
f4d058d
ffd0e5b
 
 
 
 
 
 
 
 
 
 
f4d058d
ffd0e5b
 
 
f4d058d
 
 
 
 
 
ffd0e5b
 
 
f4d058d
ffd0e5b
 
f4d058d
ffd0e5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4d058d
 
 
 
 
 
 
 
ffd0e5b
f4d058d
ffd0e5b
f4d058d
ffd0e5b
 
 
f4d058d
ffd0e5b
 
f4d058d
ffd0e5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4d058d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
from PIL import Image
from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration,
    BlipConfig,
    BlipTextConfig,
    BlipVisionConfig,
)

import torch
import model_management
import folder_paths

class BLIPImg2Txt:
    def __init__(

        self,

        conditional_caption: str,

        min_words: int,

        max_words: int,

        temperature: float,

        repetition_penalty: float,

        search_beams: int,

        model_id: str = "Salesforce/blip-image-captioning-large",

        custom_model_path: str = None,

    ):
        self.conditional_caption = conditional_caption
        self.model_id = model_id
        self.custom_model_path = custom_model_path

        if self.custom_model_path and os.path.exists(self.custom_model_path):
            self.model_path = self.custom_model_path
        else:
            self.model_path = folder_paths.get_full_path("blip", model_id)

        if temperature > 1.1 or temperature < 0.90:
            do_sample = True
            num_beams = 1
        else:
            do_sample = False
            num_beams = search_beams if search_beams > 1 else 1

        self.text_config_kwargs = {
            "do_sample": do_sample,
            "max_length": max_words,
            "min_length": min_words,
            "repetition_penalty": repetition_penalty,
            "padding": "max_length",
        }
        if not do_sample:
            self.text_config_kwargs["temperature"] = temperature
            self.text_config_kwargs["num_beams"] = num_beams

    def generate_caption(self, image: Image.Image) -> str:
        if image.mode != "RGB":
            image = image.convert("RGB")

        if self.model_path and os.path.exists(self.model_path):
            model_path = self.model_path
            local_files_only = True
        else:
            model_path = self.model_id
            local_files_only = False

        processor = BlipProcessor.from_pretrained(model_path, local_files_only=local_files_only)

        config_text = BlipTextConfig.from_pretrained(model_path, local_files_only=local_files_only)
        config_text.update(self.text_config_kwargs)
        config_vision = BlipVisionConfig.from_pretrained(model_path, local_files_only=local_files_only)
        config = BlipConfig.from_text_vision_configs(config_text, config_vision)

        model = BlipForConditionalGeneration.from_pretrained(
            model_path,
            config=config,
            torch_dtype=torch.float16,
            local_files_only=local_files_only
        ).to(model_management.get_torch_device())

        inputs = processor(
            image,
            self.conditional_caption,
            return_tensors="pt",
        ).to(model_management.get_torch_device(), torch.float16)

        with torch.no_grad():
            out = model.generate(**inputs)
            ret = processor.decode(out[0], skip_special_tokens=True)

        del model
        torch.cuda.empty_cache()

        return ret