π [Update] the class_num&class_list to dataset
Browse files- yolo/config/config.py +2 -2
- yolo/config/dataset/coco.yaml +3 -0
- yolo/config/dataset/dev.yaml +3 -0
- yolo/config/dataset/mock.yaml +3 -0
- yolo/config/general.yaml +0 -2
- yolo/lazy.py +1 -1
- yolo/tools/loss_functions.py +1 -1
- yolo/tools/solver.py +1 -1
- yolo/utils/deploy_utils.py +5 -3
yolo/config/config.py
CHANGED
@@ -45,6 +45,8 @@ class DownloadOptions:
|
|
45 |
@dataclass
|
46 |
class DatasetConfig:
|
47 |
path: str
|
|
|
|
|
48 |
auto_download: Optional[DownloadOptions]
|
49 |
|
50 |
|
@@ -142,8 +144,6 @@ class Config:
|
|
142 |
device: Union[str, int, List[int]]
|
143 |
cpu_num: int
|
144 |
|
145 |
-
class_num: int
|
146 |
-
class_list: List[str]
|
147 |
class_idx_id: List[int]
|
148 |
image_size: List[int]
|
149 |
|
|
|
45 |
@dataclass
|
46 |
class DatasetConfig:
|
47 |
path: str
|
48 |
+
class_num: int
|
49 |
+
class_list: List[str]
|
50 |
auto_download: Optional[DownloadOptions]
|
51 |
|
52 |
|
|
|
144 |
device: Union[str, int, List[int]]
|
145 |
cpu_num: int
|
146 |
|
|
|
|
|
147 |
class_idx_id: List[int]
|
148 |
image_size: List[int]
|
149 |
|
yolo/config/dataset/coco.yaml
CHANGED
@@ -2,6 +2,9 @@ path: data/coco
|
|
2 |
train: train2017
|
3 |
validation: val2017
|
4 |
|
|
|
|
|
|
|
5 |
auto_download:
|
6 |
images:
|
7 |
base_url: http://images.cocodataset.org/zips/
|
|
|
2 |
train: train2017
|
3 |
validation: val2017
|
4 |
|
5 |
+
class_num: 80
|
6 |
+
class_list: ['Person', 'Bicycle', 'Car', 'Motorcycle', 'Airplane', 'Bus', 'Train', 'Truck', 'Boat', 'Traffic light', 'Fire hydrant', 'Stop sign', 'Parking meter', 'Bench', 'Bird', 'Cat', 'Dog', 'Horse', 'Sheep', 'Cow', 'Elephant', 'Bear', 'Zebra', 'Giraffe', 'Backpack', 'Umbrella', 'Handbag', 'Tie', 'Suitcase', 'Frisbee', 'Skis', 'Snowboard', 'Sports ball', 'Kite', 'Baseball bat', 'Baseball glove', 'Skateboard', 'Surfboard', 'Tennis racket', 'Bottle', 'Wine glass', 'Cup', 'Fork', 'Knife', 'Spoon', 'Bowl', 'Banana', 'Apple', 'Sandwich', 'Orange', 'Broccoli', 'Carrot', 'Hot dog', 'Pizza', 'Donut', 'Cake', 'Chair', 'Couch', 'Potted plant', 'Bed', 'Dining table', 'Toilet', 'Tv', 'Laptop', 'Mouse', 'Remote', 'Keyboard', 'Cell phone', 'Microwave', 'Oven', 'Toaster', 'Sink', 'Refrigerator', 'Book', 'Clock', 'Vase', 'Scissors', 'Teddy bear', 'Hair drier', 'Toothbrush']
|
7 |
+
|
8 |
auto_download:
|
9 |
images:
|
10 |
base_url: http://images.cocodataset.org/zips/
|
yolo/config/dataset/dev.yaml
CHANGED
@@ -2,4 +2,7 @@ path: data/dev
|
|
2 |
train: train
|
3 |
validation: val
|
4 |
|
|
|
|
|
|
|
5 |
auto_download:
|
|
|
2 |
train: train
|
3 |
validation: val
|
4 |
|
5 |
+
class_num: 80
|
6 |
+
class_list: ['Person', 'Bicycle', 'Car', 'Motorcycle', 'Airplane', 'Bus', 'Train', 'Truck', 'Boat', 'Traffic light', 'Fire hydrant', 'Stop sign', 'Parking meter', 'Bench', 'Bird', 'Cat', 'Dog', 'Horse', 'Sheep', 'Cow', 'Elephant', 'Bear', 'Zebra', 'Giraffe', 'Backpack', 'Umbrella', 'Handbag', 'Tie', 'Suitcase', 'Frisbee', 'Skis', 'Snowboard', 'Sports ball', 'Kite', 'Baseball bat', 'Baseball glove', 'Skateboard', 'Surfboard', 'Tennis racket', 'Bottle', 'Wine glass', 'Cup', 'Fork', 'Knife', 'Spoon', 'Bowl', 'Banana', 'Apple', 'Sandwich', 'Orange', 'Broccoli', 'Carrot', 'Hot dog', 'Pizza', 'Donut', 'Cake', 'Chair', 'Couch', 'Potted plant', 'Bed', 'Dining table', 'Toilet', 'Tv', 'Laptop', 'Mouse', 'Remote', 'Keyboard', 'Cell phone', 'Microwave', 'Oven', 'Toaster', 'Sink', 'Refrigerator', 'Book', 'Clock', 'Vase', 'Scissors', 'Teddy bear', 'Hair drier', 'Toothbrush']
|
7 |
+
|
8 |
auto_download:
|
yolo/config/dataset/mock.yaml
CHANGED
@@ -2,6 +2,9 @@ path: tests/data
|
|
2 |
train: train
|
3 |
validation: val
|
4 |
|
|
|
|
|
|
|
5 |
auto_download:
|
6 |
images:
|
7 |
base_url: https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/
|
|
|
2 |
train: train
|
3 |
validation: val
|
4 |
|
5 |
+
class_num: 80
|
6 |
+
class_list: ['Person', 'Bicycle', 'Car', 'Motorcycle', 'Airplane', 'Bus', 'Train', 'Truck', 'Boat', 'Traffic light', 'Fire hydrant', 'Stop sign', 'Parking meter', 'Bench', 'Bird', 'Cat', 'Dog', 'Horse', 'Sheep', 'Cow', 'Elephant', 'Bear', 'Zebra', 'Giraffe', 'Backpack', 'Umbrella', 'Handbag', 'Tie', 'Suitcase', 'Frisbee', 'Skis', 'Snowboard', 'Sports ball', 'Kite', 'Baseball bat', 'Baseball glove', 'Skateboard', 'Surfboard', 'Tennis racket', 'Bottle', 'Wine glass', 'Cup', 'Fork', 'Knife', 'Spoon', 'Bowl', 'Banana', 'Apple', 'Sandwich', 'Orange', 'Broccoli', 'Carrot', 'Hot dog', 'Pizza', 'Donut', 'Cake', 'Chair', 'Couch', 'Potted plant', 'Bed', 'Dining table', 'Toilet', 'Tv', 'Laptop', 'Mouse', 'Remote', 'Keyboard', 'Cell phone', 'Microwave', 'Oven', 'Toaster', 'Sink', 'Refrigerator', 'Book', 'Clock', 'Vase', 'Scissors', 'Teddy bear', 'Hair drier', 'Toothbrush']
|
7 |
+
|
8 |
auto_download:
|
9 |
images:
|
10 |
base_url: https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/
|
yolo/config/general.yaml
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
device: 0
|
2 |
cpu_num: 16
|
3 |
|
4 |
-
class_num: 80
|
5 |
-
class_list: ['Person', 'Bicycle', 'Car', 'Motorcycle', 'Airplane', 'Bus', 'Train', 'Truck', 'Boat', 'Traffic light', 'Fire hydrant', 'Stop sign', 'Parking meter', 'Bench', 'Bird', 'Cat', 'Dog', 'Horse', 'Sheep', 'Cow', 'Elephant', 'Bear', 'Zebra', 'Giraffe', 'Backpack', 'Umbrella', 'Handbag', 'Tie', 'Suitcase', 'Frisbee', 'Skis', 'Snowboard', 'Sports ball', 'Kite', 'Baseball bat', 'Baseball glove', 'Skateboard', 'Surfboard', 'Tennis racket', 'Bottle', 'Wine glass', 'Cup', 'Fork', 'Knife', 'Spoon', 'Bowl', 'Banana', 'Apple', 'Sandwich', 'Orange', 'Broccoli', 'Carrot', 'Hot dog', 'Pizza', 'Donut', 'Cake', 'Chair', 'Couch', 'Potted plant', 'Bed', 'Dining table', 'Toilet', 'Tv', 'Laptop', 'Mouse', 'Remote', 'Keyboard', 'Cell phone', 'Microwave', 'Oven', 'Toaster', 'Sink', 'Refrigerator', 'Book', 'Clock', 'Vase', 'Scissors', 'Teddy bear', 'Hair drier', 'Toothbrush']
|
6 |
image_size: [640, 640]
|
7 |
|
8 |
out_path: runs
|
|
|
1 |
device: 0
|
2 |
cpu_num: 16
|
3 |
|
|
|
|
|
4 |
image_size: [640, 640]
|
5 |
|
6 |
out_path: runs
|
yolo/lazy.py
CHANGED
@@ -24,7 +24,7 @@ def main(cfg: Config):
|
|
24 |
if getattr(cfg.task, "fast_inference", False):
|
25 |
model = FastModelLoader(cfg).load_model(device)
|
26 |
else:
|
27 |
-
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
|
28 |
model = model.to(device)
|
29 |
|
30 |
converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)
|
|
|
24 |
if getattr(cfg.task, "fast_inference", False):
|
25 |
model = FastModelLoader(cfg).load_model(device)
|
26 |
else:
|
27 |
+
model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
|
28 |
model = model.to(device)
|
29 |
|
30 |
converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)
|
yolo/tools/loss_functions.py
CHANGED
@@ -109,7 +109,7 @@ class YOLOLoss:
|
|
109 |
class DualLoss:
|
110 |
def __init__(self, cfg: Config, vec2box) -> None:
|
111 |
loss_cfg = cfg.task.loss
|
112 |
-
self.loss = YOLOLoss(loss_cfg, vec2box, class_num=cfg.class_num, reg_max=cfg.model.anchor.reg_max)
|
113 |
|
114 |
self.aux_rate = loss_cfg.aux
|
115 |
|
|
|
109 |
class DualLoss:
|
110 |
def __init__(self, cfg: Config, vec2box) -> None:
|
111 |
loss_cfg = cfg.task.loss
|
112 |
+
self.loss = YOLOLoss(loss_cfg, vec2box, class_num=cfg.dataset.class_num, reg_max=cfg.model.anchor.reg_max)
|
113 |
|
114 |
self.aux_rate = loss_cfg.aux
|
115 |
|
yolo/tools/solver.py
CHANGED
@@ -162,7 +162,7 @@ class ModelTester:
|
|
162 |
self.save_path = progress.save_path / "images"
|
163 |
os.makedirs(self.save_path, exist_ok=True)
|
164 |
self.save_predict = getattr(cfg.task, "save_predict", None)
|
165 |
-
self.idx2label = cfg.class_list
|
166 |
|
167 |
def solve(self, dataloader: StreamDataLoader):
|
168 |
logger.info("π Start Inference!")
|
|
|
162 |
self.save_path = progress.save_path / "images"
|
163 |
os.makedirs(self.save_path, exist_ok=True)
|
164 |
self.save_predict = getattr(cfg.task, "save_predict", None)
|
165 |
+
self.idx2label = cfg.dataset.class_list
|
166 |
|
167 |
def solve(self, dataloader: StreamDataLoader):
|
168 |
logger.info("π Start Inference!")
|
yolo/utils/deploy_utils.py
CHANGED
@@ -12,6 +12,8 @@ class FastModelLoader:
|
|
12 |
def __init__(self, cfg: Config):
|
13 |
self.cfg = cfg
|
14 |
self.compiler = cfg.task.fast_inference
|
|
|
|
|
15 |
self._validate_compiler()
|
16 |
if cfg.weight == True:
|
17 |
cfg.weight = Path("weights") / f"{cfg.model.name}.pt"
|
@@ -32,7 +34,7 @@ class FastModelLoader:
|
|
32 |
return self._load_trt_model().to(device)
|
33 |
elif self.compiler == "deploy":
|
34 |
self.cfg.model.model.auxiliary = {}
|
35 |
-
return create_model(self.cfg.model, class_num=self.
|
36 |
|
37 |
def _load_onnx_model(self, device):
|
38 |
from onnxruntime import InferenceSession
|
@@ -67,7 +69,7 @@ class FastModelLoader:
|
|
67 |
from onnxruntime import InferenceSession
|
68 |
from torch.onnx import export
|
69 |
|
70 |
-
model = create_model(self.cfg.model, class_num=self.
|
71 |
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
72 |
export(
|
73 |
model,
|
@@ -95,7 +97,7 @@ class FastModelLoader:
|
|
95 |
def _create_trt_model(self):
|
96 |
from torch2trt import torch2trt
|
97 |
|
98 |
-
model = create_model(self.cfg.model, class_num=self.
|
99 |
dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda()
|
100 |
logger.info(f"β»οΈ Creating TensorRT model")
|
101 |
model_trt = torch2trt(model.cuda(), [dummy_input])
|
|
|
12 |
def __init__(self, cfg: Config):
|
13 |
self.cfg = cfg
|
14 |
self.compiler = cfg.task.fast_inference
|
15 |
+
self.class_num = cfg.dataset.class_num
|
16 |
+
|
17 |
self._validate_compiler()
|
18 |
if cfg.weight == True:
|
19 |
cfg.weight = Path("weights") / f"{cfg.model.name}.pt"
|
|
|
34 |
return self._load_trt_model().to(device)
|
35 |
elif self.compiler == "deploy":
|
36 |
self.cfg.model.model.auxiliary = {}
|
37 |
+
return create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).to(device)
|
38 |
|
39 |
def _load_onnx_model(self, device):
|
40 |
from onnxruntime import InferenceSession
|
|
|
69 |
from onnxruntime import InferenceSession
|
70 |
from torch.onnx import export
|
71 |
|
72 |
+
model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
|
73 |
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
74 |
export(
|
75 |
model,
|
|
|
97 |
def _create_trt_model(self):
|
98 |
from torch2trt import torch2trt
|
99 |
|
100 |
+
model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
|
101 |
dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda()
|
102 |
logger.info(f"β»οΈ Creating TensorRT model")
|
103 |
model_trt = torch2trt(model.cuda(), [dummy_input])
|