File size: 5,812 Bytes
0d54c12
9845f41
d3bbf05
9845f41
 
7295a68
 
9845f41
 
 
 
b8b6ade
 
9845f41
d3bbf05
 
 
9845f41
9cce4c8
 
 
 
 
fbee9c4
 
9845f41
d3bbf05
 
 
 
 
9845f41
 
 
b8b6ade
d3bbf05
9cce4c8
fbee9c4
d3bbf05
9845f41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d54c12
9845f41
b8b6ade
7295a68
4bfc3de
b8b6ade
8236a85
b8b6ade
 
 
 
8236a85
b8b6ade
 
9845f41
b8b6ade
9845f41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbee9c4
9845f41
fbee9c4
 
9845f41
 
 
 
99813d9
 
7295a68
9845f41
 
 
 
 
 
 
 
 
 
 
 
 
 
fbee9c4
 
 
 
 
9845f41
fbee9c4
 
 
 
9845f41
d3bbf05
fbee9c4
 
 
 
9cce4c8
 
 
 
 
 
9845f41
9cce4c8
fbee9c4
 
 
 
9845f41
d3bbf05
fbee9c4
 
 
 
9845f41
9cce4c8
9845f41
 
 
 
 
 
 
9d074e1
9845f41
 
 
 
 
b8b6ade
d3bbf05
9cce4c8
fbee9c4
d3bbf05
9845f41
 
 
 
 
 
 
7295a68
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import traceback
import gradio as gr
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, VisionEncoderDecoderModel, BitsAndBytesConfig
import torch
import open_clip
from PIL import Image
import requests

from huggingface_hub import hf_hub_download

# Load the Blip base model
preprocessor_blip_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model_blip_base = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

# Load the Blip large model
preprocessor_blip_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model_blip_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")

# Load the GIT coco base model
preprocessor_git_base_coco = AutoProcessor.from_pretrained("microsoft/git-base-coco")
model_git_base_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")

# Load the GIT coco large model
preprocessor_git_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
model_git_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")

# Load the CLIP model
model_oc_coca, _, transform_oc_coca = open_clip.create_model_and_transforms(
	model_name="coca_ViT-L-14",
	pretrained="mscoco_finetuned_laion2B-s13B-b90k"
)

device = "cuda" if torch.cuda.is_available() else "cpu"
# Transfer the models to the device
model_blip_base.to(device)
model_blip_large.to(device)
model_git_base_coco.to(device)
model_git_large_coco.to(device)
model_oc_coca.to(device)


def generate_caption(
	preprocessor,
	model,
	image,
	tokenizer=None,
):
	"""
	Generate captions for the given image.

	-----
	Parameters
	preprocessor: AutoProcessor
		The preprocessor for the model.
	model: BlipForConditionalGeneration
		The model to use.
	image: PIL.Image
		The image to generate captions for.
	tokenizer: AutoTokenizer
		The tokenizer to use. If None, the default tokenizer for the model will be used.
	use_float_16: bool
		Whether to use float16 precision. This can speed up inference, but may lead to worse results.

	-----
	Returns
	str
		The generated caption.
	"""
	pixel_values = preprocessor(images=image, return_tensors="pt").pixel_values.to(device)

	generated_ids = model.generate(
		pixel_values=pixel_values,
		max_length=50,
	)

	if tokenizer is None:
		generated_caption = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0]
	else:
		generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

	# generated_ids = model.generate(**inputs, max_new_tokens=32)
	# generated_text = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()	
		
	return generated_caption


def generate_captions_clip(
	model,
	transform,
	image
):
	"""
	Generate captions for the given image using CLIP.

	-----
	Parameters
	model: VisionEncoderDecoderModel
		The CLIP model to use.
	transform: Callable
		The transform to apply to the image before passing it to the model.
	image: PIL.Image
		The image to generate captions for.

	-----
	Returns
	str
		The generated caption.
	"""
	im = transform(image).unsqueeze(0).to(device)
	with torch.no_grad(), torch.cuda.amp.autocast():
		generated = model.generate(im, seq_len=20)
	generated_caption = open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
	return generated_caption


def generate_captions(
	image,
	max_length,
	temperature,
):
	"""
	Generate captions for the given image.

	-----
	Parameters
	image: PIL.Image
		The image to generate captions for.

	-----
	Returns
	str
		The generated caption.
	"""
	caption_blip_base = ""
	caption_blip_large = ""
	caption_git_large_coco = ""
	caption_oc_coca = ""

	# Generate captions for the image using the Blip base model
	try:
		caption_blip_base = generate_caption(preprocessor_blip_base, model_blip_base, image).strip()
	except Exception as e:
		print(e)

	# Generate captions for the image using the Blip large model
	try:
		caption_blip_large = generate_caption(preprocessor_blip_large, model_blip_large, image).strip()
	except Exception as e:
		print(e)

	# Generate captions for the image using the GIT coco base model
	try:
		caption_git_base_coco = generate_caption(preprocessor_git_base_coco, model_git_base_coco, image).strip()
	except Exception as e:
		print(e)
	
	# Generate captions for the image using the GIT coco large model
	try:
		caption_git_large_coco = generate_caption(preprocessor_git_large_coco, model_git_large_coco, image).strip()
	except Exception as e:
		print(e)
	
	# Generate captions for the image using the CLIP model
	try:
		caption_oc_coca = generate_captions_clip(model_oc_coca, transform_oc_coca, image).strip()
	except Exception as e:
		print(e)

	return caption_blip_base, caption_blip_large, caption_git_base_coco, caption_git_large_coco, caption_oc_coca


# Create the interface
iface = gr.Interface(
	fn=generate_captions,
	# Define the inputs: Image, Slider for Max Length, Slider for Temperature
	inputs=[
		gr.inputs.Image(type="pil", label="Image"),
		gr.inputs.Slider(minimum=16, maximum=64, step=2, default=32, label="Max Length"),
		gr.inputs.Slider(minimum=0.5, maximum=1.5, step=0.1, default=1.0, label="Temperature"),
	],
	# Define the outputs
	outputs=[
		gr.outputs.Textbox(label="Blip base"),
		gr.outputs.Textbox(label="Blip large"),
		gr.outputs.Textbox(label="GIT base coco"),
		gr.outputs.Textbox(label="GIT large coco"),
		gr.outputs.Textbox(label="CLIP"),
	],
	title="Image Captioning",
	description="Generate captions for images using the Blip2 model, the Blip base model, the Blip large model, the GIT large coco model, and the CLIP model.",
	enable_queue=True,
)

# Launch the interface
iface.launch(debug=True)