File size: 7,998 Bytes
fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 6e0bc6b dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 e6d4376 b6c70ac dbd36d0 00c8a92 33a4bc0 dbd36d0 00c8a92 dbd36d0 8191154 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 33a4bc0 dbd36d0 33a4bc0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 6e0bc6b dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 6e0bc6b dbd36d0 6e0bc6b dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 33a4bc0 fea6c63 dbd36d0 fea6c63 dbd36d0 1617002 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 6e0bc6b fea6c63 dbd36d0 fea6c63 dbd36d0 fea6c63 dbd36d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import datetime
from functools import partial
import json
from pathlib import Path
import random
import gradio as gr
import os
import firebase_admin
from firebase_admin import db, credentials
##################################################################
# Constants
##################################################################
NUMBER_OF_IMAGES_PER_ROW = 7
NUMBER_OF_ROWS = 2
#################################################################################################################################################
# Authentication
#################################################################################################################################################
# read secret api key
FIREBASE_API_KEY = os.environ['FirebaseSecret']
FIREBASE_URL = os.environ['FirebaseURL']
DATASET = os.environ['Dataset']
# init firebase service
firebase_creds = credentials.Certificate(json.loads(FIREBASE_API_KEY))
firebase_app = firebase_admin.initialize_app(firebase_creds, {'databaseURL': FIREBASE_URL})
firebase_data_ref = db.reference("data")
##################################################################
# Data Layer
##################################################################
class Experiment(dict):
def __init__(self, dataset, corruption, image_id, corrupted, options, selected_image=None):
super().__init__(
dataset=dataset,
corruption=corruption,
image_id=image_id,
corrupted=corrupted,
options=options,
selected_image=selected_image,
)
def experiment_to_dict(experiment, skip=False):
info = {
# experiment info
"dataset": experiment["dataset"],
"corruption": experiment["corruption"],
"image_number": experiment["image_id"],
# chosen image set info
"corrupted_filename": experiment["corrupted"]["name"],
"options": [img["name"] for img in experiment["options"]],
}
if skip:
info = {
**info,
# selected image info
"selected_image": "None",
"selected_algo": "None",
}
else:
info = {
**info,
# selected image info
"selected_image": experiment["options"][experiment["selected_image"]]["name"],
"selected_algo": experiment["options"][experiment["selected_image"]]["algo"],
}
return info
def generate_new_experiment() -> Experiment:
wanted_corruptions = ["spatter", "impulse_noise", "speckle_noise", "gaussian_noise", "pixelate", "jpeg_compression", "elastic_transform"]
corruption = random.choice([f for f in list(Path(f"./images/{DATASET}").glob("*/*")) if f.is_dir() and f.name in wanted_corruptions])
image_id = random.choice(list(corruption.glob("*")))
imgs_to_sample = (NUMBER_OF_IMAGES_PER_ROW * NUMBER_OF_ROWS) // 2
corrupted_image = {"name": str(random.choice(list(image_id.glob("*corrupted*"))))}
sdedit_images = [
{"name": str(img), "algo": "SDEdit"}
for img in random.sample(list((image_id / "sde").glob(f"*")), imgs_to_sample)
]
odedit_images = [
{"name": str(img), "algo": "ODEdit"}
for img in random.sample(list((image_id / "ode").glob(f"*")), imgs_to_sample)
]
total_images = sdedit_images + odedit_images
random.shuffle(total_images)
return Experiment(
DATASET,
corruption.name,
image_id.name,
corrupted_image,
total_images,
)
def save(experiment, corrupted_component, *img_components, mode):
if mode == "save" and (experiment is None or experiment["selected_image"] is None):
gr.Warning("You must select an image before submitting")
return [experiment, corrupted_component, *img_components]
if mode == "skip":
experiment["selected_image"] = None
dict_to_save = {
**experiment_to_dict(experiment, skip=(mode=="skip")),
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
firebase_data_ref.push(dict_to_save)
print("=====================")
print(dict_to_save)
print("=====================")
gr.Info("Your choice has been saved to Firebase")
return next()
##################################################################
# UI Layer
##################################################################
def next():
new_experiment = generate_new_experiment()
new_img_components = [
gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
for i, img in enumerate(new_experiment["options"])
]
new_corrupted_component = gr.Image(value=new_experiment["corrupted"]["name"], label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
return [new_experiment, new_corrupted_component, *new_img_components]
def on_select(evt: gr.SelectData, experiment, *img_components): # SelectData is a subclass of EventData
new_selected = int(evt.target.label)
new_img_components = [
gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
for i, img in enumerate(experiment["options"])
]
new_img_components[new_selected] = (
gr.Image(value=experiment["options"][new_selected]["name"], label=f"{new_selected}", elem_id="sel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
)
experiment["selected_image"] = int(evt.target.label)
return [experiment, *new_img_components]
css = """
#unsel {border: solid 5px transparent !important; border-radius: 15px !important; draggable: false}
#sel {border: solid 5px #00c0ff !important; border-radius: 15px !important; draggable: false}
#corrupted {margin-left: 5%; margin-right: 5%; padding: 0 !important; draggable: false}
#reducedHeight {height: 10px !important}
#padded {padding-left: 2%; padding-right: 2%}
"""
with gr.Blocks(title="Unsupervised Image Editing", css=css) as demo:
experiment = gr.State(generate_new_experiment())
with gr.Row(elem_id="padded"):
corrupted_component = gr.Image(label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
with gr.Column(scale=3, elem_id="padded"):
gr.Markdown("<div style='width: 100%'><h1 style='text-align: center; display: inline-block; width: 100%'>The sample on the left is a corrupted image</h1></div>")
gr.Markdown("<div style='width: 100%'><h3 style='text-align: center; display: inline-block; width: 100%'>Below are decorrupted versions sampled from various models. Click on the picture you like best.<br/>⚠️Do not pay attention to the background. Consider first fidelity, then quality⚠️</h3></div>")
btn_skip = gr.Button("I have no preference")
btn_submit = gr.Button("Submit preference")
img_components = []
for row in range(NUMBER_OF_ROWS):
with gr.Row():
for col in range(NUMBER_OF_IMAGES_PER_ROW):
img_components.append(gr.Image(label=f"{row * NUMBER_OF_IMAGES_PER_ROW + col}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False))
btn_skip.click(partial(save, mode="skip"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components])
btn_submit.click(partial(save, mode="save"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components])
for img in img_components:
img.select(on_select, [experiment, *img_components], [experiment, *img_components], show_progress="hidden")
demo.load(next, None, [experiment, corrupted_component, *img_components])
demo.launch() |