caption-match / app.py
iamrobotbear's picture
Update app.py
e1bf7e0
raw
history blame
3.09 kB
import gradio as gr
import torch
from pathlib import Path
from PIL import Image
import pandas as pd
from lavis.models import load_model_and_preprocess
from lavis.processors import load_processor
from transformers import CLIPTokenizerFast, CLIPModel # Import CLIPTokenizerFast
# Load model and preprocessors for Image-Text Matching (LAVIS)
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model_itm, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True)
# Load model and processor for Image Captioning (TextCaps)
model_caption = CLIPModel.from_pretrained("microsoft/git-large-r-textcaps")
processor_caption = CLIPTokenizerFast.from_pretrained("microsoft/git-large-r-textcaps", from_slow=True) # Convert tokenizer
# List of statements for Image-Text Matching
statements = [
"cartoon, figurine, or toy",
"appears to be for children",
"includes children",
"is sexual",
"depicts a child or portrays objects, images, or cartoon figures that primarily appeal to persons below the legal purchase age",
"uses the name of or depicts Santa Claus",
'promotes alcohol use as a "rite of passage" to adulthood',
]
txts = [text_processors["eval"](statement) for statement in statements]
# Function to compute Image-Text Matching (ITM) scores for all statements
def compute_itm_scores(image):
pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
img = vis_processors["eval"](pil_image.convert("RGB")).unsqueeze(0).to(device)
results = []
for i, statement in enumerate(statements):
txt = txts[i]
itm_output = model_itm({"image": img, "text_input": txt}, match_head="itm")
itm_scores = torch.nn.functional.softmax(itm_output, dim=1)
score = itm_scores[:, 1].item()
result_text = f'The image and "{statement}" are matched with a probability of {score:.3%}'
results.append(result_text)
output = "\n".join(results)
return output
# Function to generate image captions using TextCaps
def generate_image_captions(image):
pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
inputs = processor_caption(pil_image, return_tensors="pt", padding=True, truncation=True)
outputs = model_caption.generate(**inputs)
caption = processor_caption.decode(outputs[0])
return caption
# Main function to perform image captioning and image-text matching
def process_images_and_statements(image):
# Generate image captions using TextCaps
captions = generate_image_captions(image)
# Compute ITM scores for predefined statements using LAVIS
itm_scores = compute_itm_scores(image)
# Combine image captions and ITM scores into the output
output = "Image Captions:\n" + captions + "\n\nITM Scores:\n" + itm_scores
return output
image_input = gr.inputs.Image()
output = gr.outputs.Textbox(label="Results")
iface = gr.Interface(fn=process_images_and_statements, inputs=image_input, outputs=output, title="Image Captioning and Image-Text Matching")
iface.launch()