File size: 5,261 Bytes
284dc33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

import gradio as gr
from transformers import GitProcessor, GitForCausalLM, BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import torch
from sentence_transformers import SentenceTransformer, util
import nltk

# Ensure nltk stopwords are downloaded
nltk.download('stopwords')

# Load the stop words from nltk
stop_words = set(stopwords.words('english'))

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the processor and model for GIT
processor_git = GitProcessor.from_pretrained("microsoft/git-base-coco")
model_git = GitForCausalLM.from_pretrained("microsoft/git-base-coco").to(device)

# Load the processor and model for BLIP
processor_blip = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model_blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)

# Load the SentenceTransformer model
model_sentence = SentenceTransformer('sentence-transformers/stsb-roberta-base')

# Define categories and associated prompts
category_prompts = {
    "nature": "Describe the beautiful aspects of the scene in nature. What might be happening outside of the picture?",
    "technology": "Explain how technology is being used here and what future implications it might have.",
    "kids": "Talk about the joy of childhood visible in this image, and imagine what the kids might do next."
}

# Function to generate prompt based on category
def generate_prompt(image_category):
    if image_category in category_prompts:
        return category_prompts[image_category]
    else:
        return "Describe the image in detail and predict what could be happening beyond it."

# Function to preprocess image
def preprocess_image(image):
    inputs = processor_git(images=image, return_tensors="pt")
    return inputs['pixel_values'].to(device), image  # Return both pixel values and image for display

# Function to generate caption with GIT model
def generate_caption_git(image, category):
    # Preprocess image
    pixel_values, processed_image = preprocess_image(image)

    # Generate prompt for the category
    prompt = generate_prompt(category)

    # Generate caption with GIT model
    inputs = processor_git(text=prompt, images=image, return_tensors="pt")
    inputs['pixel_values'] = pixel_values  # Use preprocessed pixel values

    # Move inputs to GPU
    inputs = {key: val.to(device) for key, val in inputs.items()}

    # Generate caption
    generated_ids = model_git.generate(
        pixel_values=inputs['pixel_values'],
        max_length=300,
        num_beams=5,
        repetition_penalty=2.5
    )
    generated_text = processor_git.decode(generated_ids[0], skip_special_tokens=True)

    return generated_text, processed_image

# Function to generate final caption using BLIP model
def generate_caption_blip(image, git_caption):
    # Preprocess image for BLIP
    inputs = processor_blip(images=image, text=git_caption, return_tensors="pt").to(device)

    # Generate final caption with BLIP
    generated_ids = model_blip.generate(
        inputs['pixel_values'],
        max_length=300,  # Increase to allow for longer captions
        num_beams=5,
        repetition_penalty=2.5,
        length_penalty=1.5,  # Encourage longer outputs
        min_length=120,  # Set minimum length to ensure at least 150 words
        early_stopping=True
    )
    final_caption = processor_blip.decode(generated_ids[0], skip_special_tokens=True)

    return final_caption

# Function to compute semantic similarity score
def compute_semantic_similarity(generated_caption, candidate_answer):
    # Encode the sentences to get their embeddings
    generated_embedding = model_sentence.encode(generated_caption, convert_to_tensor=True)
    candidate_embedding = model_sentence.encode(candidate_answer, convert_to_tensor=True)

    # Compute the cosine similarity
    similarity_score = util.pytorch_cos_sim(generated_embedding, candidate_embedding).item()

    # Scale the similarity score to a percentage (0-100)
    return similarity_score * 100

# Gradio function
def gradio_interface(image, category, candidate_answer):
    # Generate initial caption using GIT
    git_caption, _ = generate_caption_git(image, category)

    # Generate final caption using BLIP based on GIT output
    final_caption = generate_caption_blip(image, git_caption)

    # Compute the similarity between BLIP caption and candidate answer
    similarity_score = compute_semantic_similarity(final_caption, candidate_answer)

    return final_caption, similarity_score

# Create Gradio interface
image_input = gr.Image(type="pil")
category_input = gr.Dropdown(choices=["nature", "technology", "kids"], label="Select Category")
candidate_input = gr.Textbox(label="Enter your answer", placeholder="Type your answer here...")

outputs = [gr.Textbox(label="BLIP Caption"),
           gr.Textbox(label="Semantic Similarity Score")]

# Launch the Gradio interface
gr.Interface(fn=gradio_interface, inputs=[image_input, category_input, candidate_input], outputs=outputs, title="Image Captioning with BLIP and Semantic Similarity", description="Upload an image, select a category, and input your answer to compare with the BLIP-generated caption.").launch()