chats-bug
Added git base coco
9cce4c8
raw
history blame
5.81 kB
import traceback
import gradio as gr
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, VisionEncoderDecoderModel, BitsAndBytesConfig
import torch
import open_clip
from PIL import Image
import requests
from huggingface_hub import hf_hub_download
# Load the Blip base model
preprocessor_blip_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model_blip_base = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# Load the Blip large model
preprocessor_blip_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model_blip_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
# Load the GIT coco base model
preprocessor_git_base_coco = AutoProcessor.from_pretrained("microsoft/git-base-coco")
model_git_base_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
# Load the GIT coco large model
preprocessor_git_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
model_git_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
# Load the CLIP model
model_oc_coca, _, transform_oc_coca = open_clip.create_model_and_transforms(
model_name="coca_ViT-L-14",
pretrained="mscoco_finetuned_laion2B-s13B-b90k"
)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Transfer the models to the device
model_blip_base.to(device)
model_blip_large.to(device)
model_git_base_coco.to(device)
model_git_large_coco.to(device)
model_oc_coca.to(device)
def generate_caption(
preprocessor,
model,
image,
tokenizer=None,
):
"""
Generate captions for the given image.
-----
Parameters
preprocessor: AutoProcessor
The preprocessor for the model.
model: BlipForConditionalGeneration
The model to use.
image: PIL.Image
The image to generate captions for.
tokenizer: AutoTokenizer
The tokenizer to use. If None, the default tokenizer for the model will be used.
use_float_16: bool
Whether to use float16 precision. This can speed up inference, but may lead to worse results.
-----
Returns
str
The generated caption.
"""
pixel_values = preprocessor(images=image, return_tensors="pt").pixel_values.to(device)
generated_ids = model.generate(
pixel_values=pixel_values,
max_length=50,
)
if tokenizer is None:
generated_caption = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# generated_ids = model.generate(**inputs, max_new_tokens=32)
# generated_text = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return generated_caption
def generate_captions_clip(
model,
transform,
image
):
"""
Generate captions for the given image using CLIP.
-----
Parameters
model: VisionEncoderDecoderModel
The CLIP model to use.
transform: Callable
The transform to apply to the image before passing it to the model.
image: PIL.Image
The image to generate captions for.
-----
Returns
str
The generated caption.
"""
im = transform(image).unsqueeze(0).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
generated = model.generate(im, seq_len=20)
generated_caption = open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
return generated_caption
def generate_captions(
image,
max_length,
temperature,
):
"""
Generate captions for the given image.
-----
Parameters
image: PIL.Image
The image to generate captions for.
-----
Returns
str
The generated caption.
"""
caption_blip_base = ""
caption_blip_large = ""
caption_git_large_coco = ""
caption_oc_coca = ""
# Generate captions for the image using the Blip base model
try:
caption_blip_base = generate_caption(preprocessor_blip_base, model_blip_base, image).strip()
except Exception as e:
print(e)
# Generate captions for the image using the Blip large model
try:
caption_blip_large = generate_caption(preprocessor_blip_large, model_blip_large, image).strip()
except Exception as e:
print(e)
# Generate captions for the image using the GIT coco base model
try:
caption_git_base_coco = generate_caption(preprocessor_git_base_coco, model_git_base_coco, image).strip()
except Exception as e:
print(e)
# Generate captions for the image using the GIT coco large model
try:
caption_git_large_coco = generate_caption(preprocessor_git_large_coco, model_git_large_coco, image).strip()
except Exception as e:
print(e)
# Generate captions for the image using the CLIP model
try:
caption_oc_coca = generate_captions_clip(model_oc_coca, transform_oc_coca, image).strip()
except Exception as e:
print(e)
return caption_blip_base, caption_blip_large, caption_git_base_coco, caption_git_large_coco, caption_oc_coca
# Create the interface
iface = gr.Interface(
fn=generate_captions,
# Define the inputs: Image, Slider for Max Length, Slider for Temperature
inputs=[
gr.inputs.Image(type="pil", label="Image"),
gr.inputs.Slider(minimum=16, maximum=64, step=2, default=32, label="Max Length"),
gr.inputs.Slider(minimum=0.5, maximum=1.5, step=0.1, default=1.0, label="Temperature"),
],
# Define the outputs
outputs=[
gr.outputs.Textbox(label="Blip base"),
gr.outputs.Textbox(label="Blip large"),
gr.outputs.Textbox(label="GIT base coco"),
gr.outputs.Textbox(label="GIT large coco"),
gr.outputs.Textbox(label="CLIP"),
],
title="Image Captioning",
description="Generate captions for images using the Blip2 model, the Blip base model, the Blip large model, the GIT large coco model, and the CLIP model.",
enable_queue=True,
)
# Launch the interface
iface.launch(debug=True)