File size: 1,741 Bytes
39113b9 30d5af0 69aa3f2 39113b9 82eb2a3 d3bd556 30d5af0 f2e596b 30d5af0 39113b9 9f13edb 30d5af0 69aa3f2 39113b9 30d5af0 2b19e7c 39113b9 30d5af0 39113b9 30d5af0 d3bd556 fc5f1b5 625973c 2b19e7c 30d5af0 39113b9 30d5af0 c652364 39113b9 c652364 fc5f1b5 30d5af0 39113b9 30d5af0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import gradio as gr
import numpy as np
import clip
import torch
from PIL import Image
import base64
from io import BytesIO
from decimal import Decimal
# Load the CLIP model
model, preprocess = clip.load("ViT-B/32")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()
# Define a function to find similarity
def find_similarity(base64_image, text_input):
# Decode the base64 image to bytes
image_bytes = base64.b64decode(base64_image)
# Convert the bytes to a PIL image
image = Image.open(BytesIO(image_bytes))
# Preprocess the image
image = preprocess(image).unsqueeze(0).to(device)
# Tokenize the text input
text_tokens = clip.tokenize([text_input]).to(device)
# Encode image and text features
with torch no grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text_tokens)
# Calculate cosine similarity
similarity = (image_features @ text_features.T).squeeze(0).cpu().numpy()
# Convert each element in the similarity array to Decimal
similarity_decimal = [Decimal(float(score)) for score in similarity]
# Format Decimal values as floats with specific precision (e.g., 4 decimal places)
formatted_similarity = [f'{float(score):.5f}' for score in similarity_decimal]
return formatted_similarity
# Create a Gradio interface
iface = gr.Interface(
fn=find_similarity,
inputs=[
gr.inputs.Textbox(label="Base64 Image", lines=8),
"text"
],
outputs="text",
live=True,
interpretation="default",
title="CLIP Model Image-Text Cosine Similarity",
description="Upload a base64 image and enter text to find their cosine similarity.",
)
iface.launch()
|