cdnuts commited on
Commit
75afda4
·
verified ·
1 Parent(s): 5e7db30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -23
app.py CHANGED
@@ -1,6 +1,11 @@
1
  import json
 
 
 
 
2
 
3
  import gradio as gr
 
4
  from PIL import Image
5
  import safetensors.torch
6
  import spaces
@@ -10,6 +15,8 @@ import torch
10
  from torchvision.transforms import transforms
11
  from torchvision.transforms import InterpolationMode
12
  import torchvision.transforms.functional as TF
 
 
13
 
14
  torch.set_grad_enabled(False)
15
 
@@ -132,12 +139,11 @@ for idx, tag in enumerate(allowed_tags):
132
 
133
  sorted_tag_score = {}
134
 
135
- @spaces.GPU(duration=5)
136
  def run_classifier(image, threshold):
137
  global sorted_tag_score
138
  img = image.convert('RGB')
139
- tensor = transform(img).unsqueeze(0)
140
- tensor = tensor.to(device)
141
  with torch.no_grad():
142
  logits = model(tensor)
143
  probabilities = torch.nn.functional.sigmoid(logits[0])
@@ -156,7 +162,84 @@ def create_tags(threshold):
156
  filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
157
  text_no_impl = ", ".join(filtered_tag_score.keys())
158
  return text_no_impl, filtered_tag_score
 
 
 
 
 
 
 
 
 
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  with gr.Blocks(css=".output-class { display: none; }") as demo:
162
  gr.Markdown("""
@@ -165,25 +248,44 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
165
 
166
  This tagger is the result of joint efforts between members of the RedRocket team. Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
167
  """)
168
- with gr.Row():
169
- with gr.Column():
170
- image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
171
- threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")
172
- with gr.Column():
173
- tag_string = gr.Textbox(label="Tag String")
174
- label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
175
-
176
- image_input.upload(
177
- fn=run_classifier,
178
- inputs=[image_input, threshold_slider],
179
- outputs=[tag_string, label_box]
180
- )
181
-
182
- threshold_slider.input(
183
- fn=create_tags,
184
- inputs=[threshold_slider],
185
- outputs=[tag_string, label_box]
186
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  if __name__ == "__main__":
189
- demo.launch()
 
1
  import json
2
+ import os
3
+ import zipfile
4
+ from io import BytesIO
5
+ from tempfile import NamedTemporaryFile
6
 
7
  import gradio as gr
8
+ import pandas as pd
9
  from PIL import Image
10
  import safetensors.torch
11
  import spaces
 
15
  from torchvision.transforms import transforms
16
  from torchvision.transforms import InterpolationMode
17
  import torchvision.transforms.functional as TF
18
+ from torch.utils.data import Dataset, DataLoader
19
+
20
 
21
  torch.set_grad_enabled(False)
22
 
 
139
 
140
  sorted_tag_score = {}
141
 
142
+ @spaces.GPU(duration=9)
143
  def run_classifier(image, threshold):
144
  global sorted_tag_score
145
  img = image.convert('RGB')
146
+ tensor = transform(img).unsqueeze(0).to(device)
 
147
  with torch.no_grad():
148
  logits = model(tensor)
149
  probabilities = torch.nn.functional.sigmoid(logits[0])
 
162
  filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
163
  text_no_impl = ", ".join(filtered_tag_score.keys())
164
  return text_no_impl, filtered_tag_score
165
+
166
+
167
+ class ImageDataset(Dataset):
168
+ def __init__(self, image_files, transform):
169
+ self.image_files = image_files
170
+ self.transform = transform
171
+
172
+ def __len__(self):
173
+ return len(self.image_files)
174
 
175
+ def __getitem__(self, idx):
176
+ img_path = self.image_files[idx]
177
+ img = Image.open(img_path).convert('RGB')
178
+ return self.transform(img), os.path.basename(img_path)
179
+
180
+
181
+ @spaces.GPU(duration=299)
182
+ def process_images(images, threshold):
183
+ dataset = ImageDataset(images, transform)
184
+
185
+ dataloader = DataLoader(dataset, batch_size=64, num_workers=0, pin_memory=True, drop_last=False)
186
+
187
+ all_results = []
188
+
189
+ with torch.no_grad():
190
+ for batch, filenames in dataloader:
191
+
192
+ batch = batch.to(device)
193
+ with torch.no_grad():
194
+ logits = model(batch)
195
+ probabilities = torch.nn.functional.sigmoid(logits)
196
+
197
+ for i, prob in enumerate(probabilities):
198
+ indices = torch.where(prob > threshold)[0]
199
+ values = prob[indices]
200
+
201
+ temp = []
202
+ tag_score = dict()
203
+ for j in range(indices.size(0)):
204
+ temp.append([allowed_tags[indices[j]], values[j].item()])
205
+ tag_score[allowed_tags[indices[j]]] = values[j].item()
206
+
207
+ tags = ", ".join([t[0] for t in temp])
208
+ all_results.append((filenames[i], tags, tag_score))
209
+
210
+ return all_results
211
+
212
+ def is_valid_image(file_path):
213
+ try:
214
+ with Image.open(file_path) as img:
215
+ img.verify()
216
+ return True
217
+ except:
218
+ return False
219
+
220
+ def process_zip(zip_file, threshold):
221
+ if zip_file is None:
222
+ return None, None
223
+
224
+ temp_dir = "temp_images"
225
+ os.makedirs(temp_dir, exist_ok=True)
226
+
227
+ with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
228
+ zip_ref.extractall(temp_dir)
229
+
230
+ all_files = [os.path.join(temp_dir, f) for f in os.listdir(temp_dir)]
231
+ image_files = [f for f in all_files if is_valid_image(f)]
232
+ results = process_images(image_files, threshold)
233
+
234
+ temp_file = NamedTemporaryFile(delete=False, suffix=".zip")
235
+ with zipfile.ZipFile(temp_file, "w") as zip_ref:
236
+ for image_name, text_no_impl, _ in results:
237
+ with zip_ref.open(''.join(image_name.split('.')[:-1]) + ".txt", 'w') as file:
238
+ file.write(text_no_impl.encode())
239
+ temp_file.seek(0)
240
+ df = pd.DataFrame([(os.path.basename(f), t) for f, t, _ in results], columns=['Image', 'Tags'])
241
+
242
+ return temp_file.name, df
243
 
244
  with gr.Blocks(css=".output-class { display: none; }") as demo:
245
  gr.Markdown("""
 
248
 
249
  This tagger is the result of joint efforts between members of the RedRocket team. Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
250
  """)
251
+
252
+ with gr.Tabs():
253
+ with gr.TabItem("Single Image"):
254
+ with gr.Row():
255
+ with gr.Column():
256
+ image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
257
+ threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")
258
+ with gr.Column():
259
+ tag_string = gr.Textbox(label="Tag String")
260
+ label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
261
+
262
+ image_input.upload(
263
+ fn=run_classifier,
264
+ inputs=[image_input, threshold_slider],
265
+ outputs=[tag_string, label_box]
266
+ )
267
+
268
+ threshold_slider.input(
269
+ fn=create_tags,
270
+ inputs=[threshold_slider],
271
+ outputs=[tag_string, label_box]
272
+ )
273
+
274
+ with gr.TabItem("Multiple Images"):
275
+ with gr.Row():
276
+ with gr.Column():
277
+ zip_input = gr.File(label="Upload ZIP file", file_types=['.zip'])
278
+ multi_threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")
279
+ process_button = gr.Button("Process Images")
280
+ with gr.Column():
281
+ zip_output = gr.File(label="Download Tagged Text Files (ZIP)")
282
+ dataframe_output = gr.Dataframe(label="Image Tags Summary")
283
+
284
+ process_button.click(
285
+ fn=process_zip,
286
+ inputs=[zip_input, multi_threshold_slider],
287
+ outputs=[zip_output, dataframe_output]
288
+ )
289
 
290
  if __name__ == "__main__":
291
+ demo.queue().launch()