File size: 4,506 Bytes
8e34f80
7ecd689
 
 
 
 
 
8e34f80
 
7ecd689
be487a3
 
7ecd689
be487a3
 
7ecd689
35e0000
7ecd689
 
35e0000
7ecd689
 
 
8e34f80
7ecd689
 
 
 
 
 
35e0000
7ecd689
35e0000
7ecd689
 
 
 
 
 
 
8e34f80
7ecd689
 
 
 
 
35e0000
7ecd689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e34f80
35e0000
7ecd689
 
 
 
 
 
 
 
 
 
 
35e0000
250a932
be487a3
 
 
 
250a932
7ecd689
 
be487a3
 
7ecd689
be487a3
 
7ecd689
be487a3
15b96ac
be487a3
35e0000
be487a3
11349bd
be487a3
5b51e25
cb4c634
be487a3
8e34f80
ee4f4a6
7ecd689
1d153c2
be487a3
 
1d153c2
 
ee4f4a6
4c0fb4c
be487a3
b2e571e
1d153c2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import torch
from PIL import Image
import pandas as pd
from lavis.models import load_model_and_preprocess
from lavis.processors import load_processor
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
import tensorflow as tf
import tensorflow_hub as hub
import io
import os 
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import tempfile 
import shutil
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Load model and preprocessors for Image-Text Matching (LAVIS)
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model_itm, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True)

# Load tokenizer and model for Image Captioning (TextCaps)
git_processor_large_textcaps = AutoProcessor.from_pretrained("microsoft/git-large-r-textcaps")
git_model_large_textcaps = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps")

# Load Universal Sentence Encoder model for textual similarity calculation
embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")

# Define a function to compute textual similarity between caption and statement
def compute_textual_similarity(caption, statement):
    # Convert caption and statement into sentence embeddings
    caption_embedding = embed([caption])[0].numpy()
    statement_embedding = embed([statement])[0].numpy()

    # Calculate cosine similarity between sentence embeddings
    similarity_score = cosine_similarity([caption_embedding], [statement_embedding])[0][0]
    return similarity_score

# Read statements from the external file 'statements.txt'
with open('statements.txt', 'r') as file:
    statements = file.read().splitlines()

# Function to compute ITM scores for the image-statement pair
def compute_itm_score(image, statement):
    logging.info('Starting compute_itm_score')
    pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
    img = vis_processors["eval"](pil_image.convert("RGB")).unsqueeze(0).to(device)
    # Pass the statement text directly to model_itm
    itm_output = model_itm({"image": img, "text_input": statement}, match_head="itm")
    itm_scores = torch.nn.functional.softmax(itm_output, dim=1)
    score = itm_scores[:, 1].item()
    logging.info('Finished compute_itm_score')
    return score

def generate_caption(processor, model, image):
    logging.info('Starting generate_caption')
    inputs = processor(images=image, return_tensors="pt").to(device)
    generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    logging.info('Finished generate_caption')
    return generated_caption

def save_dataframe_to_csv(df):
    csv_buffer = io.StringIO()
    df.to_csv(csv_buffer, index=False)
    csv_string = csv_buffer.getvalue()

    # Save the CSV string to a temporary file
    with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as temp_file:
        temp_file.write(csv_string)
        temp_file_path = temp_file.name # Get the file path

    # Return the file path (no need to reopen the file with "rb" mode)
    return temp_file_path

def process_images_and_statements(image_file):
    with tempfile.NamedTemporaryFile(delete=False) as temp_file:
        shutil.copyfileobj(image_file, temp_file)

    image = Image.open(temp_file.name)
    image = np.array(image)
    logging.info('Starting process_images_and_statements')

    # Generate the image caption
    generated_caption = caption_image(image)

    # Match the statements
    matched_statements = match_statements(image, statements)

    os.unlink(temp_file.name)  # Remove the temporary file

    return generated_caption, matched_statements

# Define Gradio interface
image_input = gr.inputs.Image(type="numpy", label="Upload Image")
outputs = [
    gr.outputs.Image(type="pil", label="Annotated Image"),
    gr.outputs.Textbox(label="Matched Statements"),
]

iface = gr.Interface(
    fn=process_images_and_statements,
    inputs=image_input,
    outputs=outputs,
    title="Image Captioning and Matching",
    description="Upload an image to generate a caption for the image and match the statements.",
    theme='sudeepshouche/minimalist'
)

# Launch Gradio app
iface.launch(debug=True)


# Launch Gradio app
iface.launch(debug=True)