henry000 commited on
Commit
1851849
Β·
1 Parent(s): 724fd6f

πŸ”§ [Update] num_classes position, move to hyper

Browse files
examples/example_train.py CHANGED
@@ -23,7 +23,7 @@ def main(cfg: Config):
23
  prepare_dataset(cfg.download)
24
 
25
  dataloader = get_dataloader(cfg)
26
- model = get_model(cfg.model)
27
  draw_model(model=model)
28
  # TODO: get_device or rank, for DDP mode
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
23
  prepare_dataset(cfg.download)
24
 
25
  dataloader = get_dataloader(cfg)
26
+ model = get_model(cfg)
27
  draw_model(model=model)
28
  # TODO: get_device or rank, for DDP mode
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
yolo/config/model/v7-base.yaml CHANGED
@@ -1,5 +1,3 @@
1
- num_classes: 80
2
-
3
  anchor:
4
  reg_max: 16
5
  strides: [8, 16, 32]
 
 
 
1
  anchor:
2
  reg_max: 16
3
  strides: [8, 16, 32]
yolo/config/model/v9-c.yaml CHANGED
@@ -1,4 +1,6 @@
1
- num_classes: 80
 
 
2
 
3
  model:
4
  backbone:
@@ -117,3 +119,4 @@ model:
117
 
118
  - MultiheadDetection:
119
  source: [A3, A4, A5, P3, P4, P5]
 
 
1
+ anchor:
2
+ reg_max: 16
3
+ strides: [8, 16, 32]
4
 
5
  model:
6
  backbone:
 
119
 
120
  - MultiheadDetection:
121
  source: [A3, A4, A5, P3, P4, P5]
122
+ output: True
yolo/model/yolo.py CHANGED
@@ -4,7 +4,7 @@ import torch.nn as nn
4
  from loguru import logger
5
  from omegaconf import ListConfig, OmegaConf
6
 
7
- from yolo.config.config import Model
8
  from yolo.tools.layer_helper import get_layer_map
9
 
10
 
@@ -17,9 +17,9 @@ class YOLO(nn.Module):
17
  parameters, and any other relevant configuration details.
18
  """
19
 
20
- def __init__(self, model_cfg: Model):
21
  super(YOLO, self).__init__()
22
- self.num_classes = model_cfg["num_classes"]
23
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
24
  self.build_model(model_cfg.model)
25
 
@@ -101,7 +101,7 @@ class YOLO(nn.Module):
101
  raise ValueError(f"Unsupported layer type: {layer_type}")
102
 
103
 
104
- def get_model(model_cfg: dict) -> YOLO:
105
  """Constructs and returns a model from a Dictionary configuration file.
106
 
107
  Args:
@@ -110,7 +110,7 @@ def get_model(model_cfg: dict) -> YOLO:
110
  Returns:
111
  YOLO: An instance of the model defined by the given configuration.
112
  """
113
- OmegaConf.set_struct(model_cfg, False)
114
- model = YOLO(model_cfg)
115
  logger.info("βœ… Success load model")
116
  return model
 
4
  from loguru import logger
5
  from omegaconf import ListConfig, OmegaConf
6
 
7
+ from yolo.config.config import Config, Model
8
  from yolo.tools.layer_helper import get_layer_map
9
 
10
 
 
17
  parameters, and any other relevant configuration details.
18
  """
19
 
20
+ def __init__(self, model_cfg: Model, num_classes: int):
21
  super(YOLO, self).__init__()
22
+ self.num_classes = num_classes
23
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
24
  self.build_model(model_cfg.model)
25
 
 
101
  raise ValueError(f"Unsupported layer type: {layer_type}")
102
 
103
 
104
+ def get_model(cfg: Config) -> YOLO:
105
  """Constructs and returns a model from a Dictionary configuration file.
106
 
107
  Args:
 
110
  Returns:
111
  YOLO: An instance of the model defined by the given configuration.
112
  """
113
+ OmegaConf.set_struct(cfg.model, False)
114
+ model = YOLO(cfg.model, cfg.hyper.data.class_num)
115
  logger.info("βœ… Success load model")
116
  return model