tommymarto's picture
Update demo.py
6e0bc6b verified
import datetime
from functools import partial
import json
from pathlib import Path
import random
import gradio as gr
import os
import firebase_admin
from firebase_admin import db, credentials
##################################################################
# Constants
##################################################################
NUMBER_OF_IMAGES_PER_ROW = 7
NUMBER_OF_ROWS = 2
#################################################################################################################################################
# Authentication
#################################################################################################################################################
# read secret api key
FIREBASE_API_KEY = os.environ['FirebaseSecret']
FIREBASE_URL = os.environ['FirebaseURL']
DATASET = os.environ['Dataset']
# init firebase service
firebase_creds = credentials.Certificate(json.loads(FIREBASE_API_KEY))
firebase_app = firebase_admin.initialize_app(firebase_creds, {'databaseURL': FIREBASE_URL})
firebase_data_ref = db.reference("data")
##################################################################
# Data Layer
##################################################################
class Experiment(dict):
def __init__(self, dataset, corruption, image_id, corrupted, options, selected_image=None):
super().__init__(
dataset=dataset,
corruption=corruption,
image_id=image_id,
corrupted=corrupted,
options=options,
selected_image=selected_image,
)
def experiment_to_dict(experiment, skip=False):
info = {
# experiment info
"dataset": experiment["dataset"],
"corruption": experiment["corruption"],
"image_number": experiment["image_id"],
# chosen image set info
"corrupted_filename": experiment["corrupted"]["name"],
"options": [img["name"] for img in experiment["options"]],
}
if skip:
info = {
**info,
# selected image info
"selected_image": "None",
"selected_algo": "None",
}
else:
info = {
**info,
# selected image info
"selected_image": experiment["options"][experiment["selected_image"]]["name"],
"selected_algo": experiment["options"][experiment["selected_image"]]["algo"],
}
return info
def generate_new_experiment() -> Experiment:
wanted_corruptions = ["spatter", "impulse_noise", "speckle_noise", "gaussian_noise", "pixelate", "jpeg_compression", "elastic_transform"]
corruption = random.choice([f for f in list(Path(f"./images/{DATASET}").glob("*/*")) if f.is_dir() and f.name in wanted_corruptions])
image_id = random.choice(list(corruption.glob("*")))
imgs_to_sample = (NUMBER_OF_IMAGES_PER_ROW * NUMBER_OF_ROWS) // 2
corrupted_image = {"name": str(random.choice(list(image_id.glob("*corrupted*"))))}
sdedit_images = [
{"name": str(img), "algo": "SDEdit"}
for img in random.sample(list((image_id / "sde").glob(f"*")), imgs_to_sample)
]
odedit_images = [
{"name": str(img), "algo": "ODEdit"}
for img in random.sample(list((image_id / "ode").glob(f"*")), imgs_to_sample)
]
total_images = sdedit_images + odedit_images
random.shuffle(total_images)
return Experiment(
DATASET,
corruption.name,
image_id.name,
corrupted_image,
total_images,
)
def save(experiment, corrupted_component, *img_components, mode):
if mode == "save" and (experiment is None or experiment["selected_image"] is None):
gr.Warning("You must select an image before submitting")
return [experiment, corrupted_component, *img_components]
if mode == "skip":
experiment["selected_image"] = None
dict_to_save = {
**experiment_to_dict(experiment, skip=(mode=="skip")),
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
firebase_data_ref.push(dict_to_save)
print("=====================")
print(dict_to_save)
print("=====================")
gr.Info("Your choice has been saved to Firebase")
return next()
##################################################################
# UI Layer
##################################################################
def next():
new_experiment = generate_new_experiment()
new_img_components = [
gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
for i, img in enumerate(new_experiment["options"])
]
new_corrupted_component = gr.Image(value=new_experiment["corrupted"]["name"], label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
return [new_experiment, new_corrupted_component, *new_img_components]
def on_select(evt: gr.SelectData, experiment, *img_components): # SelectData is a subclass of EventData
new_selected = int(evt.target.label)
new_img_components = [
gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
for i, img in enumerate(experiment["options"])
]
new_img_components[new_selected] = (
gr.Image(value=experiment["options"][new_selected]["name"], label=f"{new_selected}", elem_id="sel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
)
experiment["selected_image"] = int(evt.target.label)
return [experiment, *new_img_components]
css = """
#unsel {border: solid 5px transparent !important; border-radius: 15px !important; draggable: false}
#sel {border: solid 5px #00c0ff !important; border-radius: 15px !important; draggable: false}
#corrupted {margin-left: 5%; margin-right: 5%; padding: 0 !important; draggable: false}
#reducedHeight {height: 10px !important}
#padded {padding-left: 2%; padding-right: 2%}
"""
with gr.Blocks(title="Unsupervised Image Editing", css=css) as demo:
experiment = gr.State(generate_new_experiment())
with gr.Row(elem_id="padded"):
corrupted_component = gr.Image(label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
with gr.Column(scale=3, elem_id="padded"):
gr.Markdown("<div style='width: 100%'><h1 style='text-align: center; display: inline-block; width: 100%'>The sample on the left is a corrupted image</h1></div>")
gr.Markdown("<div style='width: 100%'><h3 style='text-align: center; display: inline-block; width: 100%'>Below are decorrupted versions sampled from various models. Click on the picture you like best.<br/>⚠️Do not pay attention to the background. Consider first fidelity, then quality⚠️</h3></div>")
btn_skip = gr.Button("I have no preference")
btn_submit = gr.Button("Submit preference")
img_components = []
for row in range(NUMBER_OF_ROWS):
with gr.Row():
for col in range(NUMBER_OF_IMAGES_PER_ROW):
img_components.append(gr.Image(label=f"{row * NUMBER_OF_IMAGES_PER_ROW + col}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False))
btn_skip.click(partial(save, mode="skip"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components])
btn_submit.click(partial(save, mode="save"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components])
for img in img_components:
img.select(on_select, [experiment, *img_components], [experiment, *img_components], show_progress="hidden")
demo.load(next, None, [experiment, corrupted_component, *img_components])
demo.launch()