caption-match / app.py
iamrobotbear's picture
add the statements
d72dfa9
raw
history blame
3.12 kB
import gradio as gr
import torch
from pathlib import Path
from PIL import Image
import pandas as pd
from lavis.models import load_model_and_preprocess
from lavis.processors import load_processor
from transformers import CLIPProcessor, CLIPModel
# Load model and preprocessors for Image-Text Matching (LAVIS)
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model_itm, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True)
# Load model and processor for Image Captioning (TextCaps)
model_caption = CLIPModel.from_pretrained("microsoft/git-large-r-textcaps")
processor_caption = CLIPProcessor.from_pretrained("microsoft/git-large-r-textcaps")
# List of statements for Image-Text Matching
# List of statements (Make sure to fully define this list with actual statements)
statements = [
"cartoon, figurine, or toy",
"appears to be for children",
"includes children",
"is sexual",
"depicts a child or portrays objects, images, or cartoon figures that primarily appeal to persons below the legal purchase age",
"uses the name of or depicts Santa Claus",
'promotes alcohol use as a "rite of passage" to adulthood',
]
txts = [text_processors["eval"](statement) for statement in statements]
# Function to compute Image-Text Matching (ITM) scores for all statements
def compute_itm_scores(image):
pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
img = vis_processors["eval"](pil_image.convert("RGB")).unsqueeze(0).to(device)
results = []
for i, statement in enumerate(statements):
txt = txts[i]
itm_output = model_itm({"image": img, "text_input": txt}, match_head="itm")
itm_scores = torch.nn.functional.softmax(itm_output, dim=1)
score = itm_scores[:, 1].item()
result_text = f'The image and "{statement}" are matched with a probability of {score:.3%}'
results.append(result_text)
output = "\n".join(results)
return output
# Function to generate image captions using TextCaps
def generate_image_captions(image):
pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
inputs = processor_caption(pil_image, return_tensors="pt", padding=True, truncation=True)
outputs = model_caption.generate(**inputs)
caption = processor_caption.decode(outputs[0])
return caption
# Main function to perform image captioning and image-text matching
def process_images_and_statements(image):
# Generate image captions using TextCaps
captions = generate_image_captions(image)
# Compute ITM scores for predefined statements using LAVIS
itm_scores = compute_itm_scores(image)
# Combine image captions and ITM scores into the output
output = "Image Captions:\n" + captions + "\n\nITM Scores:\n" + itm_scores
return output
# Gradio interface
image_input = gr.inputs.Image()
output = gr.outputs.Textbox(label="Results")
iface = gr.Interface(fn=process_images_and_statements, inputs=image_input, outputs=output, title="Image Captioning and Image-Text Matching")
iface.launch()