robgonsalves's picture
inititial checkin
402ed5c verified
raw
history blame
1.4 kB
import gradio as gr
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
# Load model and processor
model = CLIPModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
def calculate_similarity(image, text_prompt):
# Process inputs
inputs = processor(text=text_prompt, images=image, return_tensors="pt", padding=True)
# Forward pass
outputs = model(**inputs)
# Normalize and calculate cosine similarity
image_features = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
text_features = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
cosine_similarity = torch.nn.functional.cosine_similarity(image_features, text_features)
return {"Cosine Similarity": cosine_similarity.item()}
# Set up Gradio interface
iface = gr.Interface(fn=calculate_similarity,
inputs=[gr.inputs.Image(type="pil"), gr.inputs.Textbox(label="Text Prompt")],
outputs=[gr.outputs.Label(label="Cosine Similarity")],
title="OpenClip Cosine Similarity Calculator",
description="Upload an image and provide a text prompt to calculate the cosine similarity.")
# Launch the interface locally for testing
iface.launch()