import gradio as gr import torch from PIL import Image import pandas as pd from lavis.models import load_model_and_preprocess from lavis.processors import load_processor from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor import tensorflow as tf import tensorflow_hub as hub import io from sklearn.metrics.pairwise import cosine_similarity import tempfile # Add this import import logging # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Load model and preprocessors for Image-Text Matching (LAVIS) device = torch.device("cuda") if torch.cuda.is_available() else "cpu" model_itm, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True) # Load tokenizer and model for Image Captioning (TextCaps) git_processor_large_textcaps = AutoProcessor.from_pretrained("microsoft/git-large-r-textcaps") git_model_large_textcaps = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps") # Load Universal Sentence Encoder model for textual similarity calculation embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4") # Define a function to compute textual similarity between caption and statement def compute_textual_similarity(caption, statement): # Convert caption and statement into sentence embeddings caption_embedding = embed([caption])[0].numpy() statement_embedding = embed([statement])[0].numpy() # Calculate cosine similarity between sentence embeddings similarity_score = cosine_similarity([caption_embedding], [statement_embedding])[0][0] return similarity_score # List of statements for Image-Text Matching statements = [ "contains or features a cartoon, figurine, or toy", "appears to be for children", "includes children", "sexual", "nudity", "depicts a child or portrays objects, images, or cartoon figures that primarily appeal to persons below the legal purchase age", "uses the name of or depicts Santa Claus", 'promotes alcohol use as a "rite of passage" to adulthood', "uses brand identification—including logos, trademarks, or names—on clothing, toys, games, game equipment, or other items intended for use primarily by persons below the legal purchase age", "portrays persons in a state of intoxication or in any way suggests that intoxication is socially acceptable conduct", "makes curative or therapeutic claims, except as permitted by law", "makes claims or representations that individuals can attain social, professional, educational, or athletic success or status due to beverage alcohol consumption", "degrades the image, form, or status of women, men, or of any ethnic group, minority, sexual orientation, religious affiliation, or other such group?", "uses lewd or indecent images or language", "employs religion or religious themes?", "relies upon sexual prowess or sexual success as a selling point for the brand", "uses graphic or gratuitous nudity, overt sexual activity, promiscuity, or sexually lewd or indecent images or language", "associates with anti-social or dangerous behavior", "depicts illegal activity of any kind", 'uses the term "spring break" or sponsors events or activities that use the term "spring break," unless those events or activities are located at a licensed retail establishment', ] # Function to compute ITM scores for the image-statement pair def compute_itm_score(image, statement): logging.info('Starting compute_itm_score') pil_image = Image.fromarray(image.astype('uint8'), 'RGB') img = vis_processors["eval"](pil_image.convert("RGB")).unsqueeze(0).to(device) # Pass the statement text directly to model_itm itm_output = model_itm({"image": img, "text_input": statement}, match_head="itm") itm_scores = torch.nn.functional.softmax(itm_output, dim=1) score = itm_scores[:, 1].item() logging.info('Finished compute_itm_score') return score def generate_caption(processor, model, image): logging.info('Starting generate_caption') inputs = processor(images=image, return_tensors="pt").to(device) generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50) generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] logging.info('Finished generate_caption') return generated_caption def save_dataframe_to_csv(df): csv_buffer = io.StringIO() df.to_csv(csv_buffer, index=False) csv_string = csv_buffer.getvalue() # Save the CSV string to a temporary file with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as temp_file: temp_file.write(csv_string) temp_file_path = temp_file.name # Get the file path # Return the file path (no need to reopen the file with "rb" mode) return temp_file_path # Main function to perform image captioning and image-text matching def process_images_and_statements(image): logging.info('Starting process_images_and_statements') # Generate image caption for the uploaded image using git-large-r-textcaps caption = generate_caption(git_processor_large_textcaps, git_model_large_textcaps, image) # Define weights for combining textual similarity score and image-statement ITM score (adjust as needed) weight_textual_similarity = 0.5 weight_statement = 0.5 # Initialize an empty DataFrame with column names results_df = pd.DataFrame(columns=['Statement', 'Generated Caption', 'Textual Similarity Score', 'ITM Score', 'Final Combined Score']) # Loop through each predefined statement for statement in statements: # Compute textual similarity between caption and statement textual_similarity_score = (compute_textual_similarity(caption, statement) * 100) # Multiply by 100 # Compute ITM score for the image-statement pair itm_score_statement = (compute_itm_score(image, statement) * 100) # Multiply by 100 # Combine the two scores using a weighted average final_score = ((weight_textual_similarity * textual_similarity_score) + (weight_statement * itm_score_statement)) # Append the result to the DataFrame with formatted percentage values results_df = results_df.append({ 'Statement': statement, 'Generated Caption': caption, # Include the generated caption 'Textual Similarity Score': f"{textual_similarity_score:.2f}%", # Format as percentage with two decimal places 'ITM Score': f"{itm_score_statement:.2f}%", # Format as percentage with two decimal places 'Final Combined Score': f"{final_score:.2f}%" # Format as percentage with two decimal places }, ignore_index=True) logging.info('Finished process_images_and_statements') # Save results_df to a CSV file csv_results = save_dataframe_to_csv(results_df) # Return both the DataFrame and the CSV data for the Gradio interface return results_df, csv_results # <--- Return results_df and csv_results # Gradio interface image_input = gr.inputs.Image() output_df = gr.outputs.Dataframe(type="pandas", label="Results") output_csv = gr.outputs.File(label="Download CSV") iface = gr.Interface( fn=process_images_and_statements, inputs=image_input, outputs=[output_df, output_csv], # Include both the DataFrame and CSV file outputs title="Image Captioning and Image-Text Matching", theme='sudeepshouche/minimalist', css=".output { flex-direction: column; } .output .outputs { width: 100%; }" # Custom CSS ) iface.launch()