import gradio as gr
from transformers import GitProcessor, GitForCausalLM, BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import torch
from sentence_transformers import SentenceTransformer, util
import nltk

# Ensure nltk stopwords are downloaded
nltk.download('stopwords')

# Load the stop words from nltk
stop_words = set(stopwords.words('english'))

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the processor and model for GIT
processor_git = GitProcessor.from_pretrained("microsoft/git-base-coco")
model_git = GitForCausalLM.from_pretrained("microsoft/git-base-coco").to(device)

# Load the processor and model for BLIP
processor_blip = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model_blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)

# Load the SentenceTransformer model
model_sentence = SentenceTransformer('sentence-transformers/stsb-roberta-base')

# Define categories and associated prompts
category_prompts = {
    "nature": "Describe the beautiful aspects of the scene in nature. What might be happening outside of the picture?",
    "technology": "Explain how technology is being used here and what future implications it might have.",
    "kids": "Talk about the joy of childhood visible in this image, and imagine what the kids might do next."
}

# Function to generate prompt based on category
def generate_prompt(image_category):
    if image_category in category_prompts:
        return category_prompts[image_category]
    else:
        return "Describe the image in detail and predict what could be happening beyond it."

# Function to preprocess image
def preprocess_image(image):
    inputs = processor_git(images=image, return_tensors="pt")
    return inputs['pixel_values'].to(device), image  # Return both pixel values and image for display

# Function to generate caption with GIT model
def generate_caption_git(image, category):
    # Preprocess image
    pixel_values, processed_image = preprocess_image(image)

    # Generate prompt for the category
    prompt = generate_prompt(category)

    # Generate caption with GIT model
    inputs = processor_git(text=prompt, images=image, return_tensors="pt")
    inputs['pixel_values'] = pixel_values  # Use preprocessed pixel values

    # Move inputs to GPU
    inputs = {key: val.to(device) for key, val in inputs.items()}

    # Generate caption
    generated_ids = model_git.generate(
        pixel_values=inputs['pixel_values'],
        max_length=300,
        num_beams=5,
        repetition_penalty=2.5
    )
    generated_text = processor_git.decode(generated_ids[0], skip_special_tokens=True)

    return generated_text, processed_image

# Function to generate final caption using BLIP model
def generate_caption_blip(image, git_caption):
    # Preprocess image for BLIP
    inputs = processor_blip(images=image, text=git_caption, return_tensors="pt").to(device)

    # Generate final caption with BLIP
    generated_ids = model_blip.generate(
        inputs['pixel_values'],
        max_length=300,  # Increase to allow for longer captions
        num_beams=5,
        repetition_penalty=2.5,
        length_penalty=1.5,  # Encourage longer outputs
        min_length=120,  # Set minimum length to ensure at least 150 words
        early_stopping=True
    )
    final_caption = processor_blip.decode(generated_ids[0], skip_special_tokens=True)

    return final_caption

# Function to compute semantic similarity score
def compute_semantic_similarity(generated_caption, candidate_answer):
    # Encode the sentences to get their embeddings
    generated_embedding = model_sentence.encode(generated_caption, convert_to_tensor=True)
    candidate_embedding = model_sentence.encode(candidate_answer, convert_to_tensor=True)

    # Compute the cosine similarity
    similarity_score = util.pytorch_cos_sim(generated_embedding, candidate_embedding).item()

    # Scale the similarity score to a percentage (0-100)
    return similarity_score * 100

# Gradio function
def gradio_interface(image, category, candidate_answer):
    # Generate initial caption using GIT
    git_caption, _ = generate_caption_git(image, category)

    # Generate final caption using BLIP based on GIT output
    final_caption = generate_caption_blip(image, git_caption)

    # Compute the similarity between BLIP caption and candidate answer
    similarity_score = compute_semantic_similarity(final_caption, candidate_answer)

    return final_caption, similarity_score

# Create Gradio interface
image_input = gr.Image(type="pil")
category_input = gr.Dropdown(choices=["nature", "technology", "kids"], label="Select Category")
candidate_input = gr.Textbox(label="Enter your answer", placeholder="Type your answer here...")

outputs = [gr.Textbox(label="BLIP Caption"),
           gr.Textbox(label="Semantic Similarity Score")]

# Launch the Gradio interface
gr.Interface(fn=gradio_interface, inputs=[image_input, category_input, candidate_input], outputs=outputs, title="Image Captioning with BLIP and Semantic Similarity", description="Upload an image, select a category, and input your answer to compare with the BLIP-generated caption.").launch()