guangkaixu commited on
Commit
4d5065f
·
1 Parent(s): 562c833
Files changed (4) hide show
  1. util/__init__.py +0 -0
  2. util/batchsize.py +59 -0
  3. util/image_util.py +172 -0
  4. util/seed_all.py +13 -0
util/__init__.py ADDED
File without changes
util/batchsize.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+
5
+ # Search table for suggested max. inference batch size
6
+ bs_search_table = [
7
+ # tested on A100-PCIE-80GB
8
+ {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
9
+ {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
10
+ # tested on A100-PCIE-40GB
11
+ {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
12
+ {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
13
+ {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
14
+ {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
15
+ # tested on RTX3090, RTX4090
16
+ {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
17
+ {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
18
+ {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
19
+ {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
20
+ {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
21
+ {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
22
+ # tested on GTX1080Ti
23
+ {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
24
+ {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
25
+ {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
26
+ {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
27
+ {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
28
+ ]
29
+
30
+
31
+ def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
32
+ """
33
+ Automatically search for suitable operating batch size.
34
+
35
+ Args:
36
+ ensemble_size (int): Number of predictions to be ensembled
37
+ input_res (int): Operating resolution of the input image.
38
+
39
+ Returns:
40
+ int: Operating batch size
41
+ """
42
+ if not torch.cuda.is_available():
43
+ return 1
44
+
45
+ total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
46
+ filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
47
+ for settings in sorted(
48
+ filtered_bs_search_table,
49
+ key=lambda k: (k["res"], -k["total_vram"]),
50
+ ):
51
+ if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
52
+ bs = settings["bs"]
53
+ if bs > ensemble_size:
54
+ bs = ensemble_size
55
+ elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
56
+ bs = math.ceil(ensemble_size / 2)
57
+ return bs
58
+
59
+ return 1
util/image_util.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+
7
+ def norm_to_rgb(norm):
8
+ # norm: (3, H, W), range from [-1, 1]
9
+ norm_rgb = ((norm + 1) * 0.5) * 255
10
+ norm_rgb = np.clip(norm_rgb, a_min=0, a_max=255)
11
+ norm_rgb = norm_rgb.astype(np.uint8)
12
+ return norm_rgb
13
+
14
+ def colorize_depth_maps(
15
+ depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
16
+ ):
17
+ """
18
+ Colorize depth maps.
19
+ """
20
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
21
+
22
+ if isinstance(depth_map, torch.Tensor):
23
+ depth = depth_map.detach().clone().squeeze().numpy()
24
+ elif isinstance(depth_map, np.ndarray):
25
+ depth = np.squeeze(depth_map.copy())
26
+ # reshape to [ (B,) H, W ]
27
+ if depth.ndim < 3:
28
+ depth = depth[np.newaxis, :, :]
29
+
30
+ # colorize
31
+ cm = matplotlib.colormaps[cmap]
32
+ depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
33
+ img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
34
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
35
+
36
+ if valid_mask is not None:
37
+ if isinstance(depth_map, torch.Tensor):
38
+ valid_mask = valid_mask.detach().numpy()
39
+ valid_mask = np.squeeze(valid_mask) # [H, W] or [B, H, W]
40
+ if valid_mask.ndim < 3:
41
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
42
+ else:
43
+ valid_mask = valid_mask[:, np.newaxis, :, :]
44
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
45
+ img_colored_np[~valid_mask] = 0
46
+
47
+ if isinstance(depth_map, torch.Tensor):
48
+ img_colored = torch.from_numpy(img_colored_np).float()
49
+ elif isinstance(depth_map, np.ndarray):
50
+ img_colored = img_colored_np
51
+
52
+ return img_colored
53
+
54
+
55
+ def chw2hwc(chw):
56
+ assert 3 == len(chw.shape)
57
+ if isinstance(chw, torch.Tensor):
58
+ hwc = torch.permute(chw, (1, 2, 0))
59
+ elif isinstance(chw, np.ndarray):
60
+ hwc = np.moveaxis(chw, 0, -1)
61
+ return hwc
62
+
63
+
64
+ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
65
+ """
66
+ Resize image to limit maximum edge length while keeping aspect ratio
67
+
68
+ Args:
69
+ img (Image.Image): Image to be resized
70
+ max_edge_resolution (int): Maximum edge length (px).
71
+
72
+ Returns:
73
+ Image.Image: Resized image.
74
+ """
75
+ original_width, original_height = img.size
76
+ downscale_factor = min(
77
+ max_edge_resolution / original_width, max_edge_resolution / original_height
78
+ )
79
+
80
+ new_width = int(original_width * downscale_factor)
81
+ new_height = int(original_height * downscale_factor)
82
+
83
+ resized_img = img.resize((new_width, new_height))
84
+ return resized_img
85
+
86
+ def resize_max_res_integer_16(img: Image.Image, max_edge_resolution: int) -> Image.Image:
87
+ """
88
+ Resize image to limit maximum edge length while keeping aspect ratio
89
+
90
+ Args:
91
+ img (Image.Image): Image to be resized
92
+ max_edge_resolution (int): Maximum edge length (px).
93
+
94
+ Returns:
95
+ Image.Image: Resized image.
96
+ """
97
+ original_width, original_height = img.size
98
+ downscale_factor = min(
99
+ max_edge_resolution / original_width, max_edge_resolution / original_height
100
+ )
101
+
102
+ new_width = int(original_width * downscale_factor) // 16 * 16 # make sure it is integer multiples of 16, used for pixart
103
+ new_height = int(original_height * downscale_factor) // 16 * 16 # make sure it is integer multiples of 16, used for pixart
104
+
105
+ resized_img = img.resize((new_width, new_height))
106
+ return resized_img
107
+
108
+ def resize_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
109
+ """
110
+ Resize image to limit maximum edge length while keeping aspect ratio
111
+
112
+ Args:
113
+ img (Image.Image): Image to be resized
114
+ max_edge_resolution (int): Maximum edge length (px).
115
+
116
+ Returns:
117
+ Image.Image: Resized image.
118
+ """
119
+
120
+ resized_img = img.resize((max_edge_resolution, max_edge_resolution))
121
+ return resized_img
122
+
123
+ class ResizeLongestEdge:
124
+ def __init__(self, max_size, interpolation=transforms.InterpolationMode.BILINEAR):
125
+ self.max_size = max_size
126
+ self.interpolation = interpolation
127
+
128
+ def __call__(self, img):
129
+
130
+ scale = self.max_size / max(img.width, img.height)
131
+ new_size = (int(img.height * scale), int(img.width * scale))
132
+
133
+ return transforms.functional.resize(img, new_size, self.interpolation)
134
+
135
+ class ResizeShortestEdge:
136
+ def __init__(self, min_size, interpolation=transforms.InterpolationMode.BILINEAR):
137
+ self.min_size = min_size
138
+ self.interpolation = interpolation
139
+
140
+ def __call__(self, img):
141
+
142
+ scale = self.min_size / min(img.width, img.height)
143
+ new_size = (int(img.height * scale), int(img.width * scale))
144
+
145
+ return transforms.functional.resize(img, new_size, self.interpolation)
146
+
147
+ class ResizeHard:
148
+ def __init__(self, size, interpolation=transforms.InterpolationMode.BILINEAR):
149
+ self.size = size
150
+ self.interpolation = interpolation
151
+
152
+ def __call__(self, img):
153
+
154
+ new_size = (int(self.size), int(self.size))
155
+
156
+ return transforms.functional.resize(img, new_size, self.interpolation)
157
+
158
+
159
+ class ResizeLongestEdgeInteger:
160
+ def __init__(self, max_size, interpolation=transforms.InterpolationMode.BILINEAR, integer=16):
161
+ self.max_size = max_size
162
+ self.interpolation = interpolation
163
+ self.integer = integer
164
+
165
+ def __call__(self, img):
166
+
167
+ scale = self.max_size / max(img.width, img.height)
168
+ new_size_h = int(img.height * scale) // self.integer * self.integer
169
+ new_size_w = int(img.width * scale) // self.integer * self.integer
170
+ new_size = (new_size_h, new_size_w)
171
+
172
+ return transforms.functional.resize(img, new_size, self.interpolation)
util/seed_all.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+
5
+
6
+ def seed_all(seed: int = 0):
7
+ """
8
+ Set random seeds of all components.
9
+ """
10
+ random.seed(seed)
11
+ np.random.seed(seed)
12
+ torch.manual_seed(seed)
13
+ torch.cuda.manual_seed_all(seed)