🔧 [Add] TypeHint for yolo.model
Browse files- yolo/config/config.py +13 -0
- yolo/model/yolo.py +2 -1
yolo/config/config.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
from typing import Dict, List, Union
|
3 |
|
|
|
|
|
4 |
|
5 |
@dataclass
|
6 |
class AnchorConfig:
|
@@ -100,6 +102,17 @@ class Download:
|
|
100 |
datasets: Datasets
|
101 |
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
@dataclass
|
104 |
class Config:
|
105 |
model: Model
|
|
|
1 |
from dataclasses import dataclass
|
2 |
from typing import Dict, List, Union
|
3 |
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
|
7 |
@dataclass
|
8 |
class AnchorConfig:
|
|
|
102 |
datasets: Datasets
|
103 |
|
104 |
|
105 |
+
@dataclass
|
106 |
+
class YOLOLayer(nn.Module):
|
107 |
+
source: Union[int, str, List[int]]
|
108 |
+
output: bool
|
109 |
+
tags: str
|
110 |
+
layer_type: str
|
111 |
+
|
112 |
+
def __post_init__(self):
|
113 |
+
super().__init__()
|
114 |
+
|
115 |
+
|
116 |
@dataclass
|
117 |
class Config:
|
118 |
model: Model
|
yolo/model/yolo.py
CHANGED
@@ -4,7 +4,7 @@ import torch.nn as nn
|
|
4 |
from loguru import logger
|
5 |
from omegaconf import ListConfig, OmegaConf
|
6 |
|
7 |
-
from yolo.config.config import Config, Model
|
8 |
from yolo.tools.layer_helper import get_layer_map
|
9 |
|
10 |
|
@@ -21,6 +21,7 @@ class YOLO(nn.Module):
|
|
21 |
super(YOLO, self).__init__()
|
22 |
self.num_classes = num_classes
|
23 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
|
|
24 |
self.build_model(model_cfg.model)
|
25 |
|
26 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
|
|
4 |
from loguru import logger
|
5 |
from omegaconf import ListConfig, OmegaConf
|
6 |
|
7 |
+
from yolo.config.config import Config, Model, YOLOLayer
|
8 |
from yolo.tools.layer_helper import get_layer_map
|
9 |
|
10 |
|
|
|
21 |
super(YOLO, self).__init__()
|
22 |
self.num_classes = num_classes
|
23 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
24 |
+
self.model: List[YOLOLayer] = nn.ModuleList()
|
25 |
self.build_model(model_cfg.model)
|
26 |
|
27 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|