File size: 2,105 Bytes
0d08077
7d06c4c
0d08077
 
df766f8
0d08077
df766f8
 
 
 
0d08077
 
bc65b96
 
df766f8
 
 
 
 
 
 
0d08077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d06c4c
0d08077
df766f8
 
 
 
 
 
2bcaca6
 
 
 
 
df766f8
0d08077
2e7d5a4
df766f8
 
0d08077
 
 
 
2e7d5a4
0d08077
bc65b96
 
 
df766f8
 
 
 
 
 
e651c62
bc65b96
 
 
 
0d08077
 
2e7d5a4
 
 
 
 
 
8ca734f
2e7d5a4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
import torch
from PIL import Image

from model import GitBaseCocoModel, BlipBaseModel

MODELS = {
	"Git-Base-COCO": GitBaseCocoModel,
	"Blip Base": BlipBaseModel,
}

def generate_captions(
	image,
	num_captions,
	max_length,
	temperature,
	top_k,
	top_p,
	repetition_penalty,
	diversity_penalty,
	model_name,
	):
	"""
	Generates captions for the given image.
	
	-----
	Parameters:
	image: PIL.Image
		The image to generate captions for.
	max_len: int
		The maximum length of the caption.
	num_captions: int
		The number of captions to generate.

	-----
	Returns:
	list[str]
	"""

	device = "cuda" if torch.cuda.is_available() else "cpu"
	
	model = MODELS[model_name](device)

	captions = model.generate(
		image,
		max_length,
		num_captions,
		temperature=temperature,
		top_k=top_k,
		top_p=top_p,
		repetition_penalty=repetition_penalty,
		diversity_penalty=diversity_penalty,
	)

	# Convert list to a single string separated by newlines.
	captions = "\n".join(captions)
	return captions

title = "Git-Base-COCO Image Captioning"
description = "A model for generating captions for images."

interface = gr.Interface(
	fn=generate_captions,
	inputs=[
		gr.inputs.Image(type="pil", label="Image"),
		gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Captions to Generate"),
		gr.inputs.Slider(minimum=20, maximum=100, step=5, default=50, label="Maximum Caption Length"),
		gr.inputs.Slider(minimum=0.1, maximum=10.0, step=0.1, default=1.0, label="Temperature"),
		gr.inputs.Slider(minimum=1, maximum=100, step=1, default=50, label="Top K"),
		gr.inputs.Slider(minimum=-5.0, maximum=5.0, step=0.1, default=1.0, label="Top P"),
		gr.inputs.Slider(minimum=1.0, maximum=10.0, step=0.1, default=1.0, label="Repetition Penalty"),
		gr.inputs.Slider(minimum=0.0, maximum=10.0, step=0.1, default=0.0, label="Diversity Penalty"),
		gr.inputs.Dropdown(MODELS.keys(), label="Model"),
	],
	outputs=[
		gr.outputs.Textbox(label="Caption"),
	],
	title=title,
	description=description,
	)


if __name__ == "__main__":
	interface.launch(
		enable_queue=True,
		debug=True
	)