Spaces:
Sleeping
Sleeping
File size: 6,342 Bytes
1272949 db03f5d 1272949 db03f5d 1272949 db03f5d 1272949 db03f5d 1272949 db03f5d 1272949 db03f5d 1272949 db03f5d 1272949 db03f5d 1272949 db03f5d 1272949 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import gradio as gr
from numpy import empty
import open_clip
from regex import F
import torch
import json
import PIL
# Set device to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the OpenCLIP model and the necessary preprocessors
# openclip_model = 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
# openclip_model = 'laion/CLIP-ViT-B-16-laion2B-s34B-b88K'
openclip_model = 'laion/CLIP-ViT-L-14-laion2B-s32B-b82K'
openclip_model = 'hf-hub:' + openclip_model
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
model_name=openclip_model,
device=device
)
def generate_embedding(text_data, image_data):
"""
Generate embeddings for text and image data using the OpenCLIP model.
Parameters
----------
text_data : str or tuple of str
Text data to embed.
image_data : PIL.Image.Image or tuple of PIL.Image.Image
Image data to embed.
Returns
-------
text_embeddings : list of str
List of text embeddings.
image_embeddings : list of str
List of image embeddings.
similarity : list of str
List of cosine similarity between text and image embeddings.
"""
# Embed text data
text_embeddings = []
empty_data_indices = []
if text_data:
# If text_data is a string, convert to list of strings
if isinstance(text_data, str):
text_data = [text_data]
# If text_data is a tuple of strings, convert to list of strings
if isinstance(text_data, tuple):
text_data = list(text_data)
# Keep track of indices of empty text strings
empty_data_indices = [i for i, text in enumerate(text_data) if text == ""]
# Remove empty text strings
text_data = [text for text in text_data if text != ""]
if text_data:
# Tokenize text_data and convert to tensor
text_data = open_clip.tokenize(text_data).to(device)
# Generate text embeddings
with torch.no_grad():
text_embeddings = model.encode_text(text_data)
# Convert embeddings to list of strings
text_embeddings = [embedding.detach().cpu().numpy().tolist() for embedding in text_embeddings]
# Insert empty strings at indices of empty text strings
for i in empty_data_indices:
text_embeddings.insert(i, "")
# Embed image data
image_embeddings = []
empty_data_indices = []
if image_data:
# If image_data is a single PIL image, convert to list of PIL images
if isinstance(image_data, PIL.Image.Image):
image_data = [image_data]
# If image_data is a tuple of images, convert to list of images
if isinstance(image_data, tuple):
image_data = list(image_data)
# Keep track of indices of None images
empty_data_indices = [i for i, img in enumerate(image_data) if img is None]
# Remove None images
image_data = [img for img in image_data if img is not None]
if image_data:
# Preprocess image_data and convert to tensor
image_data = [preprocess_val(img).unsqueeze(0) for img in image_data]
image_data = torch.stack(image_data).squeeze(1).to(device)
# Generate image embeddings
with torch.no_grad():
image_embeddings = model.encode_image(image_data)
# Convert embeddings to list of strings
image_embeddings = [embedding.detach().cpu().numpy().tolist() for embedding in image_embeddings]
# Insert empty strings at indices of empty images
for i in empty_data_indices:
image_embeddings.insert(i, "")
# Calculate cosine similarity between text and image embeddings
similarity = []
empty_data_indices = []
if text_embeddings and image_embeddings:
# Filter out embedding pairs with either empty text or image embeddings, tracking indices of empty embeddings
text_embeddings_filtered = []
image_embeddings_filtered = []
for i, (text_embedding, image_embedding) in enumerate(zip(text_embeddings, image_embeddings)):
if text_embedding != "" and image_embedding != "":
text_embeddings_filtered.append(text_embedding)
image_embeddings_filtered.append(image_embedding)
else:
empty_data_indices.append(i)
# Calculate cosine similarity if there are any non-empty embedding pairs
if image_embeddings_filtered and text_embeddings_filtered:
# Convert lists back to tensors for processing
text_embeddings_tensor = torch.tensor(text_embeddings_filtered)
image_embeddings_tensor = torch.tensor(image_embeddings_filtered)
# Normalize the embeddings
text_embedding_norm = text_embeddings_tensor / text_embeddings_tensor.norm(dim=-1, keepdim=True)
image_embedding_norm = image_embeddings_tensor / image_embeddings_tensor.norm(dim=-1, keepdim=True)
# Calculate cosine similarity
similarity = torch.nn.functional.cosine_similarity(text_embedding_norm, image_embedding_norm, dim=-1)
# Convert to percentage as text
similarity = [f"{sim.item() * 100:.2f}%" for sim in similarity]
# Insert empty text strings in similarity
for i in empty_data_indices:
similarity.insert(i, "")
return (text_embeddings, image_embeddings, similarity)
# Define Gradio interface
demo = gr.Interface(
fn=generate_embedding,
inputs=[
gr.Textbox(lines=5, max_lines=5, placeholder="Enter Text Here...", label="Text to Embed"),
gr.Image(height=512, type="pil", label="Image to Embed")
],
outputs=[
gr.Textbox(lines=5, max_lines=5, label="Text Embedding", autoscroll=False),
gr.Textbox(lines=5, max_lines=5, label="Image Embedding", autoscroll=False),
gr.Textbox(label="Cosine Similarity")
],
title="OpenCLIP Embedding Generator",
description="Generate embeddings using OpenCLIP model for text and images.",
allow_flagging="never",
batch=True,
api_name="embed"
)
# Enable queueing and launch the app
if __name__ == "__main__":
demo.queue().launch(show_api=True)
|