tommymarto commited on
Commit
dbd36d0
·
1 Parent(s): 33a4bc0

updated demo

Browse files
Files changed (1) hide show
  1. demo.py +139 -320
demo.py CHANGED
@@ -1,21 +1,21 @@
1
- from concurrent.futures import ThreadPoolExecutor
2
  import datetime
3
  from functools import partial
4
- import os
5
- import io
6
  import json
7
- import queue
8
- import gradio as gr
9
  import random
 
 
10
  import firebase_admin
11
  from firebase_admin import db, credentials
12
- from google.oauth2 import service_account
13
- from googleapiclient.discovery import build
14
- import google_auth_httplib2
15
- import googleapiclient
16
- from googleapiclient.http import MediaIoBaseDownload
17
- from PIL import Image
18
- import httplib2
 
 
19
 
20
 
21
  #################################################################################################################################################
@@ -26,347 +26,166 @@ import httplib2
26
  # read secret api key
27
  FIREBASE_API_KEY = os.environ['FirebaseSecret']
28
  FIREBASE_URL = os.environ['FirebaseURL']
29
- DRIVE_API_KEY = os.environ['DriveSecret']
30
 
31
- SCOPES = ["https://www.googleapis.com/auth/drive"]
 
 
 
32
 
33
 
34
- #################################################################################################################################################
35
- # Types
36
- #################################################################################################################################################
37
 
38
 
39
- class Experiment():
40
- def __init__(self, corruption, image_id, corrupted, options, selected_image=None, initialized=False):
41
- self.corruption = corruption
42
- self.image_id = image_id
43
- self.corrupted = corrupted
44
- self.options = options
45
- self.selected_image = selected_image
46
- self.initialized = initialized
47
-
48
- def to_dict(self, custom_algo=None):
49
- return {
50
- "corruption": self.corruption,
51
- "image_number": self.image_id,
52
- "dataset": "CelebA",
53
- "corrupted_filename": self.corrupted["name"],
54
- "options": [img["name"] for img in self.options],
55
- "selected_image": "None" if custom_algo else self.selected_image,
56
- "algo": custom_algo if custom_algo is not None else self.options[self.selected_image]["algo"]
57
- }
58
 
59
- def to_pil(self):
60
- return self.corrupted["pil"], [img["pil"] for img in self.options]
61
-
62
- @staticmethod
63
- def from_dict(source):
64
- return Experiment(source["name"], source["corrupted"], source["options"])
65
-
66
- def __repr__(self):
67
- return f"Experiment(name={self.name}, corrupted={self.corrupted}, options={self.options})"
68
-
69
- def __str__(self):
70
- return f"Experiment(name={self.name}, corrupted={self.corrupted}, options={self.options})"
71
-
72
- def __eq__(self, other):
73
- return self.name == other.name and self.corrupted == other.corrupted and self.options == other.options
 
 
 
 
 
 
 
 
 
 
 
74
 
75
-
76
- class App():
77
-
78
- NUM_THREADS = 16
79
- NUM_TO_SCHEDULE = 10
80
-
81
- def __init__(self):
82
- self.init_remote()
83
- self.init_download_thread()
84
-
85
- for _ in range(App.NUM_TO_SCHEDULE):
86
- self.q_requested.put({})
87
-
88
- self.next_experiment()
89
- self.build_components_from_experiment()
90
-
91
-
92
- def lifespan(self, fastapi_app):
93
- yield
94
- # cancel thredpool
95
- self.executor.shutdown(wait=False)
96
- # cancel download threads
97
- for _ in range(App.NUM_THREADS):
98
- self.q_to_download.put(None)
99
- self.q_requested.put(None)
100
-
101
- def init_remote(self):
102
-
103
- def build_request(http, *args, **kwargs):
104
- new_http = google_auth_httplib2.AuthorizedHttp(self.drive_creds, http=httplib2.Http())
105
- return googleapiclient.http.HttpRequest(new_http, *args, **kwargs)
106
-
107
- # init drive service
108
- self.drive_creds = service_account.Credentials.from_service_account_info(json.loads(DRIVE_API_KEY), scopes=SCOPES)
109
- authorized_http = google_auth_httplib2.AuthorizedHttp(self.drive_creds, http=httplib2.Http())
110
- self.drive_service = build("drive", "v3", requestBuilder=build_request, http=authorized_http)
111
-
112
- # init firebase service
113
- self.firebase_creds = credentials.Certificate(json.loads(FIREBASE_API_KEY))
114
- self.firebase_app = firebase_admin.initialize_app(self.firebase_creds, {'databaseURL': FIREBASE_URL})
115
- self.firebase_data_ref = db.reference("data")
116
-
117
- def init_download_thread(self):
118
- # init download thread and queue
119
- self.q_requested = queue.Queue()
120
- self.q_to_download = queue.Queue()
121
- self.q_processed = queue.Queue()
122
- self.executor = ThreadPoolExecutor(max_workers=2*App.NUM_THREADS)
123
- for _ in range(App.NUM_THREADS):
124
- self.executor.submit(download_thread, self.drive_service, self.q_to_download, self.q_processed)
125
- self.executor.submit(schedule_downloads, self.drive_service, self.q_to_download, self.q_requested)
126
-
127
- def next_experiment(self):
128
- self.q_requested.put({})
129
- self.current_experiment : Experiment = self.q_processed.get()
130
-
131
- def build_components_from_experiment(self):
132
- corrupted = self.current_experiment.corrupted
133
- images = self.current_experiment.options
134
-
135
- self.corrupted_component = gr.Image(value=corrupted["pil"], label=self.current_experiment.corruption, show_label=True, show_download_button=False, elem_id="corrupted")
136
- self.img_components = [
137
- gr.Image(value=img["pil"], label=f"{i}", show_label=False, show_download_button=False, elem_id="unsel")
138
- for i, img in enumerate(images)
139
- ]
140
-
141
- selected_index = self.current_experiment.selected_image
142
- if selected_index is not None:
143
- self.img_components[selected_index] = (
144
- gr.Image(value=images[selected_index]["pil"], label=f"{selected_index}", show_label=False, show_download_button=False, elem_id="sel")
145
- )
146
-
147
- return [*self.img_components, self.corrupted_component]
148
 
149
- def on_select(self, evt: gr.SelectData): # SelectData is a subclass of EventData
150
- self.current_experiment.selected_image = int(evt.target.label)
151
- return self.build_components_from_experiment()
152
-
153
- def save(self, mode):
154
- print(f"Saving experiment with mode {mode}")
155
- if save_to_firebase(self.current_experiment, self.firebase_data_ref, mode=mode):
156
- self.next_experiment()
157
- self.build_components_from_experiment()
158
- return [*self.img_components, self.corrupted_component]
159
-
160
-
161
- #################################################################################################################################################
162
- # API calls
163
- #################################################################################################################################################
164
-
165
- def save_to_firebase(experiment: Experiment, firebase_data_ref, mode):
166
- if mode == "save" and (experiment is None or experiment.selected_image is None):
167
- gr.Warning("You must select an image before submitting")
168
- return False
169
- if mode == "skip":
170
- experiment.selected_image = None
171
-
172
- firebase_data_ref.push({
173
- **experiment.to_dict(custom_algo=mode if mode == "skip" else None),
174
- "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
175
- })
176
-
177
- gr.Info("Your choice has been saved to Firebase")
178
- return True
179
-
180
- def list_folders(service):
181
- folders = []
182
-
183
- corruptions = [
184
- "brightness",
185
- "elastic_transform",
186
- "frost",
187
- "impulse_noise",
188
- "masking_random_color",
189
- "shot_noise",
190
- "speckle_noise",
191
- "contrast",
192
- "fog",
193
- "gaussian_noise",
194
- "latent",
195
- "masking_vline_random_color",
196
- "spatter"
197
  ]
198
-
199
- name_query = "(" + " or ".join([f"name contains '{c}'" for c in corruptions]) + ")"
200
-
201
- results = (
202
- service.files()
203
- .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q=f"mimeType='application/vnd.google-apps.folder' and {name_query}")
204
- .execute()
205
- )
206
- folders.extend(results.get("files", []))
207
-
208
- while "nextPageToken" in results:
209
- page_token = results["nextPageToken"]
210
- results = (
211
- service.files()
212
- .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q=f"mimeType='application/vnd.google-apps.folder' and {name_query}", pageToken=page_token)
213
- .execute()
214
- )
215
- folders.extend(results.get("files", []))
216
-
217
- return folders
218
-
219
- def list_files_in_folder(service, folder, filter_=""):
220
- files = []
221
-
222
- results = (
223
- service.files()
224
- .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q=f"'{folder['id']}' in parents {'and ' + filter_ if filter_ else ''}")
225
- .execute()
226
  )
227
- files.extend(results.get("files", []))
228
-
229
- while "nextPageToken" in results:
230
- page_token = results["nextPageToken"]
231
- results = (
232
- service.files()
233
- .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q=f"'{folder['id']}' in parents {'and ' + filter_ if filter_ else ''}", pageToken=page_token)
234
- .execute()
235
- )
236
- files.extend(results.get("files", []))
237
-
238
- return files
239
-
240
- def download_file(service, file):
241
- request = service.files().get_media(fileId=file['id'])
242
- fh = io.BytesIO()
243
- downloader = MediaIoBaseDownload(fh, request)
244
- done = False
245
- while done is False:
246
- status, done = downloader.next_chunk(10)
247
- print(f"Download image: '{file['name']}' - progress: {int(status.progress() * 100)}.")
248
-
249
- return Image.open(fh)
250
-
251
-
252
- def schedule_downloads(service, q_to_download, q_requested):
253
- while True:
254
- print("Waiting for new experiment to schedule")
255
- if q_requested.get() is None:
256
- break
257
 
258
- print("Scheduling new experiment")
259
- # this should give a list of top-level folders of corruption types
260
- folders = list_folders(service)
261
- # print(f"Found {len(folders)} folders")
262
- # print(folders)
 
263
 
264
- # sample a random corruption from folders
265
- folder = random.choice(folders)
 
 
 
266
 
267
- # inside the corruption folders, there should be one folder per image id
268
- subfolders = list_files_in_folder(service, folder, filter_="mimeType='application/vnd.google-apps.folder'")
269
- # print(f"Found {len(subfolders)} subfolders")
270
- # print(subfolders)
271
 
272
- # sample a random image from subfolders
273
- subfolder = random.choice(subfolders)
274
 
275
- # list subsubfolders in the subfolder
276
- subsubfolders = list_files_in_folder(service, subfolder, filter_="mimeType='application/vnd.google-apps.folder'")
277
 
278
- # the results should be 2 subfolders: SDEdit and ODEdit
279
- odedit_subfolder = [f for f in subsubfolders if "ode" in f["name"]][0]
280
- sdedit_subfolder = [f for f in subsubfolders if "sde" in f["name"]][0]
281
 
282
- odedit_files = list_files_in_folder(service, odedit_subfolder)
283
- sdedit_files = list_files_in_folder(service, sdedit_subfolder)
284
-
285
- selected_odedit_files = random.sample(odedit_files, k=5)
286
- selected_odedit_files = [{**file, "algo": "ODEdit"} for file in selected_odedit_files]
287
 
288
- selected_sdedit_files = random.sample(sdedit_files, k=5)
289
- selected_sdedit_files = [{**file, "algo": "SDEdit"} for file in selected_sdedit_files]
290
 
291
- corrupted_file = list_files_in_folder(service, subfolder, filter_="mimeType contains 'image/'")[0]
292
-
293
- selected_files = [*selected_odedit_files, *selected_sdedit_files]
294
-
295
- experiment = Experiment(folder["name"], subfolder["name"], corrupted_file, selected_files)
296
 
297
- q_to_download.put(experiment)
298
- q_requested.task_done()
299
- print("Experiment scheduled")
300
 
301
- def download_thread(service, q_to_download, q_processed):
302
- while True:
303
- print("Waiting for experiment to download")
304
- experiment : Experiment = q_to_download.get()
305
- if experiment is None:
306
- break
307
-
308
- corrupted = experiment.corrupted
309
- print(f"Downloading file {corrupted['name']}")
310
- corrupted_pil = download_file(service, corrupted)
311
- print(f"File {corrupted['name']} downloaded")
312
- experiment.corrupted["pil"] = corrupted_pil
313
 
314
- for file in experiment.options:
315
- print(f"Downloading file {file['name']}")
316
- pil = download_file(service, file)
317
- print(f"File {file['name']} downloaded")
318
- file["pil"] = pil
 
 
319
 
320
- q_processed.put(experiment)
321
- q_to_download.task_done()
322
- print("Experiment downloaded")
323
-
324
-
325
- #################################################################################################################################################
326
- # UI
327
- #################################################################################################################################################
328
 
 
329
 
330
  css = """
331
- #unsel {border: solid 5px transparent !important; border-radius: 15px !important}
332
- #sel {border: solid 5px #00c0ff !important; border-radius: 15px !important}
333
- #corrupted {margin-left: 5%; margin-right: 5%; padding: 0 !important}
334
  #reducedHeight {height: 10px !important}
335
  #padded {padding-left: 2%; padding-right: 2%}
336
  """
337
 
338
- def build_demo():
339
- app = App()
340
-
341
- with gr.Blocks(title="Unsupervised Image Editing", css=css) as demo:
342
 
343
- with gr.Row(elem_id="padded"):
344
- corrupted_component = gr.Image(label="corr", elem_id="corrupted", show_label=False, show_download_button=False)
345
- with gr.Column(scale=3, elem_id="padded"):
346
- gr.Markdown("<div style='width: 100%'><h1 style='text-align: center; display: inline-block; width: 100%'>The sample on the left is a corrupted image</h1></div>")
347
- gr.Markdown("<div style='width: 100%'><h3 style='text-align: center; display: inline-block; width: 100%'>Below are decorrupted versions sampled from various models. Click on the picture you like best</h3></div>")
348
- btn_skip = gr.Button("I have no preference")
349
- btn_submit = gr.Button("Submit preference")
350
 
351
- img_components = []
 
352
  with gr.Row():
353
- for i, img in enumerate(app.img_components[:5]):
354
- img_components.append(gr.Image(label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False))
355
-
356
- with gr.Row():
357
- for i, img in enumerate(app.img_components[5:]):
358
- img_components.append(gr.Image(label=f"{i+5}", elem_id="unsel", show_label=False, show_download_button=False))
359
-
360
- btn_skip.click(partial(app.save, mode="skip"), None, [*img_components, corrupted_component])
361
- btn_submit.click(partial(app.save, mode="save"), None, [*img_components, corrupted_component])
362
- for img in img_components:
363
- img.select(app.on_select, None, img_components, show_progress="hidden")
364
-
365
- demo.load(app.build_components_from_experiment, inputs=None, outputs=[*img_components, corrupted_component])
366
 
367
- return demo, app
 
 
 
368
 
 
369
 
370
- if __name__ == "__main__":
371
- demo, app = build_demo()
372
- demo.launch(share=False, show_api=False, app_kwargs={"lifespan": app.lifespan})
 
 
1
  import datetime
2
  from functools import partial
 
 
3
  import json
4
+ from pathlib import Path
 
5
  import random
6
+ import gradio as gr
7
+ import os
8
  import firebase_admin
9
  from firebase_admin import db, credentials
10
+
11
+
12
+ ##################################################################
13
+ # Constants
14
+ ##################################################################
15
+
16
+
17
+ NUMBER_OF_IMAGES_PER_ROW = 7
18
+ NUMBER_OF_ROWS = 2
19
 
20
 
21
  #################################################################################################################################################
 
26
  # read secret api key
27
  FIREBASE_API_KEY = os.environ['FirebaseSecret']
28
  FIREBASE_URL = os.environ['FirebaseURL']
29
+ DATASET = os.environ['Dataset']
30
 
31
+ # init firebase service
32
+ firebase_creds = credentials.Certificate(json.loads(FIREBASE_API_KEY))
33
+ firebase_app = firebase_admin.initialize_app(firebase_creds, {'databaseURL': FIREBASE_URL})
34
+ firebase_data_ref = db.reference("data")
35
 
36
 
37
+ ##################################################################
38
+ # Data Layer
39
+ ##################################################################
40
 
41
 
42
+ class Experiment(dict):
43
+ def __init__(self, dataset, corruption, image_id, corrupted, options, selected_image=None):
44
+ super().__init__(
45
+ dataset=dataset,
46
+ corruption=corruption,
47
+ image_id=image_id,
48
+ corrupted=corrupted,
49
+ options=options,
50
+ selected_image=selected_image,
51
+ )
 
 
 
 
 
 
 
 
 
52
 
53
+ def experiment_to_dict(experiment, skip=False):
54
+ info = {
55
+ # experiment info
56
+ "dataset": experiment["dataset"],
57
+ "corruption": experiment["corruption"],
58
+ "image_number": experiment["image_id"],
59
+
60
+ # chosen image set info
61
+ "corrupted_filename": experiment["corrupted"]["name"],
62
+ "options": [img["name"] for img in experiment["options"]],
63
+ }
64
+
65
+ if skip:
66
+ info = {
67
+ **info,
68
+ # selected image info
69
+ "selected_image": "None",
70
+ "selected_algo": "None",
71
+ }
72
+ else:
73
+ info = {
74
+ **info,
75
+ # selected image info
76
+ "selected_image": experiment["options"][experiment["selected_image"]]["name"],
77
+ "selected_algo": experiment["options"][experiment["selected_image"]]["algo"],
78
+ }
79
 
80
+ return info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ def generate_new_experiment() -> Experiment:
83
+ corruption = random.choice(list(Path(f"./images/{DATASET}").glob("*")))
84
+ image_id = random.choice(list(corruption.glob("*")))
85
+ imgs_to_sample = (NUMBER_OF_IMAGES_PER_ROW * NUMBER_OF_ROWS) // 2
86
+
87
+ corrupted_image = {"name": str(random.choice(list(image_id.glob("*corrupted*"))))}
88
+ corrupted_ending_index = corrupted_image["name"].split(".")[0].split("_")[-1]
89
+ sdedit_images = [
90
+ {"name": str(img), "algo": f"SDEdit"}
91
+ for img in random.sample(list((image_id / "sde").glob(f"*_{corrupted_ending_index}*")), imgs_to_sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  ]
93
+ odedit_images = [
94
+ {"name": str(img), "algo": f"ODEdit"}
95
+ for img in random.sample(list((image_id / "ode").glob(f"*_{corrupted_ending_index}*")), imgs_to_sample)
96
+ ]
97
+ total_images = sdedit_images + odedit_images
98
+
99
+ return Experiment(
100
+ DATASET,
101
+ corruption.name,
102
+ image_id.name,
103
+ corrupted_image,
104
+ total_images,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ def save(experiment, corrupted_component, *img_components, mode):
108
+ if mode == "save" and (experiment is None or experiment["selected_image"] is None):
109
+ gr.Warning("You must select an image before submitting")
110
+ return [experiment, corrupted_component, *img_components]
111
+ if mode == "skip":
112
+ experiment["selected_image"] = None
113
 
114
+ dict_to_save = {
115
+ **experiment_to_dict(experiment, skip=(mode=="skip")),
116
+ "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
117
+ }
118
+ firebase_data_ref.push(dict_to_save)
119
 
120
+ print("=====================")
121
+ print(dict_to_save)
122
+ print("=====================")
 
123
 
124
+ gr.Info("Your choice has been saved to Firebase")
125
+ return next()
126
 
 
 
127
 
128
+ ##################################################################
129
+ # UI Layer
130
+ ##################################################################
131
 
 
 
 
 
 
132
 
133
+ def next():
134
+ new_experiment = generate_new_experiment()
135
 
136
+ new_img_components = [
137
+ gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
138
+ for i, img in enumerate(new_experiment["options"])
139
+ ]
140
+ new_corrupted_component = gr.Image(value=new_experiment["corrupted"]["name"], label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
141
 
142
+ return [new_experiment, new_corrupted_component, *new_img_components]
 
 
143
 
144
+ def on_select(evt: gr.SelectData, experiment, *img_components): # SelectData is a subclass of EventData
145
+ new_selected = int(evt.target.label)
 
 
 
 
 
 
 
 
 
 
146
 
147
+ new_img_components = [
148
+ gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
149
+ for i, img in enumerate(experiment["options"])
150
+ ]
151
+ new_img_components[new_selected] = (
152
+ gr.Image(value=experiment["options"][new_selected]["name"], label=f"{new_selected}", elem_id="sel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
153
+ )
154
 
155
+ experiment["selected_image"] = int(evt.target.label)
 
 
 
 
 
 
 
156
 
157
+ return [experiment, *new_img_components]
158
 
159
  css = """
160
+ #unsel {border: solid 5px transparent !important; border-radius: 15px !important; draggable: false}
161
+ #sel {border: solid 5px #00c0ff !important; border-radius: 15px !important; draggable: false}
162
+ #corrupted {margin-left: 5%; margin-right: 5%; padding: 0 !important; draggable: false}
163
  #reducedHeight {height: 10px !important}
164
  #padded {padding-left: 2%; padding-right: 2%}
165
  """
166
 
167
+ with gr.Blocks(title="Unsupervised Image Editing", css=css) as demo:
168
+ experiment = gr.State(generate_new_experiment())
 
 
169
 
170
+ with gr.Row(elem_id="padded"):
171
+ corrupted_component = gr.Image(label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
172
+ with gr.Column(scale=3, elem_id="padded"):
173
+ gr.Markdown("<div style='width: 100%'><h1 style='text-align: center; display: inline-block; width: 100%'>The sample on the left is a corrupted image</h1></div>")
174
+ gr.Markdown("<div style='width: 100%'><h3 style='text-align: center; display: inline-block; width: 100%'>Below are decorrupted versions sampled from various models. Click on the picture you like best.<br/>⚠️Do not pay attention to the background⚠️</h3></div>")
175
+ btn_skip = gr.Button("I have no preference")
176
+ btn_submit = gr.Button("Submit preference")
177
 
178
+ img_components = []
179
+ for row in range(NUMBER_OF_ROWS):
180
  with gr.Row():
181
+ for col in range(NUMBER_OF_IMAGES_PER_ROW):
182
+ img_components.append(gr.Image(label=f"{row * NUMBER_OF_IMAGES_PER_ROW + col}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False))
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ btn_skip.click(partial(save, mode="skip"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components])
185
+ btn_submit.click(partial(save, mode="save"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components])
186
+ for img in img_components:
187
+ img.select(on_select, [experiment, *img_components], [experiment, *img_components], show_progress="hidden")
188
 
189
+ demo.load(next, None, [experiment, corrupted_component, *img_components])
190
 
191
+ demo.launch()