robgonsalves's picture
fixed a bug, and misc changes
aa3546b verified
raw
history blame
1.74 kB
import gradio as gr
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
# Load model and processor
model = CLIPModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
def calculate_similarity(image, text_prompt):
# Ensure text_prompt is a string
if not isinstance(text_prompt, str):
text_prompt = str(text_prompt)
# Process inputs
inputs = processor(images=image, text=text_prompt, return_tensors="pt", padding=True)
# Forward pass
outputs = model(**inputs)
# Normalize and calculate cosine similarity
image_features = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
text_features = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
cosine_similarity = torch.nn.functional.cosine_similarity(image_features, text_features)
# Adjusting the similarity score
adjusted_similarity = cosine_similarity.item() * 3 * 100
clipped_similarity = min(adjusted_similarity, 99.99)
formatted_similarity = f"According to OpenCLIP, the image and the text prompt are {clipped_similarity:.2f}% similar."
return formatted_similarity
# Set up Gradio interface
iface = gr.Interface(
fn=calculate_similarity,
inputs=[
gr.Image(type="pil", label="Upload Image", height=512),
gr.Textbox(label="Text Prompt")
],
outputs=gr.Text(),
allow_flagging="never",
title="OpenClip Cosine Similarity Calculator",
description="Provide a text prompt and upload an image to calculate the cosine similarity."
)
# Launch the interface with a public link for sharing online
iface.launch(share=True)