ktllc's picture
Update app.py
39113b9
raw
history blame
1.38 kB
import gradio as gr
import numpy as np
import clip
import torch
from PIL import Image
import base64
# Load the CLIP model
model, preprocess = clip.load("ViT-B/32")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()
# Define a function to find similarity
def find_similarity(base64_image, text_input):
# Decode the base64 image to bytes
image_bytes = base64.b64decode(base64_image)
# Convert the bytes to a PIL image
image = Image.open(BytesIO(image_bytes))
# Preprocess the image
image = preprocess(image).unsqueeze(0).to(device)
# Tokenize the text input
text_tokens = clip.tokenize([text_input]).to(device)
# Encode image and text features
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text_tokens)
# Calculate cosine similarity
similarity = (image_features @ text_features.T).squeeze(0).cpu().numpy()
return similarity
# Create a Gradio interface
iface = gr.Interface(
fn=find_similarity,
inputs=[
gr.inputs.Textbox(label="Base64 Image", lines=8),
"text"
],
outputs="number",
live=True,
interpretation="default",
title="CLIP Model Image-Text Cosine Similarity",
description="Upload a base64 image and enter text to find their cosine similarity.",
)
iface.launch()