Spaces:
Sleeping
Sleeping
Free memory after inference to reduce potential memory overflow
Browse files
app.py
CHANGED
@@ -109,14 +109,18 @@ def generate_imgs(x: torch.Tensor,
|
|
109 |
# - torch.Tensor object -> Pil object
|
110 |
process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip")
|
111 |
to_pil = transforms.ToPILImage()
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
|
|
|
|
|
|
|
|
|
117 |
print(torch.cuda.memory_allocated() / 1024**2)
|
118 |
-
return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
|
119 |
|
|
|
120 |
|
121 |
get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
|
122 |
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
|
@@ -175,10 +179,10 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
|
|
175 |
|
176 |
# Display images
|
177 |
with gr.Row():
|
178 |
-
gt_img = gr.Image(label="Ground-truth image", interactive=True, key=
|
179 |
-
observed_img = gr.Image(label="Observed image", interactive=False, key=
|
180 |
-
model_a_out = gr.Image(label="RAM output", interactive=False, key=
|
181 |
-
model_b_out = gr.Image(label="DPIR output", interactive=False, key=
|
182 |
|
183 |
# Manage datasets and display metric values
|
184 |
with gr.Row():
|
@@ -187,16 +191,16 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
|
|
187 |
choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
|
188 |
label="Datasets",
|
189 |
value=dataset.name)
|
190 |
-
idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key=
|
191 |
with gr.Row():
|
192 |
load_button = gr.Button("Run on index image from dataset", size='md')
|
193 |
load_random_button = gr.Button("Run on random image from dataset", size='md')
|
194 |
with gr.Column(scale=1, min_width=160):
|
195 |
-
observed_metrics = gr.Textbox(label="Observed metric", lines=3, key=
|
196 |
with gr.Column(scale=1, min_width=160):
|
197 |
-
out_a_metric = gr.Textbox(label="RAM output metrics", lines=3, key=
|
198 |
with gr.Column(scale=1, min_width=160):
|
199 |
-
out_b_metric = gr.Textbox(label="DPIR output metrics", lines=3, key=
|
200 |
|
201 |
# Manage physics
|
202 |
with gr.Row():
|
@@ -204,7 +208,7 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
|
|
204 |
choose_physics = gr.Radio(choices=available_physics,
|
205 |
label="Physics",
|
206 |
value=physics.name)
|
207 |
-
use_generator_button = gr.Checkbox(label="Generate physics parameters during inference", value=True, key=
|
208 |
with gr.Column(scale=1):
|
209 |
with gr.Row():
|
210 |
key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
|
|
|
109 |
# - torch.Tensor object -> Pil object
|
110 |
process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip")
|
111 |
to_pil = transforms.ToPILImage()
|
112 |
+
x_pil = to_pil(process_img(x)[0].to('cpu'))
|
113 |
+
y_pil = to_pil(process_img(y_plot)[0].to('cpu'))
|
114 |
+
out_pil = to_pil(process_img(out)[0].to('cpu'))
|
115 |
+
out_baseline_pil = to_pil(process_img(out_baseline)[0].to('cpu'))
|
116 |
|
117 |
+
|
118 |
+
# Free memory
|
119 |
+
del x, y, out, out_baseline, y_plot
|
120 |
+
torch.cuda.empty_cache()
|
121 |
print(torch.cuda.memory_allocated() / 1024**2)
|
|
|
122 |
|
123 |
+
return x_pil, y_pil, out_pil, out_baseline_pil, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
|
124 |
|
125 |
get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
|
126 |
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
|
|
|
179 |
|
180 |
# Display images
|
181 |
with gr.Row():
|
182 |
+
gt_img = gr.Image(label="Ground-truth image", interactive=True, key='gt_img')
|
183 |
+
observed_img = gr.Image(label="Observed image", interactive=False, key='observed_img')
|
184 |
+
model_a_out = gr.Image(label="RAM output", interactive=False, key='ram_out')
|
185 |
+
model_b_out = gr.Image(label="DPIR output", interactive=False, key='dpir_out')
|
186 |
|
187 |
# Manage datasets and display metric values
|
188 |
with gr.Row():
|
|
|
191 |
choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
|
192 |
label="Datasets",
|
193 |
value=dataset.name)
|
194 |
+
idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key='idx_slider')
|
195 |
with gr.Row():
|
196 |
load_button = gr.Button("Run on index image from dataset", size='md')
|
197 |
load_random_button = gr.Button("Run on random image from dataset", size='md')
|
198 |
with gr.Column(scale=1, min_width=160):
|
199 |
+
observed_metrics = gr.Textbox(label="Observed metric", lines=3, key='metrics')
|
200 |
with gr.Column(scale=1, min_width=160):
|
201 |
+
out_a_metric = gr.Textbox(label="RAM output metrics", lines=3, key='ram_metrics')
|
202 |
with gr.Column(scale=1, min_width=160):
|
203 |
+
out_b_metric = gr.Textbox(label="DPIR output metrics", lines=3, key='dpir_metrics')
|
204 |
|
205 |
# Manage physics
|
206 |
with gr.Row():
|
|
|
208 |
choose_physics = gr.Radio(choices=available_physics,
|
209 |
label="Physics",
|
210 |
value=physics.name)
|
211 |
+
use_generator_button = gr.Checkbox(label="Generate physics parameters during inference", value=True, key='use_gen')
|
212 |
with gr.Column(scale=1):
|
213 |
with gr.Row():
|
214 |
key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
|