π¨ [Update] logging model place
Browse files- examples/example_train.py +0 -2
- yolo/model/yolo.py +3 -2
examples/example_train.py
CHANGED
@@ -13,7 +13,6 @@ from yolo.model.yolo import get_model
|
|
13 |
from yolo.tools.log_helper import custom_logger
|
14 |
from yolo.tools.trainer import Trainer
|
15 |
from yolo.utils.dataloader import get_dataloader
|
16 |
-
from yolo.utils.drawer import draw_model
|
17 |
from yolo.utils.get_dataset import prepare_dataset
|
18 |
|
19 |
|
@@ -24,7 +23,6 @@ def main(cfg: Config):
|
|
24 |
|
25 |
dataloader = get_dataloader(cfg)
|
26 |
model = get_model(cfg)
|
27 |
-
draw_model(model=model)
|
28 |
# TODO: get_device or rank, for DDP mode
|
29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
|
|
|
13 |
from yolo.tools.log_helper import custom_logger
|
14 |
from yolo.tools.trainer import Trainer
|
15 |
from yolo.utils.dataloader import get_dataloader
|
|
|
16 |
from yolo.utils.get_dataset import prepare_dataset
|
17 |
|
18 |
|
|
|
23 |
|
24 |
dataloader = get_dataloader(cfg)
|
25 |
model = get_model(cfg)
|
|
|
26 |
# TODO: get_device or rank, for DDP mode
|
27 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
|
yolo/model/yolo.py
CHANGED
@@ -7,6 +7,7 @@ from omegaconf import ListConfig, OmegaConf
|
|
7 |
from yolo.config.config import Config, Model, YOLOLayer
|
8 |
from yolo.tools.layer_helper import get_layer_map
|
9 |
from yolo.tools.log_helper import log_model
|
|
|
10 |
|
11 |
|
12 |
class YOLO(nn.Module):
|
@@ -24,8 +25,6 @@ class YOLO(nn.Module):
|
|
24 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
25 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
26 |
self.build_model(model_cfg.model)
|
27 |
-
# TODO: Move to other position
|
28 |
-
log_model(self.model)
|
29 |
|
30 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
31 |
self.layer_index = {}
|
@@ -126,4 +125,6 @@ def get_model(cfg: Config) -> YOLO:
|
|
126 |
OmegaConf.set_struct(cfg.model, False)
|
127 |
model = YOLO(cfg.model, cfg.hyper.data.class_num)
|
128 |
logger.info("β
Success load model")
|
|
|
|
|
129 |
return model
|
|
|
7 |
from yolo.config.config import Config, Model, YOLOLayer
|
8 |
from yolo.tools.layer_helper import get_layer_map
|
9 |
from yolo.tools.log_helper import log_model
|
10 |
+
from yolo.utils.drawer import draw_model
|
11 |
|
12 |
|
13 |
class YOLO(nn.Module):
|
|
|
25 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
26 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
27 |
self.build_model(model_cfg.model)
|
|
|
|
|
28 |
|
29 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
30 |
self.layer_index = {}
|
|
|
125 |
OmegaConf.set_struct(cfg.model, False)
|
126 |
model = YOLO(cfg.model, cfg.hyper.data.class_num)
|
127 |
logger.info("β
Success load model")
|
128 |
+
log_model(model.model)
|
129 |
+
draw_model(model=model)
|
130 |
return model
|