File size: 7,998 Bytes
fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 6e0bc6b dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 e6d4376 b6c70ac dbd36d0 00c8a92 33a4bc0 dbd36d0 00c8a92 dbd36d0 8191154 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 33a4bc0 dbd36d0 33a4bc0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 6e0bc6b dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 6e0bc6b dbd36d0 6e0bc6b dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 33a4bc0 fea6c63 dbd36d0 fea6c63 dbd36d0 1617002 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 6e0bc6b fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 |
|
import datetime
from functools import partial
import json
from pathlib import Path
import random
import gradio as gr
import os
import firebase_admin
from firebase_admin import db, credentials
##################################################################
# Constants
##################################################################
NUMBER_OF_IMAGES_PER_ROW = 7
NUMBER_OF_ROWS = 2
#################################################################################################################################################
# Authentication
#################################################################################################################################################
# read secret api key
FIREBASE_API_KEY = os.environ['FirebaseSecret']
FIREBASE_URL = os.environ['FirebaseURL']
DATASET = os.environ['Dataset']
# init firebase service
firebase_creds = credentials.Certificate(json.loads(FIREBASE_API_KEY))
firebase_app = firebase_admin.initialize_app(firebase_creds, {'databaseURL': FIREBASE_URL})
firebase_data_ref = db.reference("data")
##################################################################
# Data Layer
##################################################################
class Experiment(dict):
def __init__(self, dataset, corruption, image_id, corrupted, options, selected_image=None):
super().__init__(
dataset=dataset,
corruption=corruption,
image_id=image_id,
corrupted=corrupted,
options=options,
selected_image=selected_image,
)
def experiment_to_dict(experiment, skip=False):
info = {
# experiment info
"dataset": experiment["dataset"],
"corruption": experiment["corruption"],
"image_number": experiment["image_id"],
# chosen image set info
"corrupted_filename": experiment["corrupted"]["name"],
"options": [img["name"] for img in experiment["options"]],
}
if skip:
info = {
**info,
# selected image info
"selected_image": "None",
"selected_algo": "None",
}
else:
info = {
**info,
# selected image info
"selected_image": experiment["options"][experiment["selected_image"]]["name"],
"selected_algo": experiment["options"][experiment["selected_image"]]["algo"],
}
return info
def generate_new_experiment() -> Experiment:
wanted_corruptions = ["spatter", "impulse_noise", "speckle_noise", "gaussian_noise", "pixelate", "jpeg_compression", "elastic_transform"]
corruption = random.choice([f for f in list(Path(f"./images/{DATASET}").glob("*/*")) if f.is_dir() and f.name in wanted_corruptions])
image_id = random.choice(list(corruption.glob("*")))
imgs_to_sample = (NUMBER_OF_IMAGES_PER_ROW * NUMBER_OF_ROWS) // 2
corrupted_image = {"name": str(random.choice(list(image_id.glob("*corrupted*"))))}
sdedit_images = [
{"name": str(img), "algo": "SDEdit"}
for img in random.sample(list((image_id / "sde").glob(f"*")), imgs_to_sample)
]
odedit_images = [
{"name": str(img), "algo": "ODEdit"}
for img in random.sample(list((image_id / "ode").glob(f"*")), imgs_to_sample)
]
total_images = sdedit_images + odedit_images
random.shuffle(total_images)
return Experiment(
DATASET,
corruption.name,
image_id.name,
corrupted_image,
total_images,
)
def save(experiment, corrupted_component, *img_components, mode):
if mode == "save" and (experiment is None or experiment["selected_image"] is None):
gr.Warning("You must select an image before submitting")
return [experiment, corrupted_component, *img_components]
if mode == "skip":
experiment["selected_image"] = None
dict_to_save = {
**experiment_to_dict(experiment, skip=(mode=="skip")),
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
firebase_data_ref.push(dict_to_save)
print("=====================")
print(dict_to_save)
print("=====================")
gr.Info("Your choice has been saved to Firebase")
return next()
##################################################################
# UI Layer
##################################################################
def next():
new_experiment = generate_new_experiment()
new_img_components = [
gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
for i, img in enumerate(new_experiment["options"])
]
new_corrupted_component = gr.Image(value=new_experiment["corrupted"]["name"], label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
return [new_experiment, new_corrupted_component, *new_img_components]
def on_select(evt: gr.SelectData, experiment, *img_components): # SelectData is a subclass of EventData
new_selected = int(evt.target.label)
new_img_components = [
gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
for i, img in enumerate(experiment["options"])
]
new_img_components[new_selected] = (
gr.Image(value=experiment["options"][new_selected]["name"], label=f"{new_selected}", elem_id="sel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
)
experiment["selected_image"] = int(evt.target.label)
return [experiment, *new_img_components]
css = """
#unsel {border: solid 5px transparent !important; border-radius: 15px !important; draggable: false}
#sel {border: solid 5px #00c0ff !important; border-radius: 15px !important; draggable: false}
#corrupted {margin-left: 5%; margin-right: 5%; padding: 0 !important; draggable: false}
#reducedHeight {height: 10px !important}
#padded {padding-left: 2%; padding-right: 2%}
"""
with gr.Blocks(title="Unsupervised Image Editing", css=css) as demo:
experiment = gr.State(generate_new_experiment())
with gr.Row(elem_id="padded"):
corrupted_component = gr.Image(label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
with gr.Column(scale=3, elem_id="padded"):
gr.Markdown("<div style='width: 100%'><h1 style='text-align: center; display: inline-block; width: 100%'>The sample on the left is a corrupted image</h1></div>")
gr.Markdown("<div style='width: 100%'><h3 style='text-align: center; display: inline-block; width: 100%'>Below are decorrupted versions sampled from various models. Click on the picture you like best.<br/>⚠️Do not pay attention to the background. Consider first fidelity, then quality⚠️</h3></div>")
btn_skip = gr.Button("I have no preference")
btn_submit = gr.Button("Submit preference")
img_components = []
for row in range(NUMBER_OF_ROWS):
with gr.Row():
for col in range(NUMBER_OF_IMAGES_PER_ROW):
img_components.append(gr.Image(label=f"{row * NUMBER_OF_IMAGES_PER_ROW + col}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False))
btn_skip.click(partial(save, mode="skip"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components])
btn_submit.click(partial(save, mode="save"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components])
for img in img_components:
img.select(on_select, [experiment, *img_components], [experiment, *img_components], show_progress="hidden")
demo.load(next, None, [experiment, corrupted_component, *img_components])
demo.launch() |