|
from concurrent.futures import ThreadPoolExecutor |
|
import datetime |
|
from functools import partial |
|
import os |
|
import io |
|
import json |
|
import queue |
|
import gradio as gr |
|
import random |
|
import firebase_admin |
|
from firebase_admin import db, credentials |
|
from google.oauth2 import service_account |
|
from googleapiclient.discovery import build |
|
import google_auth_httplib2 |
|
import googleapiclient |
|
from googleapiclient.http import MediaIoBaseDownload |
|
from PIL import Image |
|
import httplib2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FIREBASE_API_KEY = os.environ['FirebaseSecret'] |
|
FIREBASE_URL = os.environ['FirebaseURL'] |
|
DRIVE_API_KEY = os.environ['DriveSecret'] |
|
|
|
SCOPES = ["https://www.googleapis.com/auth/drive"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Experiment(): |
|
def __init__(self, name, corrupted, options, selected_image=None, initialized=False): |
|
self.name = name |
|
self.corrupted = corrupted |
|
self.options = options |
|
self.selected_image = selected_image |
|
self.initialized = initialized |
|
|
|
def to_dict(self): |
|
return { |
|
"experiment_name": self.name, |
|
"corrupted": self.corrupted["name"], |
|
"options": [img["name"] for img in self.options], |
|
"selected_image": self.selected_image, |
|
"algo": self.options[self.selected_image]["algo"] |
|
} |
|
|
|
def to_pil(self): |
|
return self.corrupted["pil"], [img["pil"] for img in self.options] |
|
|
|
@staticmethod |
|
def from_dict(source): |
|
return Experiment(source["name"], source["corrupted"], source["options"]) |
|
|
|
def __repr__(self): |
|
return f"Experiment(name={self.name}, corrupted={self.corrupted}, options={self.options})" |
|
|
|
def __str__(self): |
|
return f"Experiment(name={self.name}, corrupted={self.corrupted}, options={self.options})" |
|
|
|
def __eq__(self, other): |
|
return self.name == other.name and self.corrupted == other.corrupted and self.options == other.options |
|
|
|
|
|
class App(): |
|
|
|
NUM_THREADS = 8 |
|
NUM_TO_SCHEDULE = 8 |
|
|
|
def __init__(self): |
|
self.init_remote() |
|
self.init_download_thread() |
|
|
|
for _ in range(App.NUM_TO_SCHEDULE): |
|
self.q_requested.put({}) |
|
|
|
self.next_experiment() |
|
self.build_components_from_experiment() |
|
|
|
|
|
def lifespan(self, fastapi_app): |
|
yield |
|
|
|
self.executor.shutdown(wait=False) |
|
|
|
for _ in range(App.NUM_THREADS): |
|
self.q_to_download.put(None) |
|
self.q_requested.put(None) |
|
|
|
def init_remote(self): |
|
|
|
def build_request(http, *args, **kwargs): |
|
new_http = google_auth_httplib2.AuthorizedHttp(self.drive_creds, http=httplib2.Http()) |
|
return googleapiclient.http.HttpRequest(new_http, *args, **kwargs) |
|
|
|
|
|
self.drive_creds = service_account.Credentials.from_service_account_info(json.loads(DRIVE_API_KEY), scopes=SCOPES) |
|
authorized_http = google_auth_httplib2.AuthorizedHttp(self.drive_creds, http=httplib2.Http()) |
|
self.drive_service = build("drive", "v3", requestBuilder=build_request, http=authorized_http) |
|
|
|
|
|
self.firebase_creds = credentials.Certificate(json.loads(FIREBASE_API_KEY)) |
|
self.firebase_app = firebase_admin.initialize_app(self.firebase_creds, {'databaseURL': FIREBASE_URL}) |
|
self.firebase_data_ref = db.reference("data") |
|
|
|
def init_download_thread(self): |
|
|
|
self.q_requested = queue.Queue() |
|
self.q_to_download = queue.Queue() |
|
self.q_processed = queue.Queue() |
|
self.executor = ThreadPoolExecutor(max_workers=2*App.NUM_THREADS) |
|
for _ in range(App.NUM_THREADS): |
|
self.executor.submit(download_thread, self.drive_service, self.q_to_download, self.q_processed) |
|
self.executor.submit(schedule_downloads, self.drive_service, self.q_to_download, self.q_requested) |
|
|
|
def next_experiment(self): |
|
self.q_requested.put({}) |
|
self.current_experiment : Experiment = self.q_processed.get() |
|
|
|
def build_components_from_experiment(self): |
|
corrupted = self.current_experiment.corrupted |
|
images = self.current_experiment.options |
|
|
|
self.corrupted_component = gr.Image(value=corrupted["pil"], label="corr", show_label=True, show_download_button=False, elem_id="padded") |
|
self.img_components = [ |
|
gr.Image(value=img["pil"], label=f"{i}", show_label=True, show_download_button=False, elem_id="unsel") |
|
for i, img in enumerate(images) |
|
] |
|
|
|
selected_index = self.current_experiment.selected_image |
|
if selected_index is not None: |
|
self.img_components[selected_index] = ( |
|
gr.Image(value=images[selected_index]["pil"], label=f"{selected_index}", show_label=True, show_download_button=False, elem_id="sel") |
|
) |
|
|
|
return [*self.img_components, self.corrupted_component] |
|
|
|
def on_select(self, evt: gr.SelectData): |
|
self.current_experiment.selected_image = int(evt.target.label) |
|
return self.build_components_from_experiment() |
|
|
|
def save(self): |
|
if save_to_firebase(self.current_experiment, self.firebase_data_ref): |
|
self.next_experiment() |
|
self.build_components_from_experiment() |
|
return [*self.img_components, self.corrupted_component] |
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_to_firebase(experiment, firebase_data_ref): |
|
if experiment is None or experiment.selected_image is None: |
|
gr.Warning("You must select an image before submitting") |
|
return False |
|
|
|
firebase_data_ref.push({ |
|
**experiment.to_dict(), |
|
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
}) |
|
|
|
gr.Info("Your choice has been saved to Firebase") |
|
return True |
|
|
|
def list_folders(service): |
|
folders = [] |
|
|
|
results = ( |
|
service.files() |
|
.list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q="mimeType='application/vnd.google-apps.folder' and name contains 'Experiment'") |
|
.execute() |
|
) |
|
folders.extend(results.get("files", [])) |
|
|
|
while "nextPageToken" in results: |
|
page_token = results["nextPageToken"] |
|
results = ( |
|
service.files() |
|
.list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q="mimeType='application/vnd.google-apps.folder' and name contains 'Experiment'", pageToken=page_token) |
|
.execute() |
|
) |
|
folders.extend(results.get("files", [])) |
|
|
|
return folders |
|
|
|
def list_files_in_folder(service, folder, filter_=""): |
|
files = [] |
|
|
|
results = ( |
|
service.files() |
|
.list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q=f"'{folder['id']}' in parents {'and ' + filter_ if filter_ else ''}") |
|
.execute() |
|
) |
|
files.extend(results.get("files", [])) |
|
|
|
while "nextPageToken" in results: |
|
page_token = results["nextPageToken"] |
|
results = ( |
|
service.files() |
|
.list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q=f"'{folder['id']}' in parents {'and ' + filter_ if filter_ else ''}", pageToken=page_token) |
|
.execute() |
|
) |
|
files.extend(results.get("files", [])) |
|
|
|
return files |
|
|
|
def download_file(service, file): |
|
request = service.files().get_media(fileId=file['id']) |
|
fh = io.BytesIO() |
|
downloader = MediaIoBaseDownload(fh, request) |
|
done = False |
|
while done is False: |
|
status, done = downloader.next_chunk(10) |
|
print(f"Download image: '{file['name']}' - progress: {int(status.progress() * 100)}.") |
|
|
|
return Image.open(fh) |
|
|
|
|
|
def schedule_downloads(service, q_to_download, q_requested): |
|
while True: |
|
print("Waiting for new experiment to schedule") |
|
if q_requested.get() is None: |
|
break |
|
|
|
print("Scheduling new experiment") |
|
folders = list_folders(service) |
|
|
|
|
|
folder = random.choice(folders) |
|
|
|
|
|
subfolders = list_files_in_folder(service, folder, filter_="mimeType='application/vnd.google-apps.folder'") |
|
|
|
|
|
odedit_subfolder = [subfolder for subfolder in subfolders if "ODEdit" in subfolder["name"]][0] |
|
sdedit_subfolder = [subfolder for subfolder in subfolders if "SDEdit" in subfolder["name"]][0] |
|
|
|
odedit_files = list_files_in_folder(service, odedit_subfolder) |
|
sdedit_files = list_files_in_folder(service, sdedit_subfolder) |
|
|
|
selected_odedit_files = random.sample(odedit_files, k=5) |
|
selected_odedit_files = [{**file, "algo": "ODEdit"} for file in selected_odedit_files] |
|
|
|
selected_sdedit_files = random.sample(sdedit_files, k=5) |
|
selected_sdedit_files = [{**file, "algo": "SDEdit"} for file in selected_sdedit_files] |
|
|
|
corrupted_file = list_files_in_folder(service, folder, filter_="mimeType contains 'image/'")[0] |
|
|
|
selected_files = [*selected_odedit_files, *selected_sdedit_files] |
|
|
|
experiment = Experiment(folder["name"], corrupted_file, selected_files) |
|
|
|
q_to_download.put(experiment) |
|
q_requested.task_done() |
|
print("Experiment scheduled") |
|
|
|
def download_thread(service, q_to_download, q_processed): |
|
while True: |
|
print("Waiting for experiment to download") |
|
experiment : Experiment = q_to_download.get() |
|
if experiment is None: |
|
break |
|
|
|
corrupted = experiment.corrupted |
|
print(f"Downloading file {corrupted['name']}") |
|
corrupted_pil = download_file(service, corrupted) |
|
print(f"File {corrupted['name']} downloaded") |
|
experiment.corrupted["pil"] = corrupted_pil |
|
|
|
for file in experiment.options: |
|
print(f"Downloading file {file['name']}") |
|
pil = download_file(service, file) |
|
print(f"File {file['name']} downloaded") |
|
file["pil"] = pil |
|
|
|
q_processed.put(experiment) |
|
q_to_download.task_done() |
|
print("Experiment downloaded") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
css = """ |
|
#unsel {border: solid 5px transparent !important; border-radius: 15px !important} |
|
#sel {border: solid 5px #00c0ff !important; border-radius: 15px !important} |
|
#padded {margin-left: 25% !important; margin-right: 5% !important} |
|
#paddedRight {margin-right: 5% !important} |
|
""" |
|
|
|
def build_demo(): |
|
app = App() |
|
|
|
with gr.Blocks(title="Unsupervised Image Editing", css=css) as demo: |
|
|
|
with gr.Row(): |
|
corrupted_component = gr.Image(label="corr", elem_id="padded") |
|
with gr.Column(scale=3): |
|
gr.Markdown("<div style='width: 100%'><h1 style='text-align: center; display: inline-block; width: 100%'>The sample on the left is a corrupted image</h1></div>", elem_id="paddedRight") |
|
gr.Markdown("<div style='width: 100%'><h3 style='text-align: center; display: inline-block; width: 100%'>Below are decorrupted versions sampled from various models. Click on the picture you like best</h3></div>", elem_id="paddedRight") |
|
btn = gr.Button("Submit") |
|
gr.Markdown("<hr>") |
|
|
|
img_components = [] |
|
with gr.Row(): |
|
for i, img in enumerate(app.img_components[:5]): |
|
img_components.append(gr.Image(label=f"{i}", elem_id="unsel")) |
|
|
|
with gr.Row(): |
|
for i, img in enumerate(app.img_components[5:]): |
|
img_components.append(gr.Image(label=f"{i+5}", elem_id="unsel")) |
|
|
|
btn.click(app.save, None, [*img_components, corrupted_component]) |
|
for img in img_components: |
|
img.select(app.on_select, None, img_components, show_progress="hidden") |
|
|
|
demo.load(app.build_components_from_experiment, inputs=None, outputs=[*img_components, corrupted_component]) |
|
|
|
return demo, app |
|
|
|
|
|
if __name__ == "__main__": |
|
demo, app = build_demo() |
|
demo.launch(share=False, show_api=False, app_kwargs={"lifespan": app.lifespan}) |