Zhenyu Li commited on
Commit
e773e71
·
1 Parent(s): 6459e4a
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -22,6 +22,7 @@
22
 
23
  # File author: Zhenyu Li
24
 
 
25
  from ControlNet.share import *
26
  import einops
27
  import torch
@@ -107,7 +108,6 @@ model.load_state_dict(load_state_dict(controlnet_ckp, location=DEVICE), strict=F
107
  model = model.to(DEVICE)
108
  ddim_sampler = DDIMSampler(model)
109
 
110
-
111
  # controlnet
112
  title = "# PatchFusion"
113
  description = """Official demo for **PatchFusion: An End-to-End Tile-Based Framework for High-Resolution Monocular Metric Depth Estimation**.
@@ -136,6 +136,11 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
136
  with torch.no_grad():
137
  w, h = input_image.size
138
  detected_map = predict_depth(depth_model, input_image, mode, patch_number, resolution, patch_size, device=DEVICE)
 
 
 
 
 
139
  detected_map = F.interpolate(torch.from_numpy(detected_map).unsqueeze(dim=0).unsqueeze(dim=0), (image_resolution, image_resolution), mode='bicubic', align_corners=True).squeeze().numpy()
140
 
141
  H, W = detected_map.shape
 
22
 
23
  # File author: Zhenyu Li
24
 
25
+ import gc
26
  from ControlNet.share import *
27
  import einops
28
  import torch
 
108
  model = model.to(DEVICE)
109
  ddim_sampler = DDIMSampler(model)
110
 
 
111
  # controlnet
112
  title = "# PatchFusion"
113
  description = """Official demo for **PatchFusion: An End-to-End Tile-Based Framework for High-Resolution Monocular Metric Depth Estimation**.
 
136
  with torch.no_grad():
137
  w, h = input_image.size
138
  detected_map = predict_depth(depth_model, input_image, mode, patch_number, resolution, patch_size, device=DEVICE)
139
+
140
+ del depth_model # after using the depth model, free the mem
141
+ gc.collect()
142
+ torch.cuda.empty_cache()
143
+
144
  detected_map = F.interpolate(torch.from_numpy(detected_map).unsqueeze(dim=0).unsqueeze(dim=0), (image_resolution, image_resolution), mode='bicubic', align_corners=True).squeeze().numpy()
145
 
146
  H, W = detected_map.shape