tommymarto's picture
updated demo
8191154
raw
history blame
7.92 kB
import datetime
from functools import partial
import json
from pathlib import Path
import random
import gradio as gr
import os
import firebase_admin
from firebase_admin import db, credentials
##################################################################
# Constants
##################################################################
NUMBER_OF_IMAGES_PER_ROW = 7
NUMBER_OF_ROWS = 2
#################################################################################################################################################
# Authentication
#################################################################################################################################################
# read secret api key
FIREBASE_API_KEY = os.environ['FirebaseSecret']
FIREBASE_URL = os.environ['FirebaseURL']
DATASET = os.environ['Dataset']
# init firebase service
firebase_creds = credentials.Certificate(json.loads(FIREBASE_API_KEY))
firebase_app = firebase_admin.initialize_app(firebase_creds, {'databaseURL': FIREBASE_URL})
firebase_data_ref = db.reference("data")
##################################################################
# Data Layer
##################################################################
class Experiment(dict):
def __init__(self, dataset, corruption, image_id, corrupted, options, selected_image=None):
super().__init__(
dataset=dataset,
corruption=corruption,
image_id=image_id,
corrupted=corrupted,
options=options,
selected_image=selected_image,
)
def experiment_to_dict(experiment, skip=False):
info = {
# experiment info
"dataset": experiment["dataset"],
"corruption": experiment["corruption"],
"image_number": experiment["image_id"],
# chosen image set info
"corrupted_filename": experiment["corrupted"]["name"],
"options": [img["name"] for img in experiment["options"]],
}
if skip:
info = {
**info,
# selected image info
"selected_image": "None",
"selected_algo": "None",
}
else:
info = {
**info,
# selected image info
"selected_image": experiment["options"][experiment["selected_image"]]["name"],
"selected_algo": experiment["options"][experiment["selected_image"]]["algo"],
}
return info
def generate_new_experiment() -> Experiment:
corruption = random.choice([f for f in list(Path(f"./images/{DATASET}").glob("*/*")) if f.is_dir()])
image_id = random.choice(list(corruption.glob("*")))
imgs_to_sample = (NUMBER_OF_IMAGES_PER_ROW * NUMBER_OF_ROWS) // 2
corrupted_image = {"name": str(random.choice(list(image_id.glob("*corrupted*"))))}
corrupted_ending_index = corrupted_image["name"].split(".")[0].split("_")[-1]
sdedit_images = [
{"name": str(img), "algo": f"SDEdit"}
for img in random.sample(list((image_id / "sde").glob(f"*_{corrupted_ending_index}*")), imgs_to_sample)
]
odedit_images = [
{"name": str(img), "algo": f"ODEdit"}
for img in random.sample(list((image_id / "ode").glob(f"*_{corrupted_ending_index}*")), imgs_to_sample)
]
total_images = sdedit_images + odedit_images
random.shuffle(total_images)
return Experiment(
DATASET,
corruption.name,
image_id.name,
corrupted_image,
total_images,
)
def save(experiment, corrupted_component, *img_components, mode):
if mode == "save" and (experiment is None or experiment["selected_image"] is None):
gr.Warning("You must select an image before submitting")
return [experiment, corrupted_component, *img_components]
if mode == "skip":
experiment["selected_image"] = None
dict_to_save = {
**experiment_to_dict(experiment, skip=(mode=="skip")),
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
firebase_data_ref.push(dict_to_save)
print("=====================")
print(dict_to_save)
print("=====================")
gr.Info("Your choice has been saved to Firebase")
return next()
##################################################################
# UI Layer
##################################################################
def next():
new_experiment = generate_new_experiment()
new_img_components = [
gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
for i, img in enumerate(new_experiment["options"])
]
new_corrupted_component = gr.Image(value=new_experiment["corrupted"]["name"], label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
return [new_experiment, new_corrupted_component, *new_img_components]
def on_select(evt: gr.SelectData, experiment, *img_components): # SelectData is a subclass of EventData
new_selected = int(evt.target.label)
new_img_components = [
gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
for i, img in enumerate(experiment["options"])
]
new_img_components[new_selected] = (
gr.Image(value=experiment["options"][new_selected]["name"], label=f"{new_selected}", elem_id="sel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
)
experiment["selected_image"] = int(evt.target.label)
return [experiment, *new_img_components]
css = """
#unsel {border: solid 5px transparent !important; border-radius: 15px !important; draggable: false}
#sel {border: solid 5px #00c0ff !important; border-radius: 15px !important; draggable: false}
#corrupted {margin-left: 5%; margin-right: 5%; padding: 0 !important; draggable: false}
#reducedHeight {height: 10px !important}
#padded {padding-left: 2%; padding-right: 2%}
"""
with gr.Blocks(title="Unsupervised Image Editing", css=css) as demo:
experiment = gr.State(generate_new_experiment())
with gr.Row(elem_id="padded"):
corrupted_component = gr.Image(label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
with gr.Column(scale=3, elem_id="padded"):
gr.Markdown("<div style='width: 100%'><h1 style='text-align: center; display: inline-block; width: 100%'>The sample on the left is a corrupted image</h1></div>")
gr.Markdown("<div style='width: 100%'><h3 style='text-align: center; display: inline-block; width: 100%'>Below are decorrupted versions sampled from various models. Click on the picture you like best.<br/>⚠️Do not pay attention to the background⚠️</h3></div>")
btn_skip = gr.Button("I have no preference")
btn_submit = gr.Button("Submit preference")
img_components = []
for row in range(NUMBER_OF_ROWS):
with gr.Row():
for col in range(NUMBER_OF_IMAGES_PER_ROW):
img_components.append(gr.Image(label=f"{row * NUMBER_OF_IMAGES_PER_ROW + col}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False))
btn_skip.click(partial(save, mode="skip"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components])
btn_submit.click(partial(save, mode="save"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components])
for img in img_components:
img.select(on_select, [experiment, *img_components], [experiment, *img_components], show_progress="hidden")
demo.load(next, None, [experiment, corrupted_component, *img_components])
demo.launch()