|
import datetime |
|
from functools import partial |
|
import json |
|
from pathlib import Path |
|
import random |
|
import gradio as gr |
|
import os |
|
import firebase_admin |
|
from firebase_admin import db, credentials |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NUMBER_OF_IMAGES_PER_ROW = 7 |
|
NUMBER_OF_ROWS = 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FIREBASE_API_KEY = os.environ['FirebaseSecret'] |
|
FIREBASE_URL = os.environ['FirebaseURL'] |
|
DATASET = os.environ['Dataset'] |
|
|
|
|
|
firebase_creds = credentials.Certificate(json.loads(FIREBASE_API_KEY)) |
|
firebase_app = firebase_admin.initialize_app(firebase_creds, {'databaseURL': FIREBASE_URL}) |
|
firebase_data_ref = db.reference("data") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Experiment(dict): |
|
def __init__(self, dataset, corruption, image_id, corrupted, options, selected_image=None): |
|
super().__init__( |
|
dataset=dataset, |
|
corruption=corruption, |
|
image_id=image_id, |
|
corrupted=corrupted, |
|
options=options, |
|
selected_image=selected_image, |
|
) |
|
|
|
def experiment_to_dict(experiment, skip=False): |
|
info = { |
|
|
|
"dataset": experiment["dataset"], |
|
"corruption": experiment["corruption"], |
|
"image_number": experiment["image_id"], |
|
|
|
|
|
"corrupted_filename": experiment["corrupted"]["name"], |
|
"options": [img["name"] for img in experiment["options"]], |
|
} |
|
|
|
if skip: |
|
info = { |
|
**info, |
|
|
|
"selected_image": "None", |
|
"selected_algo": "None", |
|
} |
|
else: |
|
info = { |
|
**info, |
|
|
|
"selected_image": experiment["options"][experiment["selected_image"]]["name"], |
|
"selected_algo": experiment["options"][experiment["selected_image"]]["algo"], |
|
} |
|
|
|
return info |
|
|
|
def generate_new_experiment() -> Experiment: |
|
wanted_corruptions = ["spatter", "impulse_noise", "speckle_noise", "gaussian_noise", "pixelate", "jpeg_compression", "elastic_transform"] |
|
corruption = random.choice([f for f in list(Path(f"./images/{DATASET}").glob("*/*")) if f.is_dir() and f.name in wanted_corruptions]) |
|
image_id = random.choice(list(corruption.glob("*"))) |
|
imgs_to_sample = (NUMBER_OF_IMAGES_PER_ROW * NUMBER_OF_ROWS) // 2 |
|
|
|
corrupted_image = {"name": str(random.choice(list(image_id.glob("*corrupted*"))))} |
|
sdedit_images = [ |
|
{"name": str(img), "algo": "SDEdit"} |
|
for img in random.sample(list((image_id / "sde").glob(f"*")), imgs_to_sample) |
|
] |
|
odedit_images = [ |
|
{"name": str(img), "algo": "ODEdit"} |
|
for img in random.sample(list((image_id / "ode").glob(f"*")), imgs_to_sample) |
|
] |
|
total_images = sdedit_images + odedit_images |
|
random.shuffle(total_images) |
|
|
|
return Experiment( |
|
DATASET, |
|
corruption.name, |
|
image_id.name, |
|
corrupted_image, |
|
total_images, |
|
) |
|
|
|
def save(experiment, corrupted_component, *img_components, mode): |
|
if mode == "save" and (experiment is None or experiment["selected_image"] is None): |
|
gr.Warning("You must select an image before submitting") |
|
return [experiment, corrupted_component, *img_components] |
|
if mode == "skip": |
|
experiment["selected_image"] = None |
|
|
|
dict_to_save = { |
|
**experiment_to_dict(experiment, skip=(mode=="skip")), |
|
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
} |
|
firebase_data_ref.push(dict_to_save) |
|
|
|
print("=====================") |
|
print(dict_to_save) |
|
print("=====================") |
|
|
|
gr.Info("Your choice has been saved to Firebase") |
|
return next() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def next(): |
|
new_experiment = generate_new_experiment() |
|
|
|
new_img_components = [ |
|
gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False) |
|
for i, img in enumerate(new_experiment["options"]) |
|
] |
|
new_corrupted_component = gr.Image(value=new_experiment["corrupted"]["name"], label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False) |
|
|
|
return [new_experiment, new_corrupted_component, *new_img_components] |
|
|
|
def on_select(evt: gr.SelectData, experiment, *img_components): |
|
new_selected = int(evt.target.label) |
|
|
|
new_img_components = [ |
|
gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False) |
|
for i, img in enumerate(experiment["options"]) |
|
] |
|
new_img_components[new_selected] = ( |
|
gr.Image(value=experiment["options"][new_selected]["name"], label=f"{new_selected}", elem_id="sel", show_label=False, show_download_button=False, show_share_button=False, interactive=False) |
|
) |
|
|
|
experiment["selected_image"] = int(evt.target.label) |
|
|
|
return [experiment, *new_img_components] |
|
|
|
css = """ |
|
#unsel {border: solid 5px transparent !important; border-radius: 15px !important; draggable: false} |
|
#sel {border: solid 5px #00c0ff !important; border-radius: 15px !important; draggable: false} |
|
#corrupted {margin-left: 5%; margin-right: 5%; padding: 0 !important; draggable: false} |
|
#reducedHeight {height: 10px !important} |
|
#padded {padding-left: 2%; padding-right: 2%} |
|
""" |
|
|
|
with gr.Blocks(title="Unsupervised Image Editing", css=css) as demo: |
|
experiment = gr.State(generate_new_experiment()) |
|
|
|
with gr.Row(elem_id="padded"): |
|
corrupted_component = gr.Image(label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False) |
|
with gr.Column(scale=3, elem_id="padded"): |
|
gr.Markdown("<div style='width: 100%'><h1 style='text-align: center; display: inline-block; width: 100%'>The sample on the left is a corrupted image</h1></div>") |
|
gr.Markdown("<div style='width: 100%'><h3 style='text-align: center; display: inline-block; width: 100%'>Below are decorrupted versions sampled from various models. Click on the picture you like best.<br/>⚠️Do not pay attention to the background. Consider first fidelity, then quality⚠️</h3></div>") |
|
btn_skip = gr.Button("I have no preference") |
|
btn_submit = gr.Button("Submit preference") |
|
|
|
img_components = [] |
|
for row in range(NUMBER_OF_ROWS): |
|
with gr.Row(): |
|
for col in range(NUMBER_OF_IMAGES_PER_ROW): |
|
img_components.append(gr.Image(label=f"{row * NUMBER_OF_IMAGES_PER_ROW + col}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)) |
|
|
|
btn_skip.click(partial(save, mode="skip"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components]) |
|
btn_submit.click(partial(save, mode="save"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components]) |
|
for img in img_components: |
|
img.select(on_select, [experiment, *img_components], [experiment, *img_components], show_progress="hidden") |
|
|
|
demo.load(next, None, [experiment, corrupted_component, *img_components]) |
|
|
|
demo.launch() |