Spaces:
Build error
Build error
import gradio as gr | |
from transformers import GitProcessor, GitForCausalLM, BlipProcessor, BlipForConditionalGeneration | |
from PIL import Image | |
import torch | |
from sentence_transformers import SentenceTransformer, util | |
import nltk | |
# Ensure nltk stopwords are downloaded | |
nltk.download('stopwords') | |
# Load the stop words from nltk | |
stop_words = set(stopwords.words('english')) | |
# Device setup | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load the processor and model for GIT | |
processor_git = GitProcessor.from_pretrained("microsoft/git-base-coco") | |
model_git = GitForCausalLM.from_pretrained("microsoft/git-base-coco").to(device) | |
# Load the processor and model for BLIP | |
processor_blip = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
model_blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device) | |
# Load the SentenceTransformer model | |
model_sentence = SentenceTransformer('sentence-transformers/stsb-roberta-base') | |
# Define categories and associated prompts | |
category_prompts = { | |
"nature": "Describe the beautiful aspects of the scene in nature. What might be happening outside of the picture?", | |
"technology": "Explain how technology is being used here and what future implications it might have.", | |
"kids": "Talk about the joy of childhood visible in this image, and imagine what the kids might do next." | |
} | |
# Function to generate prompt based on category | |
def generate_prompt(image_category): | |
if image_category in category_prompts: | |
return category_prompts[image_category] | |
else: | |
return "Describe the image in detail and predict what could be happening beyond it." | |
# Function to preprocess image | |
def preprocess_image(image): | |
inputs = processor_git(images=image, return_tensors="pt") | |
return inputs['pixel_values'].to(device), image # Return both pixel values and image for display | |
# Function to generate caption with GIT model | |
def generate_caption_git(image, category): | |
# Preprocess image | |
pixel_values, processed_image = preprocess_image(image) | |
# Generate prompt for the category | |
prompt = generate_prompt(category) | |
# Generate caption with GIT model | |
inputs = processor_git(text=prompt, images=image, return_tensors="pt") | |
inputs['pixel_values'] = pixel_values # Use preprocessed pixel values | |
# Move inputs to GPU | |
inputs = {key: val.to(device) for key, val in inputs.items()} | |
# Generate caption | |
generated_ids = model_git.generate( | |
pixel_values=inputs['pixel_values'], | |
max_length=300, | |
num_beams=5, | |
repetition_penalty=2.5 | |
) | |
generated_text = processor_git.decode(generated_ids[0], skip_special_tokens=True) | |
return generated_text, processed_image | |
# Function to generate final caption using BLIP model | |
def generate_caption_blip(image, git_caption): | |
# Preprocess image for BLIP | |
inputs = processor_blip(images=image, text=git_caption, return_tensors="pt").to(device) | |
# Generate final caption with BLIP | |
generated_ids = model_blip.generate( | |
inputs['pixel_values'], | |
max_length=300, # Increase to allow for longer captions | |
num_beams=5, | |
repetition_penalty=2.5, | |
length_penalty=1.5, # Encourage longer outputs | |
min_length=120, # Set minimum length to ensure at least 150 words | |
early_stopping=True | |
) | |
final_caption = processor_blip.decode(generated_ids[0], skip_special_tokens=True) | |
return final_caption | |
# Function to compute semantic similarity score | |
def compute_semantic_similarity(generated_caption, candidate_answer): | |
# Encode the sentences to get their embeddings | |
generated_embedding = model_sentence.encode(generated_caption, convert_to_tensor=True) | |
candidate_embedding = model_sentence.encode(candidate_answer, convert_to_tensor=True) | |
# Compute the cosine similarity | |
similarity_score = util.pytorch_cos_sim(generated_embedding, candidate_embedding).item() | |
# Scale the similarity score to a percentage (0-100) | |
return similarity_score * 100 | |
# Gradio function | |
def gradio_interface(image, category, candidate_answer): | |
# Generate initial caption using GIT | |
git_caption, _ = generate_caption_git(image, category) | |
# Generate final caption using BLIP based on GIT output | |
final_caption = generate_caption_blip(image, git_caption) | |
# Compute the similarity between BLIP caption and candidate answer | |
similarity_score = compute_semantic_similarity(final_caption, candidate_answer) | |
return final_caption, similarity_score | |
# Create Gradio interface | |
image_input = gr.Image(type="pil") | |
category_input = gr.Dropdown(choices=["nature", "technology", "kids"], label="Select Category") | |
candidate_input = gr.Textbox(label="Enter your answer", placeholder="Type your answer here...") | |
outputs = [gr.Textbox(label="BLIP Caption"), | |
gr.Textbox(label="Semantic Similarity Score")] | |
# Launch the Gradio interface | |
gr.Interface(fn=gradio_interface, inputs=[image_input, category_input, candidate_input], outputs=outputs, title="Image Captioning with BLIP and Semantic Similarity", description="Upload an image, select a category, and input your answer to compare with the BLIP-generated caption.").launch() | |