π§ [Update] num_classes position, move to hyper
Browse files- examples/example_train.py +1 -1
- yolo/config/model/v7-base.yaml +0 -2
- yolo/config/model/v9-c.yaml +4 -1
- yolo/model/yolo.py +6 -6
examples/example_train.py
CHANGED
@@ -23,7 +23,7 @@ def main(cfg: Config):
|
|
23 |
prepare_dataset(cfg.download)
|
24 |
|
25 |
dataloader = get_dataloader(cfg)
|
26 |
-
model = get_model(cfg
|
27 |
draw_model(model=model)
|
28 |
# TODO: get_device or rank, for DDP mode
|
29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
23 |
prepare_dataset(cfg.download)
|
24 |
|
25 |
dataloader = get_dataloader(cfg)
|
26 |
+
model = get_model(cfg)
|
27 |
draw_model(model=model)
|
28 |
# TODO: get_device or rank, for DDP mode
|
29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
yolo/config/model/v7-base.yaml
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
num_classes: 80
|
2 |
-
|
3 |
anchor:
|
4 |
reg_max: 16
|
5 |
strides: [8, 16, 32]
|
|
|
|
|
|
|
1 |
anchor:
|
2 |
reg_max: 16
|
3 |
strides: [8, 16, 32]
|
yolo/config/model/v9-c.yaml
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
-
|
|
|
|
|
2 |
|
3 |
model:
|
4 |
backbone:
|
@@ -117,3 +119,4 @@ model:
|
|
117 |
|
118 |
- MultiheadDetection:
|
119 |
source: [A3, A4, A5, P3, P4, P5]
|
|
|
|
1 |
+
anchor:
|
2 |
+
reg_max: 16
|
3 |
+
strides: [8, 16, 32]
|
4 |
|
5 |
model:
|
6 |
backbone:
|
|
|
119 |
|
120 |
- MultiheadDetection:
|
121 |
source: [A3, A4, A5, P3, P4, P5]
|
122 |
+
output: True
|
yolo/model/yolo.py
CHANGED
@@ -4,7 +4,7 @@ import torch.nn as nn
|
|
4 |
from loguru import logger
|
5 |
from omegaconf import ListConfig, OmegaConf
|
6 |
|
7 |
-
from yolo.config.config import Model
|
8 |
from yolo.tools.layer_helper import get_layer_map
|
9 |
|
10 |
|
@@ -17,9 +17,9 @@ class YOLO(nn.Module):
|
|
17 |
parameters, and any other relevant configuration details.
|
18 |
"""
|
19 |
|
20 |
-
def __init__(self, model_cfg: Model):
|
21 |
super(YOLO, self).__init__()
|
22 |
-
self.num_classes =
|
23 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
24 |
self.build_model(model_cfg.model)
|
25 |
|
@@ -101,7 +101,7 @@ class YOLO(nn.Module):
|
|
101 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
102 |
|
103 |
|
104 |
-
def get_model(
|
105 |
"""Constructs and returns a model from a Dictionary configuration file.
|
106 |
|
107 |
Args:
|
@@ -110,7 +110,7 @@ def get_model(model_cfg: dict) -> YOLO:
|
|
110 |
Returns:
|
111 |
YOLO: An instance of the model defined by the given configuration.
|
112 |
"""
|
113 |
-
OmegaConf.set_struct(
|
114 |
-
model = YOLO(
|
115 |
logger.info("β
Success load model")
|
116 |
return model
|
|
|
4 |
from loguru import logger
|
5 |
from omegaconf import ListConfig, OmegaConf
|
6 |
|
7 |
+
from yolo.config.config import Config, Model
|
8 |
from yolo.tools.layer_helper import get_layer_map
|
9 |
|
10 |
|
|
|
17 |
parameters, and any other relevant configuration details.
|
18 |
"""
|
19 |
|
20 |
+
def __init__(self, model_cfg: Model, num_classes: int):
|
21 |
super(YOLO, self).__init__()
|
22 |
+
self.num_classes = num_classes
|
23 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
24 |
self.build_model(model_cfg.model)
|
25 |
|
|
|
101 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
102 |
|
103 |
|
104 |
+
def get_model(cfg: Config) -> YOLO:
|
105 |
"""Constructs and returns a model from a Dictionary configuration file.
|
106 |
|
107 |
Args:
|
|
|
110 |
Returns:
|
111 |
YOLO: An instance of the model defined by the given configuration.
|
112 |
"""
|
113 |
+
OmegaConf.set_struct(cfg.model, False)
|
114 |
+
model = YOLO(cfg.model, cfg.hyper.data.class_num)
|
115 |
logger.info("β
Success load model")
|
116 |
return model
|