Dhairya2
Deploy Gradio app
284dc33
import gradio as gr
from transformers import GitProcessor, GitForCausalLM, BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import torch
from sentence_transformers import SentenceTransformer, util
import nltk
# Ensure nltk stopwords are downloaded
nltk.download('stopwords')
# Load the stop words from nltk
stop_words = set(stopwords.words('english'))
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the processor and model for GIT
processor_git = GitProcessor.from_pretrained("microsoft/git-base-coco")
model_git = GitForCausalLM.from_pretrained("microsoft/git-base-coco").to(device)
# Load the processor and model for BLIP
processor_blip = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model_blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
# Load the SentenceTransformer model
model_sentence = SentenceTransformer('sentence-transformers/stsb-roberta-base')
# Define categories and associated prompts
category_prompts = {
"nature": "Describe the beautiful aspects of the scene in nature. What might be happening outside of the picture?",
"technology": "Explain how technology is being used here and what future implications it might have.",
"kids": "Talk about the joy of childhood visible in this image, and imagine what the kids might do next."
}
# Function to generate prompt based on category
def generate_prompt(image_category):
if image_category in category_prompts:
return category_prompts[image_category]
else:
return "Describe the image in detail and predict what could be happening beyond it."
# Function to preprocess image
def preprocess_image(image):
inputs = processor_git(images=image, return_tensors="pt")
return inputs['pixel_values'].to(device), image # Return both pixel values and image for display
# Function to generate caption with GIT model
def generate_caption_git(image, category):
# Preprocess image
pixel_values, processed_image = preprocess_image(image)
# Generate prompt for the category
prompt = generate_prompt(category)
# Generate caption with GIT model
inputs = processor_git(text=prompt, images=image, return_tensors="pt")
inputs['pixel_values'] = pixel_values # Use preprocessed pixel values
# Move inputs to GPU
inputs = {key: val.to(device) for key, val in inputs.items()}
# Generate caption
generated_ids = model_git.generate(
pixel_values=inputs['pixel_values'],
max_length=300,
num_beams=5,
repetition_penalty=2.5
)
generated_text = processor_git.decode(generated_ids[0], skip_special_tokens=True)
return generated_text, processed_image
# Function to generate final caption using BLIP model
def generate_caption_blip(image, git_caption):
# Preprocess image for BLIP
inputs = processor_blip(images=image, text=git_caption, return_tensors="pt").to(device)
# Generate final caption with BLIP
generated_ids = model_blip.generate(
inputs['pixel_values'],
max_length=300, # Increase to allow for longer captions
num_beams=5,
repetition_penalty=2.5,
length_penalty=1.5, # Encourage longer outputs
min_length=120, # Set minimum length to ensure at least 150 words
early_stopping=True
)
final_caption = processor_blip.decode(generated_ids[0], skip_special_tokens=True)
return final_caption
# Function to compute semantic similarity score
def compute_semantic_similarity(generated_caption, candidate_answer):
# Encode the sentences to get their embeddings
generated_embedding = model_sentence.encode(generated_caption, convert_to_tensor=True)
candidate_embedding = model_sentence.encode(candidate_answer, convert_to_tensor=True)
# Compute the cosine similarity
similarity_score = util.pytorch_cos_sim(generated_embedding, candidate_embedding).item()
# Scale the similarity score to a percentage (0-100)
return similarity_score * 100
# Gradio function
def gradio_interface(image, category, candidate_answer):
# Generate initial caption using GIT
git_caption, _ = generate_caption_git(image, category)
# Generate final caption using BLIP based on GIT output
final_caption = generate_caption_blip(image, git_caption)
# Compute the similarity between BLIP caption and candidate answer
similarity_score = compute_semantic_similarity(final_caption, candidate_answer)
return final_caption, similarity_score
# Create Gradio interface
image_input = gr.Image(type="pil")
category_input = gr.Dropdown(choices=["nature", "technology", "kids"], label="Select Category")
candidate_input = gr.Textbox(label="Enter your answer", placeholder="Type your answer here...")
outputs = [gr.Textbox(label="BLIP Caption"),
gr.Textbox(label="Semantic Similarity Score")]
# Launch the Gradio interface
gr.Interface(fn=gradio_interface, inputs=[image_input, category_input, candidate_input], outputs=outputs, title="Image Captioning with BLIP and Semantic Similarity", description="Upload an image, select a category, and input your answer to compare with the BLIP-generated caption.").launch()