Spaces:
Sleeping
Sleeping
File size: 9,291 Bytes
1272949 d85d411 1272949 d85d411 1272949 d85d411 1272949 db03f5d 1272949 d85d411 1272949 db03f5d 1272949 db03f5d 1272949 d85d411 1272949 db03f5d 1272949 d85d411 1272949 db03f5d 1272949 db03f5d 1272949 d85d411 1272949 db03f5d 1272949 db03f5d 1272949 db03f5d 1272949 d85d411 1272949 d85d411 1272949 d85d411 1272949 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
import gradio as gr
from numpy import empty
import open_clip
import torch
import PIL.Image as Image
from io import BytesIO
import base64
# Set device to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the OpenCLIP model and the necessary preprocessors
# openclip_model = 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
# openclip_model = 'laion/CLIP-ViT-B-16-laion2B-s34B-b88K'
openclip_model = 'laion/CLIP-ViT-L-14-laion2B-s32B-b82K'
openclip_model = 'hf-hub:' + openclip_model
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
model_name=openclip_model,
device=device
)
# Define function to generate text embeddings
def generate_text_embedding(text_data):
"""
Generate embeddings for text data using the OpenCLIP model.
Parameters
----------
text_data : str or tuple of str
Text data to embed.
Returns
-------
text_embeddings : list of str
List of text embeddings.
"""
# Embed text data
text_embeddings = []
empty_data_indices = []
if text_data:
# If text_data is a string, convert to list of strings
if isinstance(text_data, str):
text_data = [text_data]
# If text_data is a tuple of strings, convert to list of strings
if isinstance(text_data, tuple):
text_data = list(text_data)
# If text_data is not a list of strings, raise error
if not isinstance(text_data, list):
raise TypeError("text_data must be a string or a tuple of strings.")
# Keep track of indices of empty text strings
empty_data_indices = [i for i, text in enumerate(text_data) if text == ""]
# Remove empty text strings
text_data = [text for text in text_data if text != ""]
if text_data:
# Tokenize text_data and convert to tensor
text_data = open_clip.tokenize(text_data).to(device)
# Generate text embeddings
with torch.no_grad():
text_embeddings = model.encode_text(text_data)
# Convert embeddings to list of strings
text_embeddings = [embedding.detach().cpu().numpy().tolist() for embedding in text_embeddings]
# Insert empty strings at indices of empty text strings
for i in empty_data_indices:
text_embeddings.insert(i, "")
return text_embeddings
# Define function to generate image embeddings
def generate_image_embedding(image_data):
"""
Generate embeddings for image data using the OpenCLIP model.
Parameters
----------
image_data : PIL.Image.Image or tuple of PIL.Image.Image
Image data to embed.
Returns
-------
image_embeddings : list of str
List of image embeddings.
"""
# Embed image data
image_embeddings = []
empty_data_indices = []
if image_data:
# If image_data is a single PIL image, convert to list of PIL images
if isinstance(image_data, Image.Image):
image_data = [image_data]
# If image_data is a tuple of images, convert to list of images
if isinstance(image_data, tuple):
image_data = list(image_data)
# Keep track of indices of None images
empty_data_indices = [i for i, img in enumerate(image_data) if img is None]
# Remove None images
image_data = [img for img in image_data if img is not None]
if image_data:
# Preprocess image_data and convert to tensor
image_data = [preprocess_val(img).unsqueeze(0) for img in image_data]
image_data = torch.stack(image_data).squeeze(1).to(device)
# Generate image embeddings
with torch.no_grad():
image_embeddings = model.encode_image(image_data)
# Convert embeddings to list of strings
image_embeddings = [embedding.detach().cpu().numpy().tolist() for embedding in image_embeddings]
# Insert empty strings at indices of empty images
for i in empty_data_indices:
image_embeddings.insert(i, "")
return image_embeddings
# Define function to generate embeddings
def generate_embedding(text_data, image_data, image_data_base64):
"""
Generate embeddings for text and image data using the OpenCLIP model.
Parameters
----------
text_data : str or tuple of str
Text data to embed.
image_data : PIL.Image.Image or tuple of PIL.Image.Image
Image data to embed.
image_data_base64 : str or tuple of str
Base64 encoded image data to embed.
Returns
-------
text_embeddings : list of str
List of text embeddings.
image_embeddings : list of str
List of image embeddings.
similarity : list of str
List of cosine similarity between text and image embeddings.
image_data_base64_embeddings : str or tuple of str
List of image embeddings for base64 encoded image data.
"""
# Embed text data
text_embeddings = generate_text_embedding(text_data)
# Embed image data
image_embeddings = generate_image_embedding(image_data)
# Calculate cosine similarity between text and image embeddings
similarity = []
empty_data_indices = []
if text_embeddings and image_embeddings:
# Filter out embedding pairs with either empty text or image embeddings, tracking indices of empty embeddings
text_embeddings_filtered = []
image_embeddings_filtered = []
for i, (text_embedding, image_embedding) in enumerate(zip(text_embeddings, image_embeddings)):
if text_embedding != "" and image_embedding != "":
text_embeddings_filtered.append(text_embedding)
image_embeddings_filtered.append(image_embedding)
else:
empty_data_indices.append(i)
# Calculate cosine similarity if there are any non-empty embedding pairs
if image_embeddings_filtered and text_embeddings_filtered:
# Convert lists back to tensors for processing
text_embeddings_tensor = torch.tensor(text_embeddings_filtered)
image_embeddings_tensor = torch.tensor(image_embeddings_filtered)
# Normalize the embeddings
text_embedding_norm = text_embeddings_tensor / text_embeddings_tensor.norm(dim=-1, keepdim=True)
image_embedding_norm = image_embeddings_tensor / image_embeddings_tensor.norm(dim=-1, keepdim=True)
# Calculate cosine similarity
similarity = torch.nn.functional.cosine_similarity(text_embedding_norm, image_embedding_norm, dim=-1)
# Convert to percentage as text
similarity = [f"{sim.item() * 100:.2f}%" for sim in similarity]
# Insert empty text strings in similarity
for i in empty_data_indices:
similarity.insert(i, "")
# Embed base64 encoded image data
decoded_image_data = []
if image_data_base64:
# If image_data_base64 is a string, convert to list of strings
if isinstance(image_data_base64, str):
image_data_base64 = [image_data_base64]
# If image_data_base64 is a tuple of strings, convert to list of strings
if isinstance(image_data_base64, tuple):
image_data_base64 = list(image_data_base64)
# If image_data_base64 is not a list of strings, raise error
if not isinstance(image_data_base64, list):
raise TypeError("image_data_base64 must be a string or a tuple of strings.")
# Keep track of indices of empty image strings
empty_data_indices = [i for i, img in enumerate(image_data_base64) if img == ""]
# Remove empty image strings
image_data_base64 = [img for img in image_data_base64 if img != ""]
if image_data_base64:
# Decode base64 encoded image data
decoded_image_data = [Image.open(BytesIO(base64.b64decode(img))) for img in image_data_base64]
# Insert empty strings at indices of empty image strings
for i in empty_data_indices:
decoded_image_data.insert(i, None)
image_data_base64_embeddings = generate_image_embedding(tuple(decoded_image_data))
return (text_embeddings, image_embeddings, similarity, image_data_base64_embeddings)
# Define Gradio interface
demo = gr.Interface(
fn=generate_embedding,
inputs=[
gr.Textbox(lines=5, max_lines=5, placeholder="Enter Text Here...", label="Text to Embed"),
gr.Image(height=512, type="pil", label="Image to Embed"),
gr.Textbox(label="Base64 Encoded Image", visible=False)
],
outputs=[
gr.Textbox(lines=5, max_lines=5, label="Text Embedding", autoscroll=False),
gr.Textbox(lines=5, max_lines=5, label="Image Embedding", autoscroll=False),
gr.Textbox(label="Cosine Similarity"),
gr.Textbox(label="Embedding of Base64 Encoded Images", visible=False)
],
title="OpenCLIP Embedding Generator",
description="Generate embeddings using OpenCLIP model for text and images.",
allow_flagging="never",
batch=True,
api_name="embed"
)
# Enable queueing and launch the app
if __name__ == "__main__":
demo.queue().launch(show_api=True)
|