File size: 7,998 Bytes
fea6c63
 
 
dbd36d0
fea6c63
dbd36d0
 
fea6c63
 
dbd36d0
 
 
 
 
 
 
6e0bc6b
dbd36d0
fea6c63
 
 
 
 
 
 
 
 
 
dbd36d0
fea6c63
dbd36d0
 
 
 
fea6c63
 
dbd36d0
 
 
fea6c63
 
dbd36d0
 
 
 
 
 
 
 
 
 
fea6c63
dbd36d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fea6c63
dbd36d0
fea6c63
dbd36d0
e6d4376
b6c70ac
dbd36d0
 
 
 
 
00c8a92
 
33a4bc0
dbd36d0
00c8a92
 
dbd36d0
 
8191154
dbd36d0
 
 
 
 
 
 
fea6c63
 
dbd36d0
 
 
 
 
 
fea6c63
dbd36d0
 
 
 
 
fea6c63
dbd36d0
 
 
33a4bc0
dbd36d0
 
33a4bc0
fea6c63
dbd36d0
 
 
fea6c63
 
dbd36d0
 
fea6c63
dbd36d0
6e0bc6b
dbd36d0
 
 
fea6c63
dbd36d0
fea6c63
dbd36d0
 
fea6c63
dbd36d0
6e0bc6b
dbd36d0
 
 
6e0bc6b
dbd36d0
fea6c63
dbd36d0
fea6c63
dbd36d0
fea6c63
 
dbd36d0
 
 
33a4bc0
 
fea6c63
 
dbd36d0
 
fea6c63
dbd36d0
 
 
 
1617002
dbd36d0
 
fea6c63
dbd36d0
 
fea6c63
dbd36d0
6e0bc6b
fea6c63
dbd36d0
 
 
 
fea6c63
dbd36d0
fea6c63
dbd36d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import datetime
from functools import partial
import json
from pathlib import Path
import random
import gradio as gr
import os
import firebase_admin
from firebase_admin import db, credentials


##################################################################
# Constants
##################################################################


NUMBER_OF_IMAGES_PER_ROW = 7
NUMBER_OF_ROWS = 2


#################################################################################################################################################
# Authentication
#################################################################################################################################################


# read secret api key
FIREBASE_API_KEY = os.environ['FirebaseSecret']
FIREBASE_URL = os.environ['FirebaseURL']
DATASET = os.environ['Dataset']

# init firebase service
firebase_creds = credentials.Certificate(json.loads(FIREBASE_API_KEY))
firebase_app = firebase_admin.initialize_app(firebase_creds, {'databaseURL': FIREBASE_URL})
firebase_data_ref = db.reference("data")


##################################################################
# Data Layer
##################################################################


class Experiment(dict):
    def __init__(self, dataset, corruption, image_id, corrupted, options, selected_image=None):
        super().__init__(
            dataset=dataset,
            corruption=corruption,
            image_id=image_id,
            corrupted=corrupted,
            options=options,
            selected_image=selected_image,
        )
    
def experiment_to_dict(experiment, skip=False):
    info = {
        # experiment info
        "dataset": experiment["dataset"],
        "corruption": experiment["corruption"],
        "image_number": experiment["image_id"],

        # chosen image set info
        "corrupted_filename": experiment["corrupted"]["name"],
        "options": [img["name"] for img in experiment["options"]],
    }

    if skip:
        info = {
            **info,
            # selected image info
            "selected_image": "None",
            "selected_algo": "None",
        }
    else:
        info = {
            **info,
            # selected image info
            "selected_image": experiment["options"][experiment["selected_image"]]["name"],
            "selected_algo": experiment["options"][experiment["selected_image"]]["algo"],
        }
    
    return info
    
def generate_new_experiment() -> Experiment:
    wanted_corruptions = ["spatter", "impulse_noise", "speckle_noise", "gaussian_noise", "pixelate", "jpeg_compression", "elastic_transform"]
    corruption = random.choice([f for f in list(Path(f"./images/{DATASET}").glob("*/*")) if f.is_dir() and f.name in wanted_corruptions])
    image_id = random.choice(list(corruption.glob("*")))
    imgs_to_sample = (NUMBER_OF_IMAGES_PER_ROW * NUMBER_OF_ROWS) // 2

    corrupted_image = {"name": str(random.choice(list(image_id.glob("*corrupted*"))))}
    sdedit_images = [
        {"name": str(img), "algo": "SDEdit"}
        for img in random.sample(list((image_id / "sde").glob(f"*")), imgs_to_sample)
    ]
    odedit_images = [
        {"name": str(img), "algo": "ODEdit"}
        for img in random.sample(list((image_id / "ode").glob(f"*")), imgs_to_sample)
    ]
    total_images = sdedit_images + odedit_images
    random.shuffle(total_images)

    return Experiment(
        DATASET,
        corruption.name,
        image_id.name,
        corrupted_image,
        total_images,
    )

def save(experiment, corrupted_component, *img_components, mode):
    if mode == "save" and (experiment is None or experiment["selected_image"] is None):
        gr.Warning("You must select an image before submitting")
        return [experiment, corrupted_component, *img_components]
    if mode == "skip":
        experiment["selected_image"] = None

    dict_to_save = {
        **experiment_to_dict(experiment, skip=(mode=="skip")),
        "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    }
    firebase_data_ref.push(dict_to_save)

    print("=====================")
    print(dict_to_save)
    print("=====================")

    gr.Info("Your choice has been saved to Firebase")
    return next()


##################################################################
# UI Layer
##################################################################


def next():
    new_experiment = generate_new_experiment()

    new_img_components = [
        gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
        for i, img in enumerate(new_experiment["options"])
    ]
    new_corrupted_component = gr.Image(value=new_experiment["corrupted"]["name"], label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False)

    return [new_experiment, new_corrupted_component, *new_img_components]

def on_select(evt: gr.SelectData, experiment, *img_components):  # SelectData is a subclass of EventData
    new_selected = int(evt.target.label)

    new_img_components = [
        gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
        for i, img in enumerate(experiment["options"])
    ]
    new_img_components[new_selected] = (
        gr.Image(value=experiment["options"][new_selected]["name"], label=f"{new_selected}", elem_id="sel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
    )

    experiment["selected_image"] = int(evt.target.label)

    return [experiment, *new_img_components]

css = """
#unsel {border: solid 5px transparent !important; border-radius: 15px !important; draggable: false}
#sel {border: solid 5px #00c0ff !important; border-radius: 15px !important; draggable: false}
#corrupted {margin-left: 5%; margin-right: 5%; padding: 0 !important; draggable: false}
#reducedHeight {height: 10px !important}
#padded {padding-left: 2%; padding-right: 2%}
"""

with gr.Blocks(title="Unsupervised Image Editing", css=css) as demo:
    experiment = gr.State(generate_new_experiment())

    with gr.Row(elem_id="padded"):
        corrupted_component = gr.Image(label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False)
        with gr.Column(scale=3, elem_id="padded"):
            gr.Markdown("<div style='width: 100%'><h1 style='text-align: center; display: inline-block; width: 100%'>The sample on the left is a corrupted image</h1></div>")
            gr.Markdown("<div style='width: 100%'><h3 style='text-align: center; display: inline-block; width: 100%'>Below are decorrupted versions sampled from various models. Click on the picture you like best.<br/>⚠️Do not pay attention to the background. Consider first fidelity, then quality⚠️</h3></div>")
            btn_skip = gr.Button("I have no preference")
            btn_submit = gr.Button("Submit preference")

    img_components = []
    for row in range(NUMBER_OF_ROWS):
        with gr.Row():
            for col in range(NUMBER_OF_IMAGES_PER_ROW):
                img_components.append(gr.Image(label=f"{row * NUMBER_OF_IMAGES_PER_ROW + col}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False))

    btn_skip.click(partial(save, mode="skip"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components])
    btn_submit.click(partial(save, mode="save"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components])
    for img in img_components:
        img.select(on_select, [experiment, *img_components], [experiment, *img_components], show_progress="hidden")

    demo.load(next, None, [experiment, corrupted_component, *img_components])

demo.launch()