Spaces:
Runtime error
Runtime error
File size: 2,105 Bytes
0d08077 7d06c4c 0d08077 df766f8 0d08077 df766f8 0d08077 bc65b96 df766f8 0d08077 7d06c4c 0d08077 df766f8 2bcaca6 df766f8 0d08077 2e7d5a4 df766f8 0d08077 2e7d5a4 0d08077 bc65b96 df766f8 e651c62 bc65b96 0d08077 2e7d5a4 8ca734f 2e7d5a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import gradio as gr
import torch
from PIL import Image
from model import GitBaseCocoModel, BlipBaseModel
MODELS = {
"Git-Base-COCO": GitBaseCocoModel,
"Blip Base": BlipBaseModel,
}
def generate_captions(
image,
num_captions,
max_length,
temperature,
top_k,
top_p,
repetition_penalty,
diversity_penalty,
model_name,
):
"""
Generates captions for the given image.
-----
Parameters:
image: PIL.Image
The image to generate captions for.
max_len: int
The maximum length of the caption.
num_captions: int
The number of captions to generate.
-----
Returns:
list[str]
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MODELS[model_name](device)
captions = model.generate(
image,
max_length,
num_captions,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
diversity_penalty=diversity_penalty,
)
# Convert list to a single string separated by newlines.
captions = "\n".join(captions)
return captions
title = "Git-Base-COCO Image Captioning"
description = "A model for generating captions for images."
interface = gr.Interface(
fn=generate_captions,
inputs=[
gr.inputs.Image(type="pil", label="Image"),
gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Captions to Generate"),
gr.inputs.Slider(minimum=20, maximum=100, step=5, default=50, label="Maximum Caption Length"),
gr.inputs.Slider(minimum=0.1, maximum=10.0, step=0.1, default=1.0, label="Temperature"),
gr.inputs.Slider(minimum=1, maximum=100, step=1, default=50, label="Top K"),
gr.inputs.Slider(minimum=-5.0, maximum=5.0, step=0.1, default=1.0, label="Top P"),
gr.inputs.Slider(minimum=1.0, maximum=10.0, step=0.1, default=1.0, label="Repetition Penalty"),
gr.inputs.Slider(minimum=0.0, maximum=10.0, step=0.1, default=0.0, label="Diversity Penalty"),
gr.inputs.Dropdown(MODELS.keys(), label="Model"),
],
outputs=[
gr.outputs.Textbox(label="Caption"),
],
title=title,
description=description,
)
if __name__ == "__main__":
interface.launch(
enable_queue=True,
debug=True
) |