Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
4d5065f
1
Parent(s):
562c833
upload
Browse files- util/__init__.py +0 -0
- util/batchsize.py +59 -0
- util/image_util.py +172 -0
- util/seed_all.py +13 -0
util/__init__.py
ADDED
File without changes
|
util/batchsize.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
|
4 |
+
|
5 |
+
# Search table for suggested max. inference batch size
|
6 |
+
bs_search_table = [
|
7 |
+
# tested on A100-PCIE-80GB
|
8 |
+
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
|
9 |
+
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
|
10 |
+
# tested on A100-PCIE-40GB
|
11 |
+
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
|
12 |
+
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
|
13 |
+
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
|
14 |
+
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
|
15 |
+
# tested on RTX3090, RTX4090
|
16 |
+
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
|
17 |
+
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
|
18 |
+
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
|
19 |
+
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
|
20 |
+
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
|
21 |
+
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
|
22 |
+
# tested on GTX1080Ti
|
23 |
+
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
|
24 |
+
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
|
25 |
+
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
|
26 |
+
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
|
27 |
+
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
|
28 |
+
]
|
29 |
+
|
30 |
+
|
31 |
+
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
32 |
+
"""
|
33 |
+
Automatically search for suitable operating batch size.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
ensemble_size (int): Number of predictions to be ensembled
|
37 |
+
input_res (int): Operating resolution of the input image.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
int: Operating batch size
|
41 |
+
"""
|
42 |
+
if not torch.cuda.is_available():
|
43 |
+
return 1
|
44 |
+
|
45 |
+
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
46 |
+
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
|
47 |
+
for settings in sorted(
|
48 |
+
filtered_bs_search_table,
|
49 |
+
key=lambda k: (k["res"], -k["total_vram"]),
|
50 |
+
):
|
51 |
+
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
|
52 |
+
bs = settings["bs"]
|
53 |
+
if bs > ensemble_size:
|
54 |
+
bs = ensemble_size
|
55 |
+
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
|
56 |
+
bs = math.ceil(ensemble_size / 2)
|
57 |
+
return bs
|
58 |
+
|
59 |
+
return 1
|
util/image_util.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from torchvision import transforms
|
6 |
+
|
7 |
+
def norm_to_rgb(norm):
|
8 |
+
# norm: (3, H, W), range from [-1, 1]
|
9 |
+
norm_rgb = ((norm + 1) * 0.5) * 255
|
10 |
+
norm_rgb = np.clip(norm_rgb, a_min=0, a_max=255)
|
11 |
+
norm_rgb = norm_rgb.astype(np.uint8)
|
12 |
+
return norm_rgb
|
13 |
+
|
14 |
+
def colorize_depth_maps(
|
15 |
+
depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
|
16 |
+
):
|
17 |
+
"""
|
18 |
+
Colorize depth maps.
|
19 |
+
"""
|
20 |
+
assert len(depth_map.shape) >= 2, "Invalid dimension"
|
21 |
+
|
22 |
+
if isinstance(depth_map, torch.Tensor):
|
23 |
+
depth = depth_map.detach().clone().squeeze().numpy()
|
24 |
+
elif isinstance(depth_map, np.ndarray):
|
25 |
+
depth = np.squeeze(depth_map.copy())
|
26 |
+
# reshape to [ (B,) H, W ]
|
27 |
+
if depth.ndim < 3:
|
28 |
+
depth = depth[np.newaxis, :, :]
|
29 |
+
|
30 |
+
# colorize
|
31 |
+
cm = matplotlib.colormaps[cmap]
|
32 |
+
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
|
33 |
+
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
|
34 |
+
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
|
35 |
+
|
36 |
+
if valid_mask is not None:
|
37 |
+
if isinstance(depth_map, torch.Tensor):
|
38 |
+
valid_mask = valid_mask.detach().numpy()
|
39 |
+
valid_mask = np.squeeze(valid_mask) # [H, W] or [B, H, W]
|
40 |
+
if valid_mask.ndim < 3:
|
41 |
+
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
42 |
+
else:
|
43 |
+
valid_mask = valid_mask[:, np.newaxis, :, :]
|
44 |
+
valid_mask = np.repeat(valid_mask, 3, axis=1)
|
45 |
+
img_colored_np[~valid_mask] = 0
|
46 |
+
|
47 |
+
if isinstance(depth_map, torch.Tensor):
|
48 |
+
img_colored = torch.from_numpy(img_colored_np).float()
|
49 |
+
elif isinstance(depth_map, np.ndarray):
|
50 |
+
img_colored = img_colored_np
|
51 |
+
|
52 |
+
return img_colored
|
53 |
+
|
54 |
+
|
55 |
+
def chw2hwc(chw):
|
56 |
+
assert 3 == len(chw.shape)
|
57 |
+
if isinstance(chw, torch.Tensor):
|
58 |
+
hwc = torch.permute(chw, (1, 2, 0))
|
59 |
+
elif isinstance(chw, np.ndarray):
|
60 |
+
hwc = np.moveaxis(chw, 0, -1)
|
61 |
+
return hwc
|
62 |
+
|
63 |
+
|
64 |
+
def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
65 |
+
"""
|
66 |
+
Resize image to limit maximum edge length while keeping aspect ratio
|
67 |
+
|
68 |
+
Args:
|
69 |
+
img (Image.Image): Image to be resized
|
70 |
+
max_edge_resolution (int): Maximum edge length (px).
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
Image.Image: Resized image.
|
74 |
+
"""
|
75 |
+
original_width, original_height = img.size
|
76 |
+
downscale_factor = min(
|
77 |
+
max_edge_resolution / original_width, max_edge_resolution / original_height
|
78 |
+
)
|
79 |
+
|
80 |
+
new_width = int(original_width * downscale_factor)
|
81 |
+
new_height = int(original_height * downscale_factor)
|
82 |
+
|
83 |
+
resized_img = img.resize((new_width, new_height))
|
84 |
+
return resized_img
|
85 |
+
|
86 |
+
def resize_max_res_integer_16(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
87 |
+
"""
|
88 |
+
Resize image to limit maximum edge length while keeping aspect ratio
|
89 |
+
|
90 |
+
Args:
|
91 |
+
img (Image.Image): Image to be resized
|
92 |
+
max_edge_resolution (int): Maximum edge length (px).
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
Image.Image: Resized image.
|
96 |
+
"""
|
97 |
+
original_width, original_height = img.size
|
98 |
+
downscale_factor = min(
|
99 |
+
max_edge_resolution / original_width, max_edge_resolution / original_height
|
100 |
+
)
|
101 |
+
|
102 |
+
new_width = int(original_width * downscale_factor) // 16 * 16 # make sure it is integer multiples of 16, used for pixart
|
103 |
+
new_height = int(original_height * downscale_factor) // 16 * 16 # make sure it is integer multiples of 16, used for pixart
|
104 |
+
|
105 |
+
resized_img = img.resize((new_width, new_height))
|
106 |
+
return resized_img
|
107 |
+
|
108 |
+
def resize_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
109 |
+
"""
|
110 |
+
Resize image to limit maximum edge length while keeping aspect ratio
|
111 |
+
|
112 |
+
Args:
|
113 |
+
img (Image.Image): Image to be resized
|
114 |
+
max_edge_resolution (int): Maximum edge length (px).
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
Image.Image: Resized image.
|
118 |
+
"""
|
119 |
+
|
120 |
+
resized_img = img.resize((max_edge_resolution, max_edge_resolution))
|
121 |
+
return resized_img
|
122 |
+
|
123 |
+
class ResizeLongestEdge:
|
124 |
+
def __init__(self, max_size, interpolation=transforms.InterpolationMode.BILINEAR):
|
125 |
+
self.max_size = max_size
|
126 |
+
self.interpolation = interpolation
|
127 |
+
|
128 |
+
def __call__(self, img):
|
129 |
+
|
130 |
+
scale = self.max_size / max(img.width, img.height)
|
131 |
+
new_size = (int(img.height * scale), int(img.width * scale))
|
132 |
+
|
133 |
+
return transforms.functional.resize(img, new_size, self.interpolation)
|
134 |
+
|
135 |
+
class ResizeShortestEdge:
|
136 |
+
def __init__(self, min_size, interpolation=transforms.InterpolationMode.BILINEAR):
|
137 |
+
self.min_size = min_size
|
138 |
+
self.interpolation = interpolation
|
139 |
+
|
140 |
+
def __call__(self, img):
|
141 |
+
|
142 |
+
scale = self.min_size / min(img.width, img.height)
|
143 |
+
new_size = (int(img.height * scale), int(img.width * scale))
|
144 |
+
|
145 |
+
return transforms.functional.resize(img, new_size, self.interpolation)
|
146 |
+
|
147 |
+
class ResizeHard:
|
148 |
+
def __init__(self, size, interpolation=transforms.InterpolationMode.BILINEAR):
|
149 |
+
self.size = size
|
150 |
+
self.interpolation = interpolation
|
151 |
+
|
152 |
+
def __call__(self, img):
|
153 |
+
|
154 |
+
new_size = (int(self.size), int(self.size))
|
155 |
+
|
156 |
+
return transforms.functional.resize(img, new_size, self.interpolation)
|
157 |
+
|
158 |
+
|
159 |
+
class ResizeLongestEdgeInteger:
|
160 |
+
def __init__(self, max_size, interpolation=transforms.InterpolationMode.BILINEAR, integer=16):
|
161 |
+
self.max_size = max_size
|
162 |
+
self.interpolation = interpolation
|
163 |
+
self.integer = integer
|
164 |
+
|
165 |
+
def __call__(self, img):
|
166 |
+
|
167 |
+
scale = self.max_size / max(img.width, img.height)
|
168 |
+
new_size_h = int(img.height * scale) // self.integer * self.integer
|
169 |
+
new_size_w = int(img.width * scale) // self.integer * self.integer
|
170 |
+
new_size = (new_size_h, new_size_w)
|
171 |
+
|
172 |
+
return transforms.functional.resize(img, new_size, self.interpolation)
|
util/seed_all.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def seed_all(seed: int = 0):
|
7 |
+
"""
|
8 |
+
Set random seeds of all components.
|
9 |
+
"""
|
10 |
+
random.seed(seed)
|
11 |
+
np.random.seed(seed)
|
12 |
+
torch.manual_seed(seed)
|
13 |
+
torch.cuda.manual_seed_all(seed)
|