msong97 commited on
Commit
2776aea
·
1 Parent(s): 15a7da9

Free memory after inference to reduce potential memory overflow

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -109,14 +109,18 @@ def generate_imgs(x: torch.Tensor,
109
  # - torch.Tensor object -> Pil object
110
  process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip")
111
  to_pil = transforms.ToPILImage()
112
- x = to_pil(process_img(x)[0].to('cpu'))
113
- y = to_pil(process_img(y_plot)[0].to('cpu'))
114
- out = to_pil(process_img(out)[0].to('cpu'))
115
- out_baseline = to_pil(process_img(out_baseline)[0].to('cpu'))
116
 
 
 
 
 
117
  print(torch.cuda.memory_allocated() / 1024**2)
118
- return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
119
 
 
120
 
121
  get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
122
  get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
@@ -175,10 +179,10 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
175
 
176
  # Display images
177
  with gr.Row():
178
- gt_img = gr.Image(label="Ground-truth image", interactive=True, key=0)
179
- observed_img = gr.Image(label="Observed image", interactive=False, key=1)
180
- model_a_out = gr.Image(label="RAM output", interactive=False, key=2)
181
- model_b_out = gr.Image(label="DPIR output", interactive=False, key=3)
182
 
183
  # Manage datasets and display metric values
184
  with gr.Row():
@@ -187,16 +191,16 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
187
  choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
188
  label="Datasets",
189
  value=dataset.name)
190
- idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key=4)
191
  with gr.Row():
192
  load_button = gr.Button("Run on index image from dataset", size='md')
193
  load_random_button = gr.Button("Run on random image from dataset", size='md')
194
  with gr.Column(scale=1, min_width=160):
195
- observed_metrics = gr.Textbox(label="Observed metric", lines=3, key=5)
196
  with gr.Column(scale=1, min_width=160):
197
- out_a_metric = gr.Textbox(label="RAM output metrics", lines=3, key=6)
198
  with gr.Column(scale=1, min_width=160):
199
- out_b_metric = gr.Textbox(label="DPIR output metrics", lines=3, key=7)
200
 
201
  # Manage physics
202
  with gr.Row():
@@ -204,7 +208,7 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
204
  choose_physics = gr.Radio(choices=available_physics,
205
  label="Physics",
206
  value=physics.name)
207
- use_generator_button = gr.Checkbox(label="Generate physics parameters during inference", value=True, key=8)
208
  with gr.Column(scale=1):
209
  with gr.Row():
210
  key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
 
109
  # - torch.Tensor object -> Pil object
110
  process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip")
111
  to_pil = transforms.ToPILImage()
112
+ x_pil = to_pil(process_img(x)[0].to('cpu'))
113
+ y_pil = to_pil(process_img(y_plot)[0].to('cpu'))
114
+ out_pil = to_pil(process_img(out)[0].to('cpu'))
115
+ out_baseline_pil = to_pil(process_img(out_baseline)[0].to('cpu'))
116
 
117
+
118
+ # Free memory
119
+ del x, y, out, out_baseline, y_plot
120
+ torch.cuda.empty_cache()
121
  print(torch.cuda.memory_allocated() / 1024**2)
 
122
 
123
+ return x_pil, y_pil, out_pil, out_baseline_pil, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
124
 
125
  get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
126
  get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
 
179
 
180
  # Display images
181
  with gr.Row():
182
+ gt_img = gr.Image(label="Ground-truth image", interactive=True, key='gt_img')
183
+ observed_img = gr.Image(label="Observed image", interactive=False, key='observed_img')
184
+ model_a_out = gr.Image(label="RAM output", interactive=False, key='ram_out')
185
+ model_b_out = gr.Image(label="DPIR output", interactive=False, key='dpir_out')
186
 
187
  # Manage datasets and display metric values
188
  with gr.Row():
 
191
  choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
192
  label="Datasets",
193
  value=dataset.name)
194
+ idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key='idx_slider')
195
  with gr.Row():
196
  load_button = gr.Button("Run on index image from dataset", size='md')
197
  load_random_button = gr.Button("Run on random image from dataset", size='md')
198
  with gr.Column(scale=1, min_width=160):
199
+ observed_metrics = gr.Textbox(label="Observed metric", lines=3, key='metrics')
200
  with gr.Column(scale=1, min_width=160):
201
+ out_a_metric = gr.Textbox(label="RAM output metrics", lines=3, key='ram_metrics')
202
  with gr.Column(scale=1, min_width=160):
203
+ out_b_metric = gr.Textbox(label="DPIR output metrics", lines=3, key='dpir_metrics')
204
 
205
  # Manage physics
206
  with gr.Row():
 
208
  choose_physics = gr.Radio(choices=available_physics,
209
  label="Physics",
210
  value=physics.name)
211
+ use_generator_button = gr.Checkbox(label="Generate physics parameters during inference", value=True, key='use_gen')
212
  with gr.Column(scale=1):
213
  with gr.Row():
214
  key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),