Spaces:
Sleeping
Sleeping
Define global variables across all users -> reduce loading time
Browse files
app.py
CHANGED
@@ -16,50 +16,52 @@ from factories import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDatase
|
|
16 |
|
17 |
|
18 |
### Config
|
19 |
-
|
20 |
-
torch.
|
|
|
|
|
21 |
|
22 |
|
23 |
### Gradio Utils
|
24 |
def generate_imgs_from_user(image,
|
25 |
-
model: EvalModel, baseline: BaselineModel,
|
26 |
physics: PhysicsWithGenerator, use_gen: bool,
|
|
|
27 |
metrics: List[Metric]):
|
|
|
28 |
if image is None:
|
29 |
return None, None, None, None, None, None, None, None
|
30 |
|
31 |
-
# PIL image -> torch.Tensor
|
32 |
x = transforms.ToTensor()(image).unsqueeze(0).to(DEVICE_STR)
|
33 |
|
34 |
-
return generate_imgs(x,
|
35 |
|
36 |
def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
|
37 |
-
model: EvalModel, baseline: BaselineModel,
|
38 |
physics: PhysicsWithGenerator, use_gen: bool,
|
|
|
39 |
metrics: List[Metric]):
|
40 |
### Load 1 image
|
41 |
x = dataset[idx] # shape : (C, H, W)
|
42 |
x = x.unsqueeze(0) # shape : (1, C, H, W)
|
43 |
|
44 |
-
return generate_imgs(x,
|
45 |
|
46 |
def generate_random_imgs_from_dataset(dataset: EvalDataset,
|
47 |
-
model: EvalModel,
|
48 |
-
baseline: BaselineModel,
|
49 |
physics: PhysicsWithGenerator,
|
50 |
use_gen: bool,
|
|
|
|
|
51 |
metrics: List[Metric]):
|
52 |
idx = random.randint(0, len(dataset)-1)
|
53 |
x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
|
54 |
-
dataset, idx,
|
55 |
)
|
56 |
return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
|
57 |
|
58 |
def generate_imgs(x: torch.Tensor,
|
59 |
-
model: EvalModel, baseline: BaselineModel,
|
60 |
physics: PhysicsWithGenerator, use_gen: bool,
|
|
|
61 |
metrics: List[Metric]):
|
62 |
-
|
63 |
### Compute y
|
64 |
y = physics(x, use_gen) # possible reduction in img shape due to Blurring
|
65 |
|
@@ -114,11 +116,9 @@ def generate_imgs(x: torch.Tensor,
|
|
114 |
return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
|
115 |
|
116 |
|
117 |
-
get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
|
118 |
-
get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
|
119 |
-
get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
|
120 |
get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
|
121 |
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
|
|
|
122 |
|
123 |
def get_dataset(dataset_name):
|
124 |
if dataset_name == 'MRI':
|
@@ -142,37 +142,42 @@ def get_dataset(dataset_name):
|
|
142 |
return dataset, idx, physics, baseline, available_physics
|
143 |
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
### Gradio Blocks interface
|
146 |
|
147 |
title = "Inverse problem playground" # displayed on gradio tab and in the gradio page
|
148 |
with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
|
149 |
gr.Markdown("## " + title)
|
150 |
|
151 |
-
###
|
152 |
-
# Issue: giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
|
153 |
-
# Solution: using lambda expression
|
154 |
-
model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", ""))
|
155 |
-
model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR"))
|
156 |
-
metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))
|
157 |
-
|
158 |
dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
|
159 |
-
physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy"))
|
160 |
available_physics_placeholder = gr.State(['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
|
161 |
'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'])
|
|
|
|
|
|
|
|
|
162 |
|
163 |
|
164 |
-
### LAYOUT
|
165 |
-
# Display images
|
166 |
-
with gr.Row():
|
167 |
-
gt_img = gr.Image(label="Ground-truth image", interactive=True)
|
168 |
-
observed_img = gr.Image(label="Observed image", interactive=False)
|
169 |
-
model_a_out = gr.Image(label="RAM output", interactive=False)
|
170 |
-
model_b_out = gr.Image(label="DPIR output", interactive=False)
|
171 |
-
|
172 |
@gr.render(inputs=[dataset_placeholder, physics_placeholder, available_physics_placeholder])
|
173 |
def dynamic_layout(dataset, physics, available_physics):
|
174 |
### LAYOUT
|
175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
# Manage datasets and display metric values
|
177 |
with gr.Row():
|
178 |
with gr.Column(scale=1, min_width=160):
|
@@ -180,19 +185,16 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
|
|
180 |
choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
|
181 |
label="Datasets",
|
182 |
value=dataset.name)
|
183 |
-
idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key=
|
184 |
with gr.Row():
|
185 |
load_button = gr.Button("Run on index image from dataset", size='md')
|
186 |
load_random_button = gr.Button("Run on random image from dataset", size='md')
|
187 |
with gr.Column(scale=1, min_width=160):
|
188 |
-
observed_metrics = gr.Textbox(label="
|
189 |
-
lines=1)
|
190 |
with gr.Column(scale=1, min_width=160):
|
191 |
-
out_a_metric = gr.Textbox(label="
|
192 |
-
lines=1)
|
193 |
with gr.Column(scale=1, min_width=160):
|
194 |
-
out_b_metric = gr.Textbox(label="
|
195 |
-
lines=1)
|
196 |
|
197 |
# Manage physics
|
198 |
with gr.Row():
|
@@ -200,19 +202,18 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
|
|
200 |
choose_physics = gr.Radio(choices=available_physics,
|
201 |
label="Physics",
|
202 |
value=physics.name)
|
203 |
-
use_generator_button = gr.Checkbox(label="Generate physics parameters during inference", value=True)
|
204 |
with gr.Column(scale=1):
|
205 |
with gr.Row():
|
206 |
key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
|
207 |
-
label="Updatable
|
208 |
value_text = gr.Textbox(label="Update Value")
|
209 |
-
update_button = gr.Button("Manually update parameter value")
|
210 |
with gr.Column(scale=2):
|
211 |
physics_params = gr.Textbox(label="Physics parameters",
|
212 |
lines=5,
|
213 |
value=physics.display_saved_params())
|
214 |
|
215 |
-
|
216 |
### Event listeners
|
217 |
|
218 |
choose_dataset.change(fn=get_dataset,
|
@@ -223,34 +224,28 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
|
|
223 |
outputs=[physics_placeholder])
|
224 |
update_button.click(fn=physics.update_and_display_params,
|
225 |
inputs=[key_selector, value_text], outputs=physics_params)
|
226 |
-
run_button.click(fn=
|
227 |
inputs=[gt_img,
|
228 |
-
model_a_placeholder,
|
229 |
-
model_b_placeholder,
|
230 |
physics_placeholder,
|
231 |
use_generator_button,
|
232 |
-
|
233 |
outputs=[gt_img, observed_img, model_a_out, model_b_out,
|
234 |
physics_params, observed_metrics, out_a_metric, out_b_metric])
|
235 |
-
load_button.click(fn=
|
236 |
inputs=[dataset_placeholder,
|
237 |
idx_slider,
|
238 |
-
model_a_placeholder,
|
239 |
-
model_b_placeholder,
|
240 |
physics_placeholder,
|
241 |
use_generator_button,
|
242 |
-
|
243 |
outputs=[gt_img, observed_img, model_a_out, model_b_out,
|
244 |
physics_params, observed_metrics, out_a_metric, out_b_metric])
|
245 |
-
load_random_button.click(fn=
|
246 |
inputs=[dataset_placeholder,
|
247 |
-
model_a_placeholder,
|
248 |
-
model_b_placeholder,
|
249 |
physics_placeholder,
|
250 |
use_generator_button,
|
251 |
-
|
252 |
outputs=[idx_slider, gt_img, observed_img, model_a_out, model_b_out,
|
253 |
physics_params, observed_metrics, out_a_metric, out_b_metric])
|
254 |
|
255 |
|
256 |
-
interface.launch()
|
|
|
16 |
|
17 |
|
18 |
### Config
|
19 |
+
# run model inference on NVIDIA gpu if available
|
20 |
+
DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
+
# stops tracking values for gradients
|
22 |
+
torch.set_grad_enabled(False)
|
23 |
|
24 |
|
25 |
### Gradio Utils
|
26 |
def generate_imgs_from_user(image,
|
|
|
27 |
physics: PhysicsWithGenerator, use_gen: bool,
|
28 |
+
baseline: BaselineModel, model: EvalModel,
|
29 |
metrics: List[Metric]):
|
30 |
+
# Happens when user image is missing
|
31 |
if image is None:
|
32 |
return None, None, None, None, None, None, None, None
|
33 |
|
34 |
+
# PIL image -> torch.Tensor / (1, C, H, W) / move to DEVICE_STR
|
35 |
x = transforms.ToTensor()(image).unsqueeze(0).to(DEVICE_STR)
|
36 |
|
37 |
+
return generate_imgs(x, physics, use_gen, baseline, model, metrics)
|
38 |
|
39 |
def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
|
|
|
40 |
physics: PhysicsWithGenerator, use_gen: bool,
|
41 |
+
baseline: BaselineModel, model: EvalModel,
|
42 |
metrics: List[Metric]):
|
43 |
### Load 1 image
|
44 |
x = dataset[idx] # shape : (C, H, W)
|
45 |
x = x.unsqueeze(0) # shape : (1, C, H, W)
|
46 |
|
47 |
+
return generate_imgs(x, physics, use_gen, baseline, model, metrics)
|
48 |
|
49 |
def generate_random_imgs_from_dataset(dataset: EvalDataset,
|
|
|
|
|
50 |
physics: PhysicsWithGenerator,
|
51 |
use_gen: bool,
|
52 |
+
baseline: BaselineModel,
|
53 |
+
model: EvalModel,
|
54 |
metrics: List[Metric]):
|
55 |
idx = random.randint(0, len(dataset)-1)
|
56 |
x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
|
57 |
+
dataset, idx, physics, use_gen, baseline, model, metrics
|
58 |
)
|
59 |
return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
|
60 |
|
61 |
def generate_imgs(x: torch.Tensor,
|
|
|
62 |
physics: PhysicsWithGenerator, use_gen: bool,
|
63 |
+
baseline: BaselineModel, model: EvalModel,
|
64 |
metrics: List[Metric]):
|
|
|
65 |
### Compute y
|
66 |
y = physics(x, use_gen) # possible reduction in img shape due to Blurring
|
67 |
|
|
|
116 |
return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
|
117 |
|
118 |
|
|
|
|
|
|
|
119 |
get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
|
120 |
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
|
121 |
+
get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
|
122 |
|
123 |
def get_dataset(dataset_name):
|
124 |
if dataset_name == 'MRI':
|
|
|
142 |
return dataset, idx, physics, baseline, available_physics
|
143 |
|
144 |
|
145 |
+
# global variables shared by all users
|
146 |
+
ram_model = EvalModel("unext_emb_physics_config_C", device_str=DEVICE_STR)
|
147 |
+
psnr = Metric.get_list_metrics(["PSNR"], device_str=DEVICE_STR)
|
148 |
+
|
149 |
+
generate_imgs_from_user_partial = partial(generate_imgs_from_user, model=ram_model, metrics=psnr)
|
150 |
+
generate_imgs_from_dataset_partial = partial(generate_imgs_from_dataset, model=ram_model, metrics=psnr)
|
151 |
+
generate_random_imgs_from_dataset_partial = partial(generate_random_imgs_from_dataset, model=ram_model, metrics=psnr)
|
152 |
+
|
153 |
+
|
154 |
### Gradio Blocks interface
|
155 |
|
156 |
title = "Inverse problem playground" # displayed on gradio tab and in the gradio page
|
157 |
with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
|
158 |
gr.Markdown("## " + title)
|
159 |
|
160 |
+
### USER-SPECIFIC VARIABLES
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
|
|
|
162 |
available_physics_placeholder = gr.State(['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
|
163 |
'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'])
|
164 |
+
# Issue giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
|
165 |
+
# Solution: using lambda expression
|
166 |
+
physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy"))
|
167 |
+
model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR"))
|
168 |
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
@gr.render(inputs=[dataset_placeholder, physics_placeholder, available_physics_placeholder])
|
171 |
def dynamic_layout(dataset, physics, available_physics):
|
172 |
### LAYOUT
|
173 |
|
174 |
+
# Display images
|
175 |
+
with gr.Row():
|
176 |
+
gt_img = gr.Image(label="Ground-truth image", interactive=True, key=0)
|
177 |
+
observed_img = gr.Image(label="Observed image", interactive=False, key=1)
|
178 |
+
model_a_out = gr.Image(label="RAM output", interactive=False, key=2)
|
179 |
+
model_b_out = gr.Image(label="DPIR output", interactive=False, key=3)
|
180 |
+
|
181 |
# Manage datasets and display metric values
|
182 |
with gr.Row():
|
183 |
with gr.Column(scale=1, min_width=160):
|
|
|
185 |
choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
|
186 |
label="Datasets",
|
187 |
value=dataset.name)
|
188 |
+
idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key=4)
|
189 |
with gr.Row():
|
190 |
load_button = gr.Button("Run on index image from dataset", size='md')
|
191 |
load_random_button = gr.Button("Run on random image from dataset", size='md')
|
192 |
with gr.Column(scale=1, min_width=160):
|
193 |
+
observed_metrics = gr.Textbox(label="Observed metric", lines=1, key=5)
|
|
|
194 |
with gr.Column(scale=1, min_width=160):
|
195 |
+
out_a_metric = gr.Textbox(label="RAM output metrics", lines=1, key=6)
|
|
|
196 |
with gr.Column(scale=1, min_width=160):
|
197 |
+
out_b_metric = gr.Textbox(label="DPIR output metrics", lines=1, key=7)
|
|
|
198 |
|
199 |
# Manage physics
|
200 |
with gr.Row():
|
|
|
202 |
choose_physics = gr.Radio(choices=available_physics,
|
203 |
label="Physics",
|
204 |
value=physics.name)
|
205 |
+
use_generator_button = gr.Checkbox(label="Generate physics parameters during inference", value=True, key=8)
|
206 |
with gr.Column(scale=1):
|
207 |
with gr.Row():
|
208 |
key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
|
209 |
+
label="Updatable Key")
|
210 |
value_text = gr.Textbox(label="Update Value")
|
211 |
+
update_button = gr.Button("Manually update parameter value", size='md')
|
212 |
with gr.Column(scale=2):
|
213 |
physics_params = gr.Textbox(label="Physics parameters",
|
214 |
lines=5,
|
215 |
value=physics.display_saved_params())
|
216 |
|
|
|
217 |
### Event listeners
|
218 |
|
219 |
choose_dataset.change(fn=get_dataset,
|
|
|
224 |
outputs=[physics_placeholder])
|
225 |
update_button.click(fn=physics.update_and_display_params,
|
226 |
inputs=[key_selector, value_text], outputs=physics_params)
|
227 |
+
run_button.click(fn=generate_imgs_from_user_partial,
|
228 |
inputs=[gt_img,
|
|
|
|
|
229 |
physics_placeholder,
|
230 |
use_generator_button,
|
231 |
+
model_b_placeholder],
|
232 |
outputs=[gt_img, observed_img, model_a_out, model_b_out,
|
233 |
physics_params, observed_metrics, out_a_metric, out_b_metric])
|
234 |
+
load_button.click(fn=generate_imgs_from_dataset_partial,
|
235 |
inputs=[dataset_placeholder,
|
236 |
idx_slider,
|
|
|
|
|
237 |
physics_placeholder,
|
238 |
use_generator_button,
|
239 |
+
model_b_placeholder],
|
240 |
outputs=[gt_img, observed_img, model_a_out, model_b_out,
|
241 |
physics_params, observed_metrics, out_a_metric, out_b_metric])
|
242 |
+
load_random_button.click(fn=generate_random_imgs_from_dataset_partial,
|
243 |
inputs=[dataset_placeholder,
|
|
|
|
|
244 |
physics_placeholder,
|
245 |
use_generator_button,
|
246 |
+
model_b_placeholder],
|
247 |
outputs=[idx_slider, gt_img, observed_img, model_a_out, model_b_out,
|
248 |
physics_params, observed_metrics, out_a_metric, out_b_metric])
|
249 |
|
250 |
|
251 |
+
interface.launch()
|