import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import pickle
import gradio as gr

class ImageDataset(Dataset):
    def __init__(self, image_dir, processor):
        self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.processor = processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        return self.processor(images=image, return_tensors="pt")['pixel_values'][0]

def get_and_save_clip_embeddings(image_dir, output_file, batch_size=32, device='cuda'):
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    dataset = ImageDataset(image_dir, processor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    all_embeddings = []
    image_paths = []

    model.eval()
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            batch = batch.to(device)
            embeddings = model.get_image_features(pixel_values=batch)
            all_embeddings.append(embeddings.cpu().numpy())
            start_idx = batch_idx * batch_size
            end_idx = start_idx + len(batch)
            image_paths.extend(dataset.image_paths[start_idx:end_idx])

    all_embeddings = np.concatenate(all_embeddings)

    with open(output_file, 'wb') as f:
        pickle.dump({'embeddings': all_embeddings, 'image_paths': image_paths}, f)

# image_dir = "dataset/"
# output_file = "image_embeddings.pkl"
# batch_size = 32
# device = "cuda" if torch.cuda.is_available() else "cpu"

# get_and_save_clip_embeddings(image_dir, output_file, batch_size, device)


# APP

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

with open('image_embeddings.pkl', 'rb') as f:
        f = pickle.load(f)
        image_embeddings = f['embeddings']
        image_names = f['image_paths']
image_paths = './dataset'

def cosine_similarity(a, b):
    a = a / np.linalg.norm(a, axis=-1, keepdims=True)
    b = b / np.linalg.norm(b, axis=-1, keepdims=True)
    return np.dot(a, b.T)

def find_similar_images(text):
    inputs = processor(text=[text], return_tensors="pt", padding=True)
    with torch.no_grad():
        text_embedding = model.get_text_features(**inputs).cpu().numpy()

    similarities = cosine_similarity(text_embedding, image_embeddings)
    top_indices = np.argsort(similarities[0])[::-1][:4]
    top_images = [image_names[i] for i in top_indices]
    
    return top_images



text_input = gr.Textbox(label="Input text", placeholder="Enter the images description")
imgs_output = gr.Gallery(label="Top 4 most similar images")

intf = gr.Interface(
    fn=find_similar_images,
    inputs=gr.Textbox(label="Input Text", placeholder="Enter a description"),
    outputs=gr.Gallery(label="Top 4 Similar Images"),
)
extra_text = gr.Markdown("""
The dataset contains images of dogs, rabbits, sharks, and deer.
Displaying the images might take a couple of seconds.
""")
with gr.Blocks() as app:
    intf.render()
    extra_text.render()
app.launch(share=True)