from concurrent.futures import ThreadPoolExecutor import datetime from functools import partial import os import io import json import queue import gradio as gr import random import firebase_admin from firebase_admin import db, credentials from google.oauth2 import service_account from googleapiclient.discovery import build import google_auth_httplib2 import googleapiclient from googleapiclient.http import MediaIoBaseDownload from PIL import Image import httplib2 ################################################################################################################################################# # Authentication ################################################################################################################################################# # read secret api key FIREBASE_API_KEY = os.environ['FirebaseSecret'] FIREBASE_URL = os.environ['FirebaseURL'] DRIVE_API_KEY = os.environ['DriveSecret'] SCOPES = ["https://www.googleapis.com/auth/drive"] ################################################################################################################################################# # Types ################################################################################################################################################# class Experiment(): def __init__(self, corruption, image_id, corrupted, options, selected_image=None, initialized=False): self.corruption = corruption self.image_id = image_id self.corrupted = corrupted self.options = options self.selected_image = selected_image self.initialized = initialized def to_dict(self, custom_algo=None): return { "corruption": self.corruption, "image_number": self.image_id, "dataset": "CelebA", "corrupted_filename": self.corrupted["name"], "options": [img["name"] for img in self.options], "selected_image": "None" if custom_algo else self.selected_image, "algo": custom_algo if custom_algo is not None else self.options[self.selected_image]["algo"] } def to_pil(self): return self.corrupted["pil"], [img["pil"] for img in self.options] @staticmethod def from_dict(source): return Experiment(source["name"], source["corrupted"], source["options"]) def __repr__(self): return f"Experiment(name={self.name}, corrupted={self.corrupted}, options={self.options})" def __str__(self): return f"Experiment(name={self.name}, corrupted={self.corrupted}, options={self.options})" def __eq__(self, other): return self.name == other.name and self.corrupted == other.corrupted and self.options == other.options class App(): NUM_THREADS = 16 NUM_TO_SCHEDULE = 10 def __init__(self): self.init_remote() self.init_download_thread() for _ in range(App.NUM_TO_SCHEDULE): self.q_requested.put({}) self.next_experiment() self.build_components_from_experiment() def lifespan(self, fastapi_app): yield # cancel thredpool self.executor.shutdown(wait=False) # cancel download threads for _ in range(App.NUM_THREADS): self.q_to_download.put(None) self.q_requested.put(None) def init_remote(self): def build_request(http, *args, **kwargs): new_http = google_auth_httplib2.AuthorizedHttp(self.drive_creds, http=httplib2.Http()) return googleapiclient.http.HttpRequest(new_http, *args, **kwargs) # init drive service self.drive_creds = service_account.Credentials.from_service_account_info(json.loads(DRIVE_API_KEY), scopes=SCOPES) authorized_http = google_auth_httplib2.AuthorizedHttp(self.drive_creds, http=httplib2.Http()) self.drive_service = build("drive", "v3", requestBuilder=build_request, http=authorized_http) # init firebase service self.firebase_creds = credentials.Certificate(json.loads(FIREBASE_API_KEY)) self.firebase_app = firebase_admin.initialize_app(self.firebase_creds, {'databaseURL': FIREBASE_URL}) self.firebase_data_ref = db.reference("data") def init_download_thread(self): # init download thread and queue self.q_requested = queue.Queue() self.q_to_download = queue.Queue() self.q_processed = queue.Queue() self.executor = ThreadPoolExecutor(max_workers=2*App.NUM_THREADS) for _ in range(App.NUM_THREADS): self.executor.submit(download_thread, self.drive_service, self.q_to_download, self.q_processed) self.executor.submit(schedule_downloads, self.drive_service, self.q_to_download, self.q_requested) def next_experiment(self): self.q_requested.put({}) self.current_experiment : Experiment = self.q_processed.get() def build_components_from_experiment(self): corrupted = self.current_experiment.corrupted images = self.current_experiment.options self.corrupted_component = gr.Image(value=corrupted["pil"], label=self.current_experiment.corruption, show_label=True, show_download_button=False, elem_id="corrupted") self.img_components = [ gr.Image(value=img["pil"], label=f"{i}", show_label=False, show_download_button=False, elem_id="unsel") for i, img in enumerate(images) ] selected_index = self.current_experiment.selected_image if selected_index is not None: self.img_components[selected_index] = ( gr.Image(value=images[selected_index]["pil"], label=f"{selected_index}", show_label=False, show_download_button=False, elem_id="sel") ) return [*self.img_components, self.corrupted_component] def on_select(self, evt: gr.SelectData): # SelectData is a subclass of EventData self.current_experiment.selected_image = int(evt.target.label) return self.build_components_from_experiment() def save(self, mode): print(f"Saving experiment with mode {mode}") if save_to_firebase(self.current_experiment, self.firebase_data_ref, mode=mode): self.next_experiment() self.build_components_from_experiment() return [*self.img_components, self.corrupted_component] ################################################################################################################################################# # API calls ################################################################################################################################################# def save_to_firebase(experiment: Experiment, firebase_data_ref, mode): if mode == "save" and (experiment is None or experiment.selected_image is None): gr.Warning("You must select an image before submitting") return False if mode == "skip": experiment.selected_image = None firebase_data_ref.push({ **experiment.to_dict(custom_algo=mode if mode == "skip" else None), "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), }) gr.Info("Your choice has been saved to Firebase") return True def list_folders(service): folders = [] corruptions = [ "brightness", "elastic_transform", "frost", "impulse_noise", "masking_random_color", "shot_noise", "speckle_noise", "contrast", "fog", "gaussian_noise", "latent", "masking_vline_random_color", "spatter" ] name_query = "(" + " or ".join([f"name contains '{c}'" for c in corruptions]) + ")" results = ( service.files() .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q=f"mimeType='application/vnd.google-apps.folder' and {name_query}") .execute() ) folders.extend(results.get("files", [])) while "nextPageToken" in results: page_token = results["nextPageToken"] results = ( service.files() .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q=f"mimeType='application/vnd.google-apps.folder' and {name_query}", pageToken=page_token) .execute() ) folders.extend(results.get("files", [])) return folders def list_files_in_folder(service, folder, filter_=""): files = [] results = ( service.files() .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q=f"'{folder['id']}' in parents {'and ' + filter_ if filter_ else ''}") .execute() ) files.extend(results.get("files", [])) while "nextPageToken" in results: page_token = results["nextPageToken"] results = ( service.files() .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q=f"'{folder['id']}' in parents {'and ' + filter_ if filter_ else ''}", pageToken=page_token) .execute() ) files.extend(results.get("files", [])) return files def download_file(service, file): request = service.files().get_media(fileId=file['id']) fh = io.BytesIO() downloader = MediaIoBaseDownload(fh, request) done = False while done is False: status, done = downloader.next_chunk(10) print(f"Download image: '{file['name']}' - progress: {int(status.progress() * 100)}.") return Image.open(fh) def schedule_downloads(service, q_to_download, q_requested): while True: print("Waiting for new experiment to schedule") if q_requested.get() is None: break print("Scheduling new experiment") # this should give a list of top-level folders of corruption types folders = list_folders(service) # print(f"Found {len(folders)} folders") # print(folders) # sample a random corruption from folders folder = random.choice(folders) # inside the corruption folders, there should be one folder per image id subfolders = list_files_in_folder(service, folder, filter_="mimeType='application/vnd.google-apps.folder'") # print(f"Found {len(subfolders)} subfolders") # print(subfolders) # sample a random image from subfolders subfolder = random.choice(subfolders) # list subsubfolders in the subfolder subsubfolders = list_files_in_folder(service, subfolder, filter_="mimeType='application/vnd.google-apps.folder'") # the results should be 2 subfolders: SDEdit and ODEdit odedit_subfolder = [f for f in subsubfolders if "ode" in f["name"]][0] sdedit_subfolder = [f for f in subsubfolders if "sde" in f["name"]][0] odedit_files = list_files_in_folder(service, odedit_subfolder) sdedit_files = list_files_in_folder(service, sdedit_subfolder) selected_odedit_files = random.sample(odedit_files, k=5) selected_odedit_files = [{**file, "algo": "ODEdit"} for file in selected_odedit_files] selected_sdedit_files = random.sample(sdedit_files, k=5) selected_sdedit_files = [{**file, "algo": "SDEdit"} for file in selected_sdedit_files] corrupted_file = list_files_in_folder(service, subfolder, filter_="mimeType contains 'image/'")[0] selected_files = [*selected_odedit_files, *selected_sdedit_files] experiment = Experiment(folder["name"], subfolder["name"], corrupted_file, selected_files) q_to_download.put(experiment) q_requested.task_done() print("Experiment scheduled") def download_thread(service, q_to_download, q_processed): while True: print("Waiting for experiment to download") experiment : Experiment = q_to_download.get() if experiment is None: break corrupted = experiment.corrupted print(f"Downloading file {corrupted['name']}") corrupted_pil = download_file(service, corrupted) print(f"File {corrupted['name']} downloaded") experiment.corrupted["pil"] = corrupted_pil for file in experiment.options: print(f"Downloading file {file['name']}") pil = download_file(service, file) print(f"File {file['name']} downloaded") file["pil"] = pil q_processed.put(experiment) q_to_download.task_done() print("Experiment downloaded") ################################################################################################################################################# # UI ################################################################################################################################################# css = """ #unsel {border: solid 5px transparent !important; border-radius: 15px !important} #sel {border: solid 5px #00c0ff !important; border-radius: 15px !important} #corrupted {margin-left: 5%; margin-right: 5%; padding: 0 !important} #reducedHeight {height: 10px !important} #padded {padding-left: 2%; padding-right: 2%} """ def build_demo(): app = App() with gr.Blocks(title="Unsupervised Image Editing", css=css) as demo: with gr.Row(elem_id="padded"): corrupted_component = gr.Image(label="corr", elem_id="corrupted", show_label=False, show_download_button=False) with gr.Column(scale=3, elem_id="padded"): gr.Markdown("