Spaces:
Runtime error
Runtime error
File size: 3,081 Bytes
f4d058d ffd0e5b f4d058d ffd0e5b f4d058d ffd0e5b f4d058d ffd0e5b f4d058d ffd0e5b f4d058d ffd0e5b f4d058d ffd0e5b f4d058d ffd0e5b f4d058d ffd0e5b f4d058d ffd0e5b f4d058d ffd0e5b f4d058d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os
from PIL import Image
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
BlipConfig,
BlipTextConfig,
BlipVisionConfig,
)
import torch
import model_management
import folder_paths
class BLIPImg2Txt:
def __init__(
self,
conditional_caption: str,
min_words: int,
max_words: int,
temperature: float,
repetition_penalty: float,
search_beams: int,
model_id: str = "Salesforce/blip-image-captioning-large",
custom_model_path: str = None,
):
self.conditional_caption = conditional_caption
self.model_id = model_id
self.custom_model_path = custom_model_path
if self.custom_model_path and os.path.exists(self.custom_model_path):
self.model_path = self.custom_model_path
else:
self.model_path = folder_paths.get_full_path("blip", model_id)
if temperature > 1.1 or temperature < 0.90:
do_sample = True
num_beams = 1
else:
do_sample = False
num_beams = search_beams if search_beams > 1 else 1
self.text_config_kwargs = {
"do_sample": do_sample,
"max_length": max_words,
"min_length": min_words,
"repetition_penalty": repetition_penalty,
"padding": "max_length",
}
if not do_sample:
self.text_config_kwargs["temperature"] = temperature
self.text_config_kwargs["num_beams"] = num_beams
def generate_caption(self, image: Image.Image) -> str:
if image.mode != "RGB":
image = image.convert("RGB")
if self.model_path and os.path.exists(self.model_path):
model_path = self.model_path
local_files_only = True
else:
model_path = self.model_id
local_files_only = False
processor = BlipProcessor.from_pretrained(model_path, local_files_only=local_files_only)
config_text = BlipTextConfig.from_pretrained(model_path, local_files_only=local_files_only)
config_text.update(self.text_config_kwargs)
config_vision = BlipVisionConfig.from_pretrained(model_path, local_files_only=local_files_only)
config = BlipConfig.from_text_vision_configs(config_text, config_vision)
model = BlipForConditionalGeneration.from_pretrained(
model_path,
config=config,
torch_dtype=torch.float16,
local_files_only=local_files_only
).to(model_management.get_torch_device())
inputs = processor(
image,
self.conditional_caption,
return_tensors="pt",
).to(model_management.get_torch_device(), torch.float16)
with torch.no_grad():
out = model.generate(**inputs)
ret = processor.decode(out[0], skip_special_tokens=True)
del model
torch.cuda.empty_cache()
return ret |