RoboApocalypse commited on
Commit
d85d411
·
1 Parent(s): db03f5d

Add functions to generate text and image embeddings

Browse files

Add function to generate embeddings for text and image data

Update Gradio interface to include hidden base64 encoded image input

Update Gradio interface to include hidden base64 encoded image embedding output

Files changed (1) hide show
  1. app.py +100 -15
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import gradio as gr
2
  from numpy import empty
3
  import open_clip
4
- from regex import F
5
  import torch
6
- import json
7
- import PIL
 
8
 
9
  # Set device to GPU if available
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -20,25 +20,20 @@ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
20
  )
21
 
22
 
23
- def generate_embedding(text_data, image_data):
 
24
  """
25
- Generate embeddings for text and image data using the OpenCLIP model.
26
 
27
  Parameters
28
  ----------
29
  text_data : str or tuple of str
30
  Text data to embed.
31
- image_data : PIL.Image.Image or tuple of PIL.Image.Image
32
- Image data to embed.
33
 
34
  Returns
35
  -------
36
  text_embeddings : list of str
37
  List of text embeddings.
38
- image_embeddings : list of str
39
- List of image embeddings.
40
- similarity : list of str
41
- List of cosine similarity between text and image embeddings.
42
  """
43
 
44
  # Embed text data
@@ -53,6 +48,10 @@ def generate_embedding(text_data, image_data):
53
  if isinstance(text_data, tuple):
54
  text_data = list(text_data)
55
 
 
 
 
 
56
  # Keep track of indices of empty text strings
57
  empty_data_indices = [i for i, text in enumerate(text_data) if text == ""]
58
 
@@ -74,12 +73,30 @@ def generate_embedding(text_data, image_data):
74
  for i in empty_data_indices:
75
  text_embeddings.insert(i, "")
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # Embed image data
78
  image_embeddings = []
79
  empty_data_indices = []
80
  if image_data:
81
  # If image_data is a single PIL image, convert to list of PIL images
82
- if isinstance(image_data, PIL.Image.Image):
83
  image_data = [image_data]
84
 
85
  # If image_data is a tuple of images, convert to list of images
@@ -108,6 +125,41 @@ def generate_embedding(text_data, image_data):
108
  for i in empty_data_indices:
109
  image_embeddings.insert(i, "")
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # Calculate cosine similarity between text and image embeddings
112
  similarity = []
113
  empty_data_indices = []
@@ -141,7 +193,38 @@ def generate_embedding(text_data, image_data):
141
  for i in empty_data_indices:
142
  similarity.insert(i, "")
143
 
144
- return (text_embeddings, image_embeddings, similarity)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
 
147
  # Define Gradio interface
@@ -149,12 +232,14 @@ demo = gr.Interface(
149
  fn=generate_embedding,
150
  inputs=[
151
  gr.Textbox(lines=5, max_lines=5, placeholder="Enter Text Here...", label="Text to Embed"),
152
- gr.Image(height=512, type="pil", label="Image to Embed")
 
153
  ],
154
  outputs=[
155
  gr.Textbox(lines=5, max_lines=5, label="Text Embedding", autoscroll=False),
156
  gr.Textbox(lines=5, max_lines=5, label="Image Embedding", autoscroll=False),
157
- gr.Textbox(label="Cosine Similarity")
 
158
  ],
159
  title="OpenCLIP Embedding Generator",
160
  description="Generate embeddings using OpenCLIP model for text and images.",
 
1
  import gradio as gr
2
  from numpy import empty
3
  import open_clip
 
4
  import torch
5
+ import PIL.Image as Image
6
+ from io import BytesIO
7
+ import base64
8
 
9
  # Set device to GPU if available
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
20
  )
21
 
22
 
23
+ # Define function to generate text embeddings
24
+ def generate_text_embedding(text_data):
25
  """
26
+ Generate embeddings for text data using the OpenCLIP model.
27
 
28
  Parameters
29
  ----------
30
  text_data : str or tuple of str
31
  Text data to embed.
 
 
32
 
33
  Returns
34
  -------
35
  text_embeddings : list of str
36
  List of text embeddings.
 
 
 
 
37
  """
38
 
39
  # Embed text data
 
48
  if isinstance(text_data, tuple):
49
  text_data = list(text_data)
50
 
51
+ # If text_data is not a list of strings, raise error
52
+ if not isinstance(text_data, list):
53
+ raise TypeError("text_data must be a string or a tuple of strings.")
54
+
55
  # Keep track of indices of empty text strings
56
  empty_data_indices = [i for i, text in enumerate(text_data) if text == ""]
57
 
 
73
  for i in empty_data_indices:
74
  text_embeddings.insert(i, "")
75
 
76
+ return text_embeddings
77
+
78
+ # Define function to generate image embeddings
79
+ def generate_image_embedding(image_data):
80
+ """
81
+ Generate embeddings for image data using the OpenCLIP model.
82
+
83
+ Parameters
84
+ ----------
85
+ image_data : PIL.Image.Image or tuple of PIL.Image.Image
86
+ Image data to embed.
87
+
88
+ Returns
89
+ -------
90
+ image_embeddings : list of str
91
+ List of image embeddings.
92
+ """
93
+
94
  # Embed image data
95
  image_embeddings = []
96
  empty_data_indices = []
97
  if image_data:
98
  # If image_data is a single PIL image, convert to list of PIL images
99
+ if isinstance(image_data, Image.Image):
100
  image_data = [image_data]
101
 
102
  # If image_data is a tuple of images, convert to list of images
 
125
  for i in empty_data_indices:
126
  image_embeddings.insert(i, "")
127
 
128
+ return image_embeddings
129
+
130
+
131
+ # Define function to generate embeddings
132
+ def generate_embedding(text_data, image_data, image_data_base64):
133
+ """
134
+ Generate embeddings for text and image data using the OpenCLIP model.
135
+
136
+ Parameters
137
+ ----------
138
+ text_data : str or tuple of str
139
+ Text data to embed.
140
+ image_data : PIL.Image.Image or tuple of PIL.Image.Image
141
+ Image data to embed.
142
+ image_data_base64 : str or tuple of str
143
+ Base64 encoded image data to embed.
144
+
145
+ Returns
146
+ -------
147
+ text_embeddings : list of str
148
+ List of text embeddings.
149
+ image_embeddings : list of str
150
+ List of image embeddings.
151
+ similarity : list of str
152
+ List of cosine similarity between text and image embeddings.
153
+ image_data_base64_embeddings : str or tuple of str
154
+ List of image embeddings for base64 encoded image data.
155
+ """
156
+
157
+ # Embed text data
158
+ text_embeddings = generate_text_embedding(text_data)
159
+
160
+ # Embed image data
161
+ image_embeddings = generate_image_embedding(image_data)
162
+
163
  # Calculate cosine similarity between text and image embeddings
164
  similarity = []
165
  empty_data_indices = []
 
193
  for i in empty_data_indices:
194
  similarity.insert(i, "")
195
 
196
+ # Embed base64 encoded image data
197
+ decoded_image_data = []
198
+ if image_data_base64:
199
+ # If image_data_base64 is a string, convert to list of strings
200
+ if isinstance(image_data_base64, str):
201
+ image_data_base64 = [image_data_base64]
202
+
203
+ # If image_data_base64 is a tuple of strings, convert to list of strings
204
+ if isinstance(image_data_base64, tuple):
205
+ image_data_base64 = list(image_data_base64)
206
+
207
+ # If image_data_base64 is not a list of strings, raise error
208
+ if not isinstance(image_data_base64, list):
209
+ raise TypeError("image_data_base64 must be a string or a tuple of strings.")
210
+
211
+ # Keep track of indices of empty image strings
212
+ empty_data_indices = [i for i, img in enumerate(image_data_base64) if img == ""]
213
+
214
+ # Remove empty image strings
215
+ image_data_base64 = [img for img in image_data_base64 if img != ""]
216
+
217
+ if image_data_base64:
218
+ # Decode base64 encoded image data
219
+ decoded_image_data = [Image.open(BytesIO(base64.b64decode(img))) for img in image_data_base64]
220
+
221
+ # Insert empty strings at indices of empty image strings
222
+ for i in empty_data_indices:
223
+ decoded_image_data.insert(i, None)
224
+
225
+ image_data_base64_embeddings = generate_image_embedding(tuple(decoded_image_data))
226
+
227
+ return (text_embeddings, image_embeddings, similarity, image_data_base64_embeddings)
228
 
229
 
230
  # Define Gradio interface
 
232
  fn=generate_embedding,
233
  inputs=[
234
  gr.Textbox(lines=5, max_lines=5, placeholder="Enter Text Here...", label="Text to Embed"),
235
+ gr.Image(height=512, type="pil", label="Image to Embed"),
236
+ gr.Textbox(label="Base64 Encoded Image", visible=False)
237
  ],
238
  outputs=[
239
  gr.Textbox(lines=5, max_lines=5, label="Text Embedding", autoscroll=False),
240
  gr.Textbox(lines=5, max_lines=5, label="Image Embedding", autoscroll=False),
241
+ gr.Textbox(label="Cosine Similarity"),
242
+ gr.Textbox(label="Embedding of Base64 Encoded Images", visible=False)
243
  ],
244
  title="OpenCLIP Embedding Generator",
245
  description="Generate embeddings using OpenCLIP model for text and images.",