ktllc's picture
Update app.py
3b53215 verified
import gradio as gr
import numpy as np
import clip
import torch
from PIL import Image
import base64
from io import BytesIO
from decimal import Decimal
# Load the CLIP model
model, preprocess = clip.load("ViT-L/14@336px")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()
# Define a function to find similarity
def find_similarity(base64_image, text_input):
# Decode the base64 image to bytes
image_bytes = base64.b64decode(base64_image)
# Convert the bytes to a PIL image
image = Image.open(BytesIO(image_bytes))
# Preprocess the image
image = preprocess(image).unsqueeze(0).to(device)
# Prepare input text
text_tokens = clip.tokenize([text_input]).to(device)
# Encode image and text features
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text_tokens)
# Normalize features and calculate similarity
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()
return similarity
# Create a Gradio interface
iface = gr.Interface(
fn=find_similarity,
inputs=[
gr.Textbox(label="Base64 Image", lines=8),
gr.Textbox(label="Text Input")
],
outputs="number",
live=True,
title="CLIP Model Image-Text Cosine Similarity",
description="Upload a base64 image and enter text to find their cosine similarity.",
)
iface.launch()