✏️ [Fix] Process <- Proccess typo
Browse files- yolo/__init__.py +2 -2
- yolo/tools/data_loader.py +1 -1
- yolo/tools/solver.py +3 -3
- yolo/utils/model_utils.py +1 -1
yolo/__init__.py
CHANGED
@@ -10,7 +10,7 @@ from yolo.utils.logging_utils import (
|
|
10 |
YOLORichModelSummary,
|
11 |
YOLORichProgressBar,
|
12 |
)
|
13 |
-
from yolo.utils.model_utils import
|
14 |
|
15 |
all = [
|
16 |
"create_model",
|
@@ -29,5 +29,5 @@ all = [
|
|
29 |
"create_dataloader",
|
30 |
"FastModelLoader",
|
31 |
"TrainModel",
|
32 |
-
"
|
33 |
]
|
|
|
10 |
YOLORichModelSummary,
|
11 |
YOLORichProgressBar,
|
12 |
)
|
13 |
+
from yolo.utils.model_utils import PostProcess
|
14 |
|
15 |
all = [
|
16 |
"create_model",
|
|
|
29 |
"create_dataloader",
|
30 |
"FastModelLoader",
|
31 |
"TrainModel",
|
32 |
+
"PostProcess",
|
33 |
]
|
yolo/tools/data_loader.py
CHANGED
@@ -170,7 +170,7 @@ def collate_fn(batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]
|
|
170 |
"""
|
171 |
batch_size = len(batch)
|
172 |
target_sizes = [item[1].size(0) for item in batch]
|
173 |
-
# TODO: Improve readability of these
|
174 |
# TODO: remove maxBbox or reduce loss function memory usage
|
175 |
batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5)
|
176 |
batch_targets[:, :, 0] = -1
|
|
|
170 |
"""
|
171 |
batch_size = len(batch)
|
172 |
target_sizes = [item[1].size(0) for item in batch]
|
173 |
+
# TODO: Improve readability of these process
|
174 |
# TODO: remove maxBbox or reduce loss function memory usage
|
175 |
batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5)
|
176 |
batch_targets[:, :, 0] = -1
|
yolo/tools/solver.py
CHANGED
@@ -6,7 +6,7 @@ from yolo.model.yolo import create_model
|
|
6 |
from yolo.tools.data_loader import create_dataloader
|
7 |
from yolo.tools.loss_functions import create_loss_function
|
8 |
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
|
9 |
-
from yolo.utils.model_utils import
|
10 |
|
11 |
|
12 |
class BaseModel(LightningModule):
|
@@ -34,14 +34,14 @@ class ValidateModel(BaseModel):
|
|
34 |
self.vec2box = create_converter(
|
35 |
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
|
36 |
)
|
37 |
-
self.
|
38 |
|
39 |
def val_dataloader(self):
|
40 |
return self.val_loader
|
41 |
|
42 |
def validation_step(self, batch, batch_idx):
|
43 |
batch_size, images, targets, rev_tensor, img_paths = batch
|
44 |
-
predicts = self.
|
45 |
batch_metrics = self.metric(
|
46 |
[to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
|
47 |
)
|
|
|
6 |
from yolo.tools.data_loader import create_dataloader
|
7 |
from yolo.tools.loss_functions import create_loss_function
|
8 |
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
|
9 |
+
from yolo.utils.model_utils import PostProcess, create_optimizer, create_scheduler
|
10 |
|
11 |
|
12 |
class BaseModel(LightningModule):
|
|
|
34 |
self.vec2box = create_converter(
|
35 |
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
|
36 |
)
|
37 |
+
self.post_process = PostProcess(self.vec2box, self.validation_cfg.nms)
|
38 |
|
39 |
def val_dataloader(self):
|
40 |
return self.val_loader
|
41 |
|
42 |
def validation_step(self, batch, batch_idx):
|
43 |
batch_size, images, targets, rev_tensor, img_paths = batch
|
44 |
+
predicts = self.post_process(self(images))
|
45 |
batch_metrics = self.metric(
|
46 |
[to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
|
47 |
)
|
yolo/utils/model_utils.py
CHANGED
@@ -124,7 +124,7 @@ def get_device(device_spec: Union[str, int, List[int]]) -> torch.device:
|
|
124 |
return device, ddp_flag
|
125 |
|
126 |
|
127 |
-
class
|
128 |
"""
|
129 |
TODO: function document
|
130 |
scale back the prediction and do nms for pred_bbox
|
|
|
124 |
return device, ddp_flag
|
125 |
|
126 |
|
127 |
+
class PostProcess:
|
128 |
"""
|
129 |
TODO: function document
|
130 |
scale back the prediction and do nms for pred_bbox
|