Spaces:
Sleeping
Sleeping
File size: 8,168 Bytes
1f9e30b 1272949 d85d411 1272949 e957495 1272949 2d8aec0 1f9e30b 1272949 2d8aec0 1272949 d85d411 516f5f6 1f9e30b 1272949 d85d411 1272949 db03f5d 1272949 d85d411 1272949 db03f5d 1272949 2d8aec0 1272949 db03f5d 1272949 d85d411 2d8aec0 d85d411 2d8aec0 d85d411 1272949 db03f5d 1272949 d85d411 1272949 db03f5d 1272949 2d8aec0 1272949 db03f5d 1272949 d85d411 2d8aec0 d85d411 1272949 db03f5d 1272949 2d8aec0 1272949 db03f5d 1272949 2d8aec0 1272949 2d8aec0 1272949 db03f5d 1272949 1f9e30b 1272949 2d8aec0 d85d411 1272949 d85d411 1f9e30b 1272949 ebe9bed 2d8aec0 1272949 435181d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
from typing import Union
import gradio as gr
from numpy import empty
import open_clip
import torch
import PIL.Image as Image
# Set device to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"PyTorch Device {device}")
# Load the OpenCLIP model and the necessary preprocessors
# openclip_model = 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
# openclip_model = 'laion/CLIP-ViT-B-16-laion2B-s34B-b88K'
openclip_model_name = "laion/CLIP-ViT-L-14-laion2B-s32B-b82K"
openclip_model = "hf-hub:" + openclip_model_name
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
model_name=openclip_model, device=device
)
# Define function to generate text embeddings
# @spaces.GPU
def generate_text_embedding(text_data: Union[str, tuple[str]]) -> list[str]:
"""
Generate embeddings for text data using the OpenCLIP model.
Parameters
----------
text_data : str or tuple of str
Text data to embed.
Returns
-------
text_embeddings : list of str
List of text embeddings.
"""
# Embed text data
text_embeddings = []
empty_data_indices = []
if text_data:
# If text_data is a string, convert to list of strings
if isinstance(text_data, str):
text_data = [text_data]
# If text_data is a tuple of strings, convert to list of strings
if isinstance(text_data, tuple):
text_data = list(text_data)
# If text_data is not a list of strings, raise error
if not isinstance(text_data, list):
raise TypeError("text_data must be a string or a tuple of strings.")
# Keep track of indices of empty text strings
empty_data_indices = [i for i, text in enumerate(text_data) if text == ""]
# Remove empty text strings
text_data = [text for text in text_data if text != ""]
if text_data:
# Tokenize text_data and convert to tensor
text_data = open_clip.tokenize(text_data).to(device)
# Generate text embeddings
with torch.no_grad():
text_embeddings = model.encode_text(text_data)
# Convert embeddings to list of strings
text_embeddings = [
embedding.detach().cpu().numpy().tolist()
for embedding in text_embeddings
]
# Insert empty strings at indices of empty text strings
for i in empty_data_indices:
text_embeddings.insert(i, "")
return text_embeddings
# Define function to generate image embeddings
def generate_image_embedding(
image_data: Union[Image.Image, tuple[Image.Image]]
) -> list[str]:
"""
Generate embeddings for image data using the OpenCLIP model.
Parameters
----------
image_data : PIL.Image.Image or tuple of PIL.Image.Image
Image data to embed.
Returns
-------
image_embeddings : list of str
List of image embeddings.
"""
# Embed image data
image_embeddings = []
empty_data_indices = []
if image_data:
# If image_data is a single PIL image, convert to list of PIL images
if isinstance(image_data, Image.Image):
image_data = [image_data]
# If image_data is a tuple of images, convert to list of images
if isinstance(image_data, tuple):
image_data = list(image_data)
# Keep track of indices of None images
empty_data_indices = [i for i, img in enumerate(image_data) if img is None]
# Remove None images
image_data = [img for img in image_data if img is not None]
if image_data:
# Preprocess image_data and convert to tensor
image_data = [preprocess_val(img).unsqueeze(0) for img in image_data]
image_data = torch.stack(image_data).squeeze(1).to(device)
# Generate image embeddings
with torch.no_grad():
image_embeddings = model.encode_image(image_data)
# Convert embeddings to list of strings
image_embeddings = [
embedding.detach().cpu().numpy().tolist()
for embedding in image_embeddings
]
# Insert empty strings at indices of empty images
for i in empty_data_indices:
image_embeddings.insert(i, "")
return image_embeddings
# Define function to generate embeddings
def generate_embedding(
text_data: Union[str, tuple[str]],
image_data: Union[Image.Image, tuple[Image.Image]],
) -> tuple[list[str], list[str], list[str]]:
"""
Generate embeddings for text and image data using the OpenCLIP model.
Parameters
----------
text_data : str or tuple of str
Text data to embed.
image_data : PIL.Image.Image or tuple of PIL.Image.Image
Image data to embed.
Returns
-------
text_embeddings : list of str
List of text embeddings.
image_embeddings : list of str
List of image embeddings.
similarity : list of str
List of cosine similarity between text and image embeddings.
"""
# Embed text data
text_embeddings = generate_text_embedding(text_data)
# Embed image data
image_embeddings = generate_image_embedding(image_data)
# Calculate cosine similarity between text and image embeddings
similarity = []
empty_data_indices = []
if text_embeddings and image_embeddings:
# Filter out embedding pairs with either empty text or image embeddings, tracking indices of empty embeddings
text_embeddings_filtered = []
image_embeddings_filtered = []
for i, (text_embedding, image_embedding) in enumerate(
zip(text_embeddings, image_embeddings)
):
if text_embedding != "" and image_embedding != "":
text_embeddings_filtered.append(text_embedding)
image_embeddings_filtered.append(image_embedding)
else:
empty_data_indices.append(i)
# Calculate cosine similarity if there are any non-empty embedding pairs
if image_embeddings_filtered and text_embeddings_filtered:
# Convert lists back to tensors for processing
text_embeddings_tensor = torch.tensor(text_embeddings_filtered)
image_embeddings_tensor = torch.tensor(image_embeddings_filtered)
# Normalize the embeddings
text_embedding_norm = text_embeddings_tensor / text_embeddings_tensor.norm(
dim=-1, keepdim=True
)
image_embedding_norm = (
image_embeddings_tensor
/ image_embeddings_tensor.norm(dim=-1, keepdim=True)
)
# Calculate cosine similarity
similarity = torch.nn.functional.cosine_similarity(
text_embedding_norm, image_embedding_norm, dim=-1
)
# Convert to percentage as text
similarity = [f"{sim.item() * 100:.2f}%" for sim in similarity]
# Insert empty text strings in similarity
for i in empty_data_indices:
similarity.insert(i, "")
return (text_embeddings, image_embeddings, similarity, openclip_model_name)
# Define Gradio interface
demo = gr.Interface(
fn=generate_embedding,
inputs=[
gr.Textbox(
lines=5,
max_lines=5,
placeholder="Enter Text Here...",
label="Text to Embed",
),
gr.Image(height=512, type="pil", label="Image to Embed"),
],
outputs=[
gr.Textbox(lines=5, max_lines=5, label="Text Embedding", autoscroll=False),
gr.Textbox(lines=5, max_lines=5, label="Image Embedding", autoscroll=False),
gr.Textbox(label="Cosine Similarity"),
gr.Textbox(label="Embedding Model"),
],
title="OpenCLIP Embedding Generator",
description="Generate embeddings using OpenCLIP model for text and images.",
allow_flagging="never",
batch=False,
api_name="embed",
)
# Enable queueing and launch the app
if __name__ == "__main__":
demo.queue(api_open=True).launch(show_api=True)
|