tommymarto commited on
Commit
33a4bc0
·
1 Parent(s): fea6c63

demo update

Browse files
Files changed (1) hide show
  1. demo.py +74 -40
demo.py CHANGED
@@ -18,14 +18,11 @@ from PIL import Image
18
  import httplib2
19
 
20
 
21
-
22
  #################################################################################################################################################
23
  # Authentication
24
  #################################################################################################################################################
25
 
26
 
27
-
28
-
29
  # read secret api key
30
  FIREBASE_API_KEY = os.environ['FirebaseSecret']
31
  FIREBASE_URL = os.environ['FirebaseURL']
@@ -40,20 +37,23 @@ SCOPES = ["https://www.googleapis.com/auth/drive"]
40
 
41
 
42
  class Experiment():
43
- def __init__(self, name, corrupted, options, selected_image=None, initialized=False):
44
- self.name = name
 
45
  self.corrupted = corrupted
46
  self.options = options
47
  self.selected_image = selected_image
48
  self.initialized = initialized
49
 
50
- def to_dict(self):
51
  return {
52
- "experiment_name": self.name,
53
- "corrupted": self.corrupted["name"],
 
 
54
  "options": [img["name"] for img in self.options],
55
- "selected_image": self.selected_image,
56
- "algo": self.options[self.selected_image]["algo"]
57
  }
58
 
59
  def to_pil(self):
@@ -75,8 +75,8 @@ class Experiment():
75
 
76
  class App():
77
 
78
- NUM_THREADS = 8
79
- NUM_TO_SCHEDULE = 8
80
 
81
  def __init__(self):
82
  self.init_remote()
@@ -132,16 +132,16 @@ class App():
132
  corrupted = self.current_experiment.corrupted
133
  images = self.current_experiment.options
134
 
135
- self.corrupted_component = gr.Image(value=corrupted["pil"], label="corr", show_label=True, show_download_button=False, elem_id="padded")
136
  self.img_components = [
137
- gr.Image(value=img["pil"], label=f"{i}", show_label=True, show_download_button=False, elem_id="unsel")
138
  for i, img in enumerate(images)
139
  ]
140
 
141
  selected_index = self.current_experiment.selected_image
142
  if selected_index is not None:
143
  self.img_components[selected_index] = (
144
- gr.Image(value=images[selected_index]["pil"], label=f"{selected_index}", show_label=True, show_download_button=False, elem_id="sel")
145
  )
146
 
147
  return [*self.img_components, self.corrupted_component]
@@ -150,8 +150,9 @@ class App():
150
  self.current_experiment.selected_image = int(evt.target.label)
151
  return self.build_components_from_experiment()
152
 
153
- def save(self):
154
- if save_to_firebase(self.current_experiment, self.firebase_data_ref):
 
155
  self.next_experiment()
156
  self.build_components_from_experiment()
157
  return [*self.img_components, self.corrupted_component]
@@ -161,13 +162,15 @@ class App():
161
  # API calls
162
  #################################################################################################################################################
163
 
164
- def save_to_firebase(experiment, firebase_data_ref):
165
- if experiment is None or experiment.selected_image is None:
166
  gr.Warning("You must select an image before submitting")
167
  return False
 
 
168
 
169
  firebase_data_ref.push({
170
- **experiment.to_dict(),
171
  "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
172
  })
173
 
@@ -177,9 +180,27 @@ def save_to_firebase(experiment, firebase_data_ref):
177
  def list_folders(service):
178
  folders = []
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  results = (
181
  service.files()
182
- .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q="mimeType='application/vnd.google-apps.folder' and name contains 'Experiment'")
183
  .execute()
184
  )
185
  folders.extend(results.get("files", []))
@@ -188,7 +209,7 @@ def list_folders(service):
188
  page_token = results["nextPageToken"]
189
  results = (
190
  service.files()
191
- .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q="mimeType='application/vnd.google-apps.folder' and name contains 'Experiment'", pageToken=page_token)
192
  .execute()
193
  )
194
  folders.extend(results.get("files", []))
@@ -235,17 +256,28 @@ def schedule_downloads(service, q_to_download, q_requested):
235
  break
236
 
237
  print("Scheduling new experiment")
 
238
  folders = list_folders(service)
 
 
239
 
240
- # sample a random folder from folders
241
  folder = random.choice(folders)
242
 
243
- # list subfolders in the folder
244
  subfolders = list_files_in_folder(service, folder, filter_="mimeType='application/vnd.google-apps.folder'")
 
 
 
 
 
 
 
 
245
 
246
  # the results should be 2 subfolders: SDEdit and ODEdit
247
- odedit_subfolder = [subfolder for subfolder in subfolders if "ODEdit" in subfolder["name"]][0]
248
- sdedit_subfolder = [subfolder for subfolder in subfolders if "SDEdit" in subfolder["name"]][0]
249
 
250
  odedit_files = list_files_in_folder(service, odedit_subfolder)
251
  sdedit_files = list_files_in_folder(service, sdedit_subfolder)
@@ -256,11 +288,11 @@ def schedule_downloads(service, q_to_download, q_requested):
256
  selected_sdedit_files = random.sample(sdedit_files, k=5)
257
  selected_sdedit_files = [{**file, "algo": "SDEdit"} for file in selected_sdedit_files]
258
 
259
- corrupted_file = list_files_in_folder(service, folder, filter_="mimeType contains 'image/'")[0]
260
 
261
  selected_files = [*selected_odedit_files, *selected_sdedit_files]
262
 
263
- experiment = Experiment(folder["name"], corrupted_file, selected_files)
264
 
265
  q_to_download.put(experiment)
266
  q_requested.task_done()
@@ -298,8 +330,9 @@ def download_thread(service, q_to_download, q_processed):
298
  css = """
299
  #unsel {border: solid 5px transparent !important; border-radius: 15px !important}
300
  #sel {border: solid 5px #00c0ff !important; border-radius: 15px !important}
301
- #padded {margin-left: 25% !important; margin-right: 5% !important}
302
- #paddedRight {margin-right: 5% !important}
 
303
  """
304
 
305
  def build_demo():
@@ -307,24 +340,25 @@ def build_demo():
307
 
308
  with gr.Blocks(title="Unsupervised Image Editing", css=css) as demo:
309
 
310
- with gr.Row():
311
- corrupted_component = gr.Image(label="corr", elem_id="padded")
312
- with gr.Column(scale=3):
313
- gr.Markdown("<div style='width: 100%'><h1 style='text-align: center; display: inline-block; width: 100%'>The sample on the left is a corrupted image</h1></div>", elem_id="paddedRight")
314
- gr.Markdown("<div style='width: 100%'><h3 style='text-align: center; display: inline-block; width: 100%'>Below are decorrupted versions sampled from various models. Click on the picture you like best</h3></div>", elem_id="paddedRight")
315
- btn = gr.Button("Submit")
316
- gr.Markdown("<hr>")
317
 
318
  img_components = []
319
  with gr.Row():
320
  for i, img in enumerate(app.img_components[:5]):
321
- img_components.append(gr.Image(label=f"{i}", elem_id="unsel"))
322
 
323
  with gr.Row():
324
  for i, img in enumerate(app.img_components[5:]):
325
- img_components.append(gr.Image(label=f"{i+5}", elem_id="unsel"))
326
 
327
- btn.click(app.save, None, [*img_components, corrupted_component])
 
328
  for img in img_components:
329
  img.select(app.on_select, None, img_components, show_progress="hidden")
330
 
 
18
  import httplib2
19
 
20
 
 
21
  #################################################################################################################################################
22
  # Authentication
23
  #################################################################################################################################################
24
 
25
 
 
 
26
  # read secret api key
27
  FIREBASE_API_KEY = os.environ['FirebaseSecret']
28
  FIREBASE_URL = os.environ['FirebaseURL']
 
37
 
38
 
39
  class Experiment():
40
+ def __init__(self, corruption, image_id, corrupted, options, selected_image=None, initialized=False):
41
+ self.corruption = corruption
42
+ self.image_id = image_id
43
  self.corrupted = corrupted
44
  self.options = options
45
  self.selected_image = selected_image
46
  self.initialized = initialized
47
 
48
+ def to_dict(self, custom_algo=None):
49
  return {
50
+ "corruption": self.corruption,
51
+ "image_number": self.image_id,
52
+ "dataset": "CelebA",
53
+ "corrupted_filename": self.corrupted["name"],
54
  "options": [img["name"] for img in self.options],
55
+ "selected_image": "None" if custom_algo else self.selected_image,
56
+ "algo": custom_algo if custom_algo is not None else self.options[self.selected_image]["algo"]
57
  }
58
 
59
  def to_pil(self):
 
75
 
76
  class App():
77
 
78
+ NUM_THREADS = 16
79
+ NUM_TO_SCHEDULE = 10
80
 
81
  def __init__(self):
82
  self.init_remote()
 
132
  corrupted = self.current_experiment.corrupted
133
  images = self.current_experiment.options
134
 
135
+ self.corrupted_component = gr.Image(value=corrupted["pil"], label=self.current_experiment.corruption, show_label=True, show_download_button=False, elem_id="corrupted")
136
  self.img_components = [
137
+ gr.Image(value=img["pil"], label=f"{i}", show_label=False, show_download_button=False, elem_id="unsel")
138
  for i, img in enumerate(images)
139
  ]
140
 
141
  selected_index = self.current_experiment.selected_image
142
  if selected_index is not None:
143
  self.img_components[selected_index] = (
144
+ gr.Image(value=images[selected_index]["pil"], label=f"{selected_index}", show_label=False, show_download_button=False, elem_id="sel")
145
  )
146
 
147
  return [*self.img_components, self.corrupted_component]
 
150
  self.current_experiment.selected_image = int(evt.target.label)
151
  return self.build_components_from_experiment()
152
 
153
+ def save(self, mode):
154
+ print(f"Saving experiment with mode {mode}")
155
+ if save_to_firebase(self.current_experiment, self.firebase_data_ref, mode=mode):
156
  self.next_experiment()
157
  self.build_components_from_experiment()
158
  return [*self.img_components, self.corrupted_component]
 
162
  # API calls
163
  #################################################################################################################################################
164
 
165
+ def save_to_firebase(experiment: Experiment, firebase_data_ref, mode):
166
+ if mode == "save" and (experiment is None or experiment.selected_image is None):
167
  gr.Warning("You must select an image before submitting")
168
  return False
169
+ if mode == "skip":
170
+ experiment.selected_image = None
171
 
172
  firebase_data_ref.push({
173
+ **experiment.to_dict(custom_algo=mode if mode == "skip" else None),
174
  "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
175
  })
176
 
 
180
  def list_folders(service):
181
  folders = []
182
 
183
+ corruptions = [
184
+ "brightness",
185
+ "elastic_transform",
186
+ "frost",
187
+ "impulse_noise",
188
+ "masking_random_color",
189
+ "shot_noise",
190
+ "speckle_noise",
191
+ "contrast",
192
+ "fog",
193
+ "gaussian_noise",
194
+ "latent",
195
+ "masking_vline_random_color",
196
+ "spatter"
197
+ ]
198
+
199
+ name_query = "(" + " or ".join([f"name contains '{c}'" for c in corruptions]) + ")"
200
+
201
  results = (
202
  service.files()
203
+ .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q=f"mimeType='application/vnd.google-apps.folder' and {name_query}")
204
  .execute()
205
  )
206
  folders.extend(results.get("files", []))
 
209
  page_token = results["nextPageToken"]
210
  results = (
211
  service.files()
212
+ .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q=f"mimeType='application/vnd.google-apps.folder' and {name_query}", pageToken=page_token)
213
  .execute()
214
  )
215
  folders.extend(results.get("files", []))
 
256
  break
257
 
258
  print("Scheduling new experiment")
259
+ # this should give a list of top-level folders of corruption types
260
  folders = list_folders(service)
261
+ # print(f"Found {len(folders)} folders")
262
+ # print(folders)
263
 
264
+ # sample a random corruption from folders
265
  folder = random.choice(folders)
266
 
267
+ # inside the corruption folders, there should be one folder per image id
268
  subfolders = list_files_in_folder(service, folder, filter_="mimeType='application/vnd.google-apps.folder'")
269
+ # print(f"Found {len(subfolders)} subfolders")
270
+ # print(subfolders)
271
+
272
+ # sample a random image from subfolders
273
+ subfolder = random.choice(subfolders)
274
+
275
+ # list subsubfolders in the subfolder
276
+ subsubfolders = list_files_in_folder(service, subfolder, filter_="mimeType='application/vnd.google-apps.folder'")
277
 
278
  # the results should be 2 subfolders: SDEdit and ODEdit
279
+ odedit_subfolder = [f for f in subsubfolders if "ode" in f["name"]][0]
280
+ sdedit_subfolder = [f for f in subsubfolders if "sde" in f["name"]][0]
281
 
282
  odedit_files = list_files_in_folder(service, odedit_subfolder)
283
  sdedit_files = list_files_in_folder(service, sdedit_subfolder)
 
288
  selected_sdedit_files = random.sample(sdedit_files, k=5)
289
  selected_sdedit_files = [{**file, "algo": "SDEdit"} for file in selected_sdedit_files]
290
 
291
+ corrupted_file = list_files_in_folder(service, subfolder, filter_="mimeType contains 'image/'")[0]
292
 
293
  selected_files = [*selected_odedit_files, *selected_sdedit_files]
294
 
295
+ experiment = Experiment(folder["name"], subfolder["name"], corrupted_file, selected_files)
296
 
297
  q_to_download.put(experiment)
298
  q_requested.task_done()
 
330
  css = """
331
  #unsel {border: solid 5px transparent !important; border-radius: 15px !important}
332
  #sel {border: solid 5px #00c0ff !important; border-radius: 15px !important}
333
+ #corrupted {margin-left: 5%; margin-right: 5%; padding: 0 !important}
334
+ #reducedHeight {height: 10px !important}
335
+ #padded {padding-left: 2%; padding-right: 2%}
336
  """
337
 
338
  def build_demo():
 
340
 
341
  with gr.Blocks(title="Unsupervised Image Editing", css=css) as demo:
342
 
343
+ with gr.Row(elem_id="padded"):
344
+ corrupted_component = gr.Image(label="corr", elem_id="corrupted", show_label=False, show_download_button=False)
345
+ with gr.Column(scale=3, elem_id="padded"):
346
+ gr.Markdown("<div style='width: 100%'><h1 style='text-align: center; display: inline-block; width: 100%'>The sample on the left is a corrupted image</h1></div>")
347
+ gr.Markdown("<div style='width: 100%'><h3 style='text-align: center; display: inline-block; width: 100%'>Below are decorrupted versions sampled from various models. Click on the picture you like best</h3></div>")
348
+ btn_skip = gr.Button("I have no preference")
349
+ btn_submit = gr.Button("Submit preference")
350
 
351
  img_components = []
352
  with gr.Row():
353
  for i, img in enumerate(app.img_components[:5]):
354
+ img_components.append(gr.Image(label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False))
355
 
356
  with gr.Row():
357
  for i, img in enumerate(app.img_components[5:]):
358
+ img_components.append(gr.Image(label=f"{i+5}", elem_id="unsel", show_label=False, show_download_button=False))
359
 
360
+ btn_skip.click(partial(app.save, mode="skip"), None, [*img_components, corrupted_component])
361
+ btn_submit.click(partial(app.save, mode="save"), None, [*img_components, corrupted_component])
362
  for img in img_components:
363
  img.select(app.on_select, None, img_components, show_progress="hidden")
364