π [Merge] branch 'INFERENCE' into TEST
Browse files- yolo/config/config.py +85 -0
- yolo/lazy.py +5 -1
- yolo/tools/data_augmentation.py +23 -18
- yolo/tools/data_loader.py +13 -9
- yolo/tools/solver.py +18 -17
- yolo/utils/logging_utils.py +2 -2
- yolo/utils/model_utils.py +42 -4
yolo/config/config.py
CHANGED
@@ -142,6 +142,7 @@ class Config:
|
|
142 |
|
143 |
class_num: int
|
144 |
class_list: List[str]
|
|
|
145 |
image_size: List[int]
|
146 |
|
147 |
out_path: str
|
@@ -164,3 +165,87 @@ class YOLOLayer(nn.Module):
|
|
164 |
|
165 |
def __post_init__(self):
|
166 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
class_num: int
|
144 |
class_list: List[str]
|
145 |
+
class_idx_id: List[int]
|
146 |
image_size: List[int]
|
147 |
|
148 |
out_path: str
|
|
|
165 |
|
166 |
def __post_init__(self):
|
167 |
super().__init__()
|
168 |
+
|
169 |
+
|
170 |
+
IDX_TO_ID = [
|
171 |
+
1,
|
172 |
+
2,
|
173 |
+
3,
|
174 |
+
4,
|
175 |
+
5,
|
176 |
+
6,
|
177 |
+
7,
|
178 |
+
8,
|
179 |
+
9,
|
180 |
+
10,
|
181 |
+
11,
|
182 |
+
13,
|
183 |
+
14,
|
184 |
+
15,
|
185 |
+
16,
|
186 |
+
17,
|
187 |
+
18,
|
188 |
+
19,
|
189 |
+
20,
|
190 |
+
21,
|
191 |
+
22,
|
192 |
+
23,
|
193 |
+
24,
|
194 |
+
25,
|
195 |
+
27,
|
196 |
+
28,
|
197 |
+
31,
|
198 |
+
32,
|
199 |
+
33,
|
200 |
+
34,
|
201 |
+
35,
|
202 |
+
36,
|
203 |
+
37,
|
204 |
+
38,
|
205 |
+
39,
|
206 |
+
40,
|
207 |
+
41,
|
208 |
+
42,
|
209 |
+
43,
|
210 |
+
44,
|
211 |
+
46,
|
212 |
+
47,
|
213 |
+
48,
|
214 |
+
49,
|
215 |
+
50,
|
216 |
+
51,
|
217 |
+
52,
|
218 |
+
53,
|
219 |
+
54,
|
220 |
+
55,
|
221 |
+
56,
|
222 |
+
57,
|
223 |
+
58,
|
224 |
+
59,
|
225 |
+
60,
|
226 |
+
61,
|
227 |
+
62,
|
228 |
+
63,
|
229 |
+
64,
|
230 |
+
65,
|
231 |
+
67,
|
232 |
+
70,
|
233 |
+
72,
|
234 |
+
73,
|
235 |
+
74,
|
236 |
+
75,
|
237 |
+
76,
|
238 |
+
77,
|
239 |
+
78,
|
240 |
+
79,
|
241 |
+
80,
|
242 |
+
81,
|
243 |
+
82,
|
244 |
+
84,
|
245 |
+
85,
|
246 |
+
86,
|
247 |
+
87,
|
248 |
+
88,
|
249 |
+
89,
|
250 |
+
90,
|
251 |
+
]
|
yolo/lazy.py
CHANGED
@@ -9,7 +9,7 @@ sys.path.append(str(project_root))
|
|
9 |
from yolo.config.config import Config
|
10 |
from yolo.model.yolo import create_model
|
11 |
from yolo.tools.data_loader import create_dataloader
|
12 |
-
from yolo.tools.solver import ModelTester, ModelTrainer
|
13 |
from yolo.utils.bounding_box_utils import Vec2Box
|
14 |
from yolo.utils.deploy_utils import FastModelLoader
|
15 |
from yolo.utils.logging_utils import ProgressLogger
|
@@ -37,6 +37,10 @@ def main(cfg: Config):
|
|
37 |
tester = ModelTester(cfg, model, vec2box, progress, device)
|
38 |
tester.solve(dataloader)
|
39 |
|
|
|
|
|
|
|
|
|
40 |
|
41 |
if __name__ == "__main__":
|
42 |
main()
|
|
|
9 |
from yolo.config.config import Config
|
10 |
from yolo.model.yolo import create_model
|
11 |
from yolo.tools.data_loader import create_dataloader
|
12 |
+
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
13 |
from yolo.utils.bounding_box_utils import Vec2Box
|
14 |
from yolo.utils.deploy_utils import FastModelLoader
|
15 |
from yolo.utils.logging_utils import ProgressLogger
|
|
|
37 |
tester = ModelTester(cfg, model, vec2box, progress, device)
|
38 |
tester.solve(dataloader)
|
39 |
|
40 |
+
if cfg.task.task == "validation":
|
41 |
+
valider = ModelValidator(cfg.task, model, vec2box, progress, device)
|
42 |
+
valider.solve(dataloader)
|
43 |
+
|
44 |
|
45 |
if __name__ == "__main__":
|
46 |
main()
|
yolo/tools/data_augmentation.py
CHANGED
@@ -10,7 +10,7 @@ class AugmentationComposer:
|
|
10 |
def __init__(self, transforms, image_size: int = [640, 640]):
|
11 |
self.transforms = transforms
|
12 |
# TODO: handle List of image_size [640, 640]
|
13 |
-
self.image_size = image_size
|
14 |
self.pad_resize = PadAndResize(self.image_size)
|
15 |
|
16 |
for transform in self.transforms:
|
@@ -29,27 +29,32 @@ class AugmentationComposer:
|
|
29 |
|
30 |
|
31 |
class PadAndResize:
|
32 |
-
def __init__(self, image_size):
|
33 |
"""Initialize the object with the target image size."""
|
34 |
-
self.
|
|
|
35 |
|
36 |
-
def __call__(self, image, boxes):
|
37 |
-
|
38 |
-
scale = self.
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
square_img.paste(image, (left, top))
|
43 |
|
44 |
-
|
|
|
|
|
|
|
45 |
|
46 |
-
boxes[:, 1]
|
47 |
-
boxes[:, 2]
|
48 |
-
boxes[:, 3]
|
49 |
-
boxes[:, 4]
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
return
|
53 |
|
54 |
|
55 |
class HorizontalFlip:
|
@@ -94,7 +99,7 @@ class Mosaic:
|
|
94 |
|
95 |
assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
|
96 |
|
97 |
-
img_sz = self.parent.image_size # Assuming `image_size` is defined in parent
|
98 |
more_data = self.parent.get_more_data(3) # get 3 more images randomly
|
99 |
|
100 |
data = [(image, boxes)] + more_data
|
|
|
10 |
def __init__(self, transforms, image_size: int = [640, 640]):
|
11 |
self.transforms = transforms
|
12 |
# TODO: handle List of image_size [640, 640]
|
13 |
+
self.image_size = image_size
|
14 |
self.pad_resize = PadAndResize(self.image_size)
|
15 |
|
16 |
for transform in self.transforms:
|
|
|
29 |
|
30 |
|
31 |
class PadAndResize:
|
32 |
+
def __init__(self, image_size, background_color=(128, 128, 128)):
|
33 |
"""Initialize the object with the target image size."""
|
34 |
+
self.target_width, self.target_height = image_size
|
35 |
+
self.background_color = background_color
|
36 |
|
37 |
+
def __call__(self, image: Image, boxes):
|
38 |
+
img_width, img_height = image.size
|
39 |
+
scale = min(self.target_width / img_width, self.target_height / img_height)
|
40 |
+
new_width, new_height = int(img_width * scale), int(img_height * scale)
|
41 |
+
|
42 |
+
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
|
|
|
43 |
|
44 |
+
pad_left = (self.target_width - new_width) // 2
|
45 |
+
pad_top = (self.target_height - new_height) // 2
|
46 |
+
padded_image = Image.new("RGB", (self.target_width, self.target_height), self.background_color)
|
47 |
+
padded_image.paste(resized_image, (pad_left, pad_top))
|
48 |
|
49 |
+
boxes[:, 1] *= scale # xmin
|
50 |
+
boxes[:, 2] *= scale # ymin
|
51 |
+
boxes[:, 3] *= scale # xmax
|
52 |
+
boxes[:, 4] *= scale # ymax
|
53 |
+
boxes[:, [1, 3]] += pad_left
|
54 |
+
boxes[:, [2, 4]] += pad_top
|
55 |
|
56 |
+
transform_info = torch.tensor([scale, pad_left, pad_top, pad_left, pad_top])
|
57 |
+
return padded_image, boxes, transform_info
|
58 |
|
59 |
|
60 |
class HorizontalFlip:
|
|
|
99 |
|
100 |
assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
|
101 |
|
102 |
+
img_sz = self.parent.image_size[0] # Assuming `image_size` is defined in parent
|
103 |
more_data = self.parent.get_more_data(3) # get 3 more images randomly
|
104 |
|
105 |
data = [(image, boxes)] + more_data
|
yolo/tools/data_loader.py
CHANGED
@@ -141,16 +141,16 @@ class YoloDataset(Dataset):
|
|
141 |
def get_data(self, idx):
|
142 |
img_path, bboxes = self.data[idx]
|
143 |
img = Image.open(img_path).convert("RGB")
|
144 |
-
return img, bboxes
|
145 |
|
146 |
def get_more_data(self, num: int = 1):
|
147 |
indices = torch.randint(0, len(self), (num,))
|
148 |
-
return [self.get_data(idx) for idx in indices]
|
149 |
|
150 |
def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
|
151 |
-
img, bboxes = self.get_data(idx)
|
152 |
-
img, bboxes,
|
153 |
-
return img, bboxes
|
154 |
|
155 |
def __len__(self) -> int:
|
156 |
return len(self.data)
|
@@ -195,9 +195,11 @@ class YoloDataLoader(DataLoader):
|
|
195 |
batch_targets[idx, :target_size] = batch[idx][1]
|
196 |
batch_targets[:, :, 1:] *= self.image_size
|
197 |
|
198 |
-
batch_images
|
|
|
|
|
199 |
|
200 |
-
return batch_images, batch_targets
|
201 |
|
202 |
|
203 |
def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
|
@@ -261,12 +263,14 @@ class StreamDataLoader:
|
|
261 |
if isinstance(frame, np.ndarray):
|
262 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
263 |
frame = Image.fromarray(frame)
|
|
|
264 |
frame, _, rev_tensor = self.transform(frame, torch.zeros(0, 5))
|
265 |
frame = frame[None]
|
|
|
266 |
if not self.is_stream:
|
267 |
-
self.queue.put(frame)
|
268 |
else:
|
269 |
-
self.current_frame = frame
|
270 |
|
271 |
def __iter__(self) -> Generator[Tensor, None, None]:
|
272 |
return self
|
|
|
141 |
def get_data(self, idx):
|
142 |
img_path, bboxes = self.data[idx]
|
143 |
img = Image.open(img_path).convert("RGB")
|
144 |
+
return img, bboxes, img_path
|
145 |
|
146 |
def get_more_data(self, num: int = 1):
|
147 |
indices = torch.randint(0, len(self), (num,))
|
148 |
+
return [self.get_data(idx)[:2] for idx in indices]
|
149 |
|
150 |
def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
|
151 |
+
img, bboxes, img_path = self.get_data(idx)
|
152 |
+
img, bboxes, rev_tensor = self.transform(img, bboxes)
|
153 |
+
return img, bboxes, rev_tensor, img_path
|
154 |
|
155 |
def __len__(self) -> int:
|
156 |
return len(self.data)
|
|
|
195 |
batch_targets[idx, :target_size] = batch[idx][1]
|
196 |
batch_targets[:, :, 1:] *= self.image_size
|
197 |
|
198 |
+
batch_images, _, batch_reverse, batch_path = zip(*batch)
|
199 |
+
batch_images = torch.stack(batch_images)
|
200 |
+
batch_reverse = torch.stack(batch_reverse)
|
201 |
|
202 |
+
return batch_images, batch_targets, batch_reverse, batch_path
|
203 |
|
204 |
|
205 |
def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
|
|
|
263 |
if isinstance(frame, np.ndarray):
|
264 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
265 |
frame = Image.fromarray(frame)
|
266 |
+
origin_frame = frame
|
267 |
frame, _, rev_tensor = self.transform(frame, torch.zeros(0, 5))
|
268 |
frame = frame[None]
|
269 |
+
rev_tensor = rev_tensor[None]
|
270 |
if not self.is_stream:
|
271 |
+
self.queue.put((frame, rev_tensor, origin_frame))
|
272 |
else:
|
273 |
+
self.current_frame = (frame, rev_tensor, origin_frame)
|
274 |
|
275 |
def __iter__(self) -> Generator[Tensor, None, None]:
|
276 |
return self
|
yolo/tools/solver.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import os
|
2 |
import time
|
3 |
|
@@ -15,12 +16,14 @@ from yolo.model.yolo import YOLO
|
|
15 |
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
16 |
from yolo.tools.drawer import draw_bboxes, draw_model
|
17 |
from yolo.tools.loss_functions import create_loss_function
|
18 |
-
from yolo.utils.bounding_box_utils import Vec2Box
|
19 |
from yolo.utils.logging_utils import ProgressLogger, log_model_structure
|
20 |
from yolo.utils.model_utils import (
|
21 |
ExponentialMovingAverage,
|
|
|
22 |
create_optimizer,
|
23 |
create_scheduler,
|
|
|
24 |
)
|
25 |
|
26 |
|
@@ -72,7 +75,7 @@ class ModelTrainer:
|
|
72 |
self.model.train()
|
73 |
total_loss = 0
|
74 |
|
75 |
-
for images, targets in dataloader:
|
76 |
loss, loss_each = self.train_one_batch(images, targets)
|
77 |
|
78 |
total_loss += loss
|
@@ -136,8 +139,9 @@ class ModelTester:
|
|
136 |
|
137 |
last_time = time.time()
|
138 |
try:
|
139 |
-
for idx, images in enumerate(dataloader):
|
140 |
images = images.to(self.device)
|
|
|
141 |
with torch.no_grad():
|
142 |
predicts = self.model(images)
|
143 |
predicts = self.vec2box(predicts["Main"])
|
@@ -175,32 +179,29 @@ class ModelValidator:
|
|
175 |
validation_cfg: ValidationConfig,
|
176 |
model: YOLO,
|
177 |
vec2box: Vec2Box,
|
178 |
-
device,
|
179 |
progress: ProgressLogger,
|
|
|
180 |
):
|
181 |
self.model = model
|
182 |
-
self.vec2box = vec2box
|
183 |
self.device = device
|
184 |
self.progress = progress
|
185 |
|
186 |
-
self.
|
|
|
187 |
|
188 |
def solve(self, dataloader):
|
189 |
# logger.info("π§ͺ Start Validation!")
|
190 |
self.model.eval()
|
191 |
-
|
192 |
-
iou_thresholds = torch.arange(0.5, 1.0, 0.05)
|
193 |
-
map_all = []
|
194 |
self.progress.start_one_epoch(len(dataloader))
|
195 |
-
for images, targets in dataloader:
|
196 |
-
images, targets = images.to(self.device), targets.to(self.device)
|
197 |
with torch.no_grad():
|
198 |
predicts = self.model(images)
|
199 |
-
|
200 |
-
|
201 |
-
for idx, predict in enumerate(nms_out):
|
202 |
-
map_value = calculate_map(predict, targets[idx], iou_thresholds)
|
203 |
-
map_all.append(map_value[0])
|
204 |
-
self.progress.one_batch(mapp=torch.Tensor(map_all).mean())
|
205 |
|
|
|
206 |
self.progress.finish_one_epoch()
|
|
|
|
|
|
1 |
+
import json
|
2 |
import os
|
3 |
import time
|
4 |
|
|
|
16 |
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
17 |
from yolo.tools.drawer import draw_bboxes, draw_model
|
18 |
from yolo.tools.loss_functions import create_loss_function
|
19 |
+
from yolo.utils.bounding_box_utils import Vec2Box
|
20 |
from yolo.utils.logging_utils import ProgressLogger, log_model_structure
|
21 |
from yolo.utils.model_utils import (
|
22 |
ExponentialMovingAverage,
|
23 |
+
PostProccess,
|
24 |
create_optimizer,
|
25 |
create_scheduler,
|
26 |
+
predicts_to_json,
|
27 |
)
|
28 |
|
29 |
|
|
|
75 |
self.model.train()
|
76 |
total_loss = 0
|
77 |
|
78 |
+
for images, targets, *_ in dataloader:
|
79 |
loss, loss_each = self.train_one_batch(images, targets)
|
80 |
|
81 |
total_loss += loss
|
|
|
139 |
|
140 |
last_time = time.time()
|
141 |
try:
|
142 |
+
for idx, (images, rev_tensor, origin_frame) in enumerate(dataloader):
|
143 |
images = images.to(self.device)
|
144 |
+
rev_tensor = rev_tensor.to(self.device)
|
145 |
with torch.no_grad():
|
146 |
predicts = self.model(images)
|
147 |
predicts = self.vec2box(predicts["Main"])
|
|
|
179 |
validation_cfg: ValidationConfig,
|
180 |
model: YOLO,
|
181 |
vec2box: Vec2Box,
|
|
|
182 |
progress: ProgressLogger,
|
183 |
+
device,
|
184 |
):
|
185 |
self.model = model
|
|
|
186 |
self.device = device
|
187 |
self.progress = progress
|
188 |
|
189 |
+
self.post_proccess = PostProccess(vec2box, validation_cfg.nms)
|
190 |
+
self.json_path = os.path.join(self.progress.save_path, f"predict.json")
|
191 |
|
192 |
def solve(self, dataloader):
|
193 |
# logger.info("π§ͺ Start Validation!")
|
194 |
self.model.eval()
|
195 |
+
predict_json = []
|
|
|
|
|
196 |
self.progress.start_one_epoch(len(dataloader))
|
197 |
+
for images, targets, rev_tensor, img_paths in dataloader:
|
198 |
+
images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
|
199 |
with torch.no_grad():
|
200 |
predicts = self.model(images)
|
201 |
+
predicts = self.post_proccess(predicts, rev_tensor)
|
202 |
+
self.progress.one_batch()
|
|
|
|
|
|
|
|
|
203 |
|
204 |
+
predict_json.extend(predicts_to_json(img_paths, predicts))
|
205 |
self.progress.finish_one_epoch()
|
206 |
+
with open(self.json_path, "w") as f:
|
207 |
+
json.dump(predict_json, f)
|
yolo/utils/logging_utils.py
CHANGED
@@ -72,9 +72,9 @@ class ProgressLogger:
|
|
72 |
self.wandb.log({f"Learning Rate/{lr_name}": lr_value}, step=epoch_idx)
|
73 |
self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
|
74 |
|
75 |
-
def one_batch(self, loss_dict: Dict[str, Tensor] = None
|
76 |
if loss_dict is None:
|
77 |
-
self.progress.update(self.batch_task, advance=1, description=f"[green]
|
78 |
return
|
79 |
if self.use_wandb:
|
80 |
for loss_name, loss_value in loss_dict.items():
|
|
|
72 |
self.wandb.log({f"Learning Rate/{lr_name}": lr_value}, step=epoch_idx)
|
73 |
self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
|
74 |
|
75 |
+
def one_batch(self, loss_dict: Dict[str, Tensor] = None):
|
76 |
if loss_dict is None:
|
77 |
+
self.progress.update(self.batch_task, advance=1, description=f"[green]Validating")
|
78 |
return
|
79 |
if self.use_wandb:
|
80 |
for loss_name, loss_value in loss_dict.items():
|
yolo/utils/model_utils.py
CHANGED
@@ -1,17 +1,18 @@
|
|
1 |
import os
|
2 |
-
from
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
from loguru import logger
|
7 |
from omegaconf import ListConfig
|
8 |
-
from torch import
|
9 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
10 |
from torch.optim import Optimizer
|
11 |
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
|
12 |
|
13 |
-
from yolo.config.config import OptimizerConfig, SchedulerConfig
|
14 |
from yolo.model.yolo import YOLO
|
|
|
15 |
|
16 |
|
17 |
class ExponentialMovingAverage:
|
@@ -93,3 +94,40 @@ def get_device(device_spec: Union[str, int, List[int]]) -> torch.device:
|
|
93 |
device_spec = initialize_distributed()
|
94 |
device = torch.device(device_spec)
|
95 |
return device, ddp_flag
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import List, Optional, Type, Union
|
4 |
|
5 |
import torch
|
6 |
import torch.distributed as dist
|
7 |
from loguru import logger
|
8 |
from omegaconf import ListConfig
|
9 |
+
from torch import Tensor
|
|
|
10 |
from torch.optim import Optimizer
|
11 |
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
|
12 |
|
13 |
+
from yolo.config.config import IDX_TO_ID, NMSConfig, OptimizerConfig, SchedulerConfig
|
14 |
from yolo.model.yolo import YOLO
|
15 |
+
from yolo.utils.bounding_box_utils import bbox_nms, transform_bbox
|
16 |
|
17 |
|
18 |
class ExponentialMovingAverage:
|
|
|
94 |
device_spec = initialize_distributed()
|
95 |
device = torch.device(device_spec)
|
96 |
return device, ddp_flag
|
97 |
+
|
98 |
+
|
99 |
+
class PostProccess:
|
100 |
+
"""
|
101 |
+
TODO: function document
|
102 |
+
scale back the prediction and do nms for pred_bbox
|
103 |
+
"""
|
104 |
+
|
105 |
+
def __init__(self, vec2box, nms_cfg: NMSConfig) -> None:
|
106 |
+
self.vec2box = vec2box
|
107 |
+
self.nms = nms_cfg
|
108 |
+
|
109 |
+
def __call__(self, predict, rev_tensor: Optional[Tensor]):
|
110 |
+
pred_class, _, pred_bbox = self.vec2box(predict["Main"])
|
111 |
+
if rev_tensor is not None:
|
112 |
+
pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]
|
113 |
+
pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms)
|
114 |
+
return pred_bbox
|
115 |
+
|
116 |
+
|
117 |
+
def predicts_to_json(img_paths, predicts):
|
118 |
+
"""
|
119 |
+
TODO: function document
|
120 |
+
turn a batch of imagepath and predicts(n x 6 for each image) to a List of diction(Detection output)
|
121 |
+
"""
|
122 |
+
batch_json = []
|
123 |
+
for img_path, bboxes in zip(img_paths, predicts):
|
124 |
+
bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
|
125 |
+
for cls, *pos, conf in bboxes:
|
126 |
+
bbox = {
|
127 |
+
"image_id": int(Path(img_path).stem),
|
128 |
+
"category_id": IDX_TO_ID[int(cls)],
|
129 |
+
"bbox": [float(p) for p in pos],
|
130 |
+
"score": float(conf),
|
131 |
+
}
|
132 |
+
batch_json.append(bbox)
|
133 |
+
return batch_json
|