|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
import pandas as pd |
|
from lavis.models import load_model_and_preprocess |
|
from lavis.processors import load_processor |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor |
|
import tensorflow as tf |
|
import tensorflow_hub as hub |
|
import io |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import tempfile |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else "cpu" |
|
model_itm, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True) |
|
|
|
|
|
git_processor_large_textcaps = AutoProcessor.from_pretrained("microsoft/git-large-r-textcaps") |
|
git_model_large_textcaps = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps") |
|
|
|
|
|
embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4") |
|
|
|
|
|
def compute_textual_similarity(caption, statement): |
|
|
|
caption_embedding = embed([caption])[0].numpy() |
|
statement_embedding = embed([statement])[0].numpy() |
|
|
|
|
|
similarity_score = cosine_similarity([caption_embedding], [statement_embedding])[0][0] |
|
return similarity_score |
|
|
|
|
|
with open('statements.txt', 'r') as file: |
|
statements = file.read().splitlines() |
|
|
|
|
|
def compute_itm_score(image, statement): |
|
logging.info('Starting compute_itm_score') |
|
pil_image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
img = vis_processors["eval"](pil_image.convert("RGB")).unsqueeze(0).to(device) |
|
|
|
itm_output = model_itm({"image": img, "text_input": statement}, match_head="itm") |
|
itm_scores = torch.nn.functional.softmax(itm_output, dim=1) |
|
score = itm_scores[:, 1].item() |
|
logging.info('Finished compute_itm_score') |
|
return score |
|
|
|
def generate_caption(processor, model, image): |
|
logging.info('Starting generate_caption') |
|
inputs = processor(images=image, return_tensors="pt").to(device) |
|
generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50) |
|
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
logging.info('Finished generate_caption') |
|
return generated_caption |
|
|
|
def save_dataframe_to_csv(df): |
|
csv_buffer = io.StringIO() |
|
df.to_csv(csv_buffer, index=False) |
|
csv_string = csv_buffer.getvalue() |
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as temp_file: |
|
temp_file.write(csv_string) |
|
temp_file_path = temp_file.name |
|
|
|
|
|
return temp_file_path |
|
|
|
|
|
def process_images_and_statements(image): |
|
logging.info('Starting process_images_and_statements') |
|
|
|
|
|
caption = generate_caption(git_processor_large_textcaps, git_model_large_textcaps, image) |
|
|
|
|
|
weight_textual_similarity = 0.5 |
|
weight_statement = 0.5 |
|
|
|
|
|
results_list = [] |
|
|
|
|
|
for statement in statements: |
|
|
|
textual_similarity_score = (compute_textual_similarity(caption, statement) * 100) |
|
|
|
|
|
itm_score_statement = (compute_itm_score(image, statement) * 100) |
|
|
|
|
|
final_score = ((weight_textual_similarity * textual_similarity_score) + |
|
(weight_statement * itm_score_statement)) |
|
|
|
|
|
results_list.append({ |
|
'Statement': statement, |
|
'Generated Caption': caption, |
|
'Textual Similarity Score': f"{textual_similarity_score:.2f}%", |
|
'ITM Score': f"{itm_score_statement:.2f}%", |
|
'Final Combined Score': f"{final_score:.2f}%" |
|
}) |
|
|
|
|
|
results_df = pd.concat([pd.DataFrame([result]) for result in results_list], ignore_index=True) |
|
|
|
logging.info('Finished process_images_and_statements') |
|
|
|
|
|
csv_results = save_dataframe_to_csv(results_df) |
|
|
|
|
|
return results_df, csv_results |
|
|
|
|
|
image_input = gr.inputs.Image() |
|
output_df = gr.outputs.Dataframe(type="pandas", label="Results") |
|
output_csv = gr.outputs.File(label="Download CSV") |
|
|
|
iface = gr.Interface( |
|
fn=process_images_and_statements, |
|
inputs=image_input, |
|
outputs=[output_df, output_csv], |
|
title="Image Captioning and Image-Text Matching", |
|
theme='sudeepshouche/minimalist', |
|
css=".output { flex-direction: column; } .output .outputs { width: 100%; }" |
|
) |
|
|
|
iface.launch() |