henry000 commited on
Commit
856cce6
·
1 Parent(s): 1445811

🔧 [Add] TypeHint for yolo.model

Browse files
Files changed (2) hide show
  1. yolo/config/config.py +13 -0
  2. yolo/model/yolo.py +2 -1
yolo/config/config.py CHANGED
@@ -1,6 +1,8 @@
1
  from dataclasses import dataclass
2
  from typing import Dict, List, Union
3
 
 
 
4
 
5
  @dataclass
6
  class AnchorConfig:
@@ -100,6 +102,17 @@ class Download:
100
  datasets: Datasets
101
 
102
 
 
 
 
 
 
 
 
 
 
 
 
103
  @dataclass
104
  class Config:
105
  model: Model
 
1
  from dataclasses import dataclass
2
  from typing import Dict, List, Union
3
 
4
+ from torch import nn
5
+
6
 
7
  @dataclass
8
  class AnchorConfig:
 
102
  datasets: Datasets
103
 
104
 
105
+ @dataclass
106
+ class YOLOLayer(nn.Module):
107
+ source: Union[int, str, List[int]]
108
+ output: bool
109
+ tags: str
110
+ layer_type: str
111
+
112
+ def __post_init__(self):
113
+ super().__init__()
114
+
115
+
116
  @dataclass
117
  class Config:
118
  model: Model
yolo/model/yolo.py CHANGED
@@ -4,7 +4,7 @@ import torch.nn as nn
4
  from loguru import logger
5
  from omegaconf import ListConfig, OmegaConf
6
 
7
- from yolo.config.config import Config, Model
8
  from yolo.tools.layer_helper import get_layer_map
9
 
10
 
@@ -21,6 +21,7 @@ class YOLO(nn.Module):
21
  super(YOLO, self).__init__()
22
  self.num_classes = num_classes
23
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
 
24
  self.build_model(model_cfg.model)
25
 
26
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
 
4
  from loguru import logger
5
  from omegaconf import ListConfig, OmegaConf
6
 
7
+ from yolo.config.config import Config, Model, YOLOLayer
8
  from yolo.tools.layer_helper import get_layer_map
9
 
10
 
 
21
  super(YOLO, self).__init__()
22
  self.num_classes = num_classes
23
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
24
+ self.model: List[YOLOLayer] = nn.ModuleList()
25
  self.build_model(model_cfg.model)
26
 
27
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):