Spaces:
Sleeping
Sleeping
Fix syntax error and add dynamic change of available physics depending on dataset
Browse files
app.py
CHANGED
@@ -90,39 +90,6 @@ def update_random_idx_and_generate_imgs(dataset: EvalDataset,
|
|
90 |
metrics)
|
91 |
return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
|
92 |
|
93 |
-
def save_imgs(dataset: EvalDataset, idx: int, physics: PhysicsWithGenerator,
|
94 |
-
model_a: EvalModel | BaselineModel, model_b: EvalModel | BaselineModel,
|
95 |
-
x: Image.Image, y: Image.Image,
|
96 |
-
out_a: Image.Image, out_b: Image.Image,
|
97 |
-
y_metrics_str: str,
|
98 |
-
out_a_metric_str : str, out_b_metric_str: str) -> None:
|
99 |
-
|
100 |
-
### PROCESSES STR
|
101 |
-
physics_params_str = ""
|
102 |
-
for param_name, param_value in physics.saved_params["updatable_params"].items():
|
103 |
-
physics_params_str += f"{param_name}_{param_value}-"
|
104 |
-
physics_params_str = physics_params_str[:-1] if physics_params_str.endswith("-") else physics_params_str
|
105 |
-
y_metrics_str = y_metrics_str.replace(" = ", "_").replace("\n", "-")
|
106 |
-
y_metrics_str = y_metrics_str[:-1] if y_metrics_str.endswith("-") else y_metrics_str
|
107 |
-
out_a_metric_str = out_a_metric_str.replace(" = ", "_").replace("\n", "-")
|
108 |
-
out_a_metric_str = out_a_metric_str[:-1] if out_a_metric_str.endswith("-") else out_a_metric_str
|
109 |
-
out_b_metric_str = out_b_metric_str.replace(" = ", "_").replace("\n", "-")
|
110 |
-
out_b_metric_str = out_b_metric_str[:-1] if out_b_metric_str.endswith("-") else out_b_metric_str
|
111 |
-
|
112 |
-
save_path = SAVE_IMG_DIR / f"{dataset.name}+{idx}+{physics.name}+{physics_params_str}+{y_metrics_str}+{model_a.name}+{out_a_metric_str}+{model_b.name}+{out_b_metric_str}.png"
|
113 |
-
titles = [f"{dataset.name}[{idx}]",
|
114 |
-
f"y = {physics.name}(x)",
|
115 |
-
f"{model_a.name}",
|
116 |
-
f"{model_b.name}"]
|
117 |
-
|
118 |
-
# Pil object -> torch.Tensor
|
119 |
-
to_tensor = transforms.ToTensor()
|
120 |
-
x = to_tensor(x)
|
121 |
-
y = to_tensor(y)
|
122 |
-
out_a = to_tensor(out_a)
|
123 |
-
out_b = to_tensor(out_b)
|
124 |
-
|
125 |
-
dinv.utils.plot([x, y, out_a, out_b], titles=titles, show=False, save_fn=save_path)
|
126 |
|
127 |
get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
|
128 |
get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
|
@@ -139,19 +106,12 @@ def get_physics(physics_name):
|
|
139 |
baseline = get_baseline_model_on_DEVICE_STR('DPIR')
|
140 |
return get_physics_on_DEVICE_STR(physics_name), baseline
|
141 |
|
142 |
-
def get_model(model_name, ckpt_pth):
|
143 |
-
if model_name in BaselineModel.all_baselines:
|
144 |
-
return get_baseline_model_on_DEVICE_STR(model_name)
|
145 |
-
else:
|
146 |
-
return get_eval_model_on_DEVICE_STR(model_name, ckpt_pth)
|
147 |
-
|
148 |
AVAILABLE_PHYSICS = PhysicsWithGenerator.all_physics
|
149 |
-
|
150 |
def get_dataset(dataset_name):
|
151 |
global AVAILABLE_PHYSICS
|
152 |
-
if dataset_name
|
153 |
AVAILABLE_PHYSICS = ['MRI']
|
154 |
-
elif dataset_name
|
155 |
AVAILABLE_PHYSICS = ['CT']
|
156 |
else:
|
157 |
AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
|
@@ -197,7 +157,7 @@ with gr.Blocks(title=title, css=custom_css) as interface:
|
|
197 |
y_image = gr.Image(label=f"{physics_name} IMAGE", interactive=False)
|
198 |
y_metrics = gr.Textbox(label="Metrics(y, x)", elem_classes=["fixed-textbox"],)
|
199 |
|
200 |
-
choose_physics = gr.Radio(choices=
|
201 |
label="List of PhysicsWithGenerator",
|
202 |
value=physics_name)
|
203 |
with gr.Row():
|
@@ -231,10 +191,10 @@ with gr.Blocks(title=title, css=custom_css) as interface:
|
|
231 |
use_generator_button = gr.Checkbox(label="Generate valid physics parameters", scale=1)
|
232 |
with gr.Column(scale=1):
|
233 |
load_button = gr.Button("Load images...")
|
234 |
-
|
235 |
|
236 |
### Event listeners
|
237 |
-
choose_dataset.change(fn=
|
238 |
inputs=choose_dataset,
|
239 |
outputs=dataset_placeholder)
|
240 |
choose_physics.change(fn=get_physics,
|
|
|
90 |
metrics)
|
91 |
return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
|
95 |
get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
|
|
|
106 |
baseline = get_baseline_model_on_DEVICE_STR('DPIR')
|
107 |
return get_physics_on_DEVICE_STR(physics_name), baseline
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
AVAILABLE_PHYSICS = PhysicsWithGenerator.all_physics
|
|
|
110 |
def get_dataset(dataset_name):
|
111 |
global AVAILABLE_PHYSICS
|
112 |
+
if dataset_name == 'MRI':
|
113 |
AVAILABLE_PHYSICS = ['MRI']
|
114 |
+
elif dataset_name == 'CT':
|
115 |
AVAILABLE_PHYSICS = ['CT']
|
116 |
else:
|
117 |
AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
|
|
|
157 |
y_image = gr.Image(label=f"{physics_name} IMAGE", interactive=False)
|
158 |
y_metrics = gr.Textbox(label="Metrics(y, x)", elem_classes=["fixed-textbox"],)
|
159 |
|
160 |
+
choose_physics = gr.Radio(choices=AVAILABLE_PHYSICS,
|
161 |
label="List of PhysicsWithGenerator",
|
162 |
value=physics_name)
|
163 |
with gr.Row():
|
|
|
191 |
use_generator_button = gr.Checkbox(label="Generate valid physics parameters", scale=1)
|
192 |
with gr.Column(scale=1):
|
193 |
load_button = gr.Button("Load images...")
|
194 |
+
load_random_button = gr.Button("Load randomly...")
|
195 |
|
196 |
### Event listeners
|
197 |
+
choose_dataset.change(fn=get_dataset,
|
198 |
inputs=choose_dataset,
|
199 |
outputs=dataset_placeholder)
|
200 |
choose_physics.change(fn=get_physics,
|