henry000 commited on
Commit
3eb85fd
Β·
1 Parent(s): ba6baa6

🚚 [Update] the class_num&class_list to dataset

Browse files
yolo/config/config.py CHANGED
@@ -45,6 +45,8 @@ class DownloadOptions:
45
  @dataclass
46
  class DatasetConfig:
47
  path: str
 
 
48
  auto_download: Optional[DownloadOptions]
49
 
50
 
@@ -142,8 +144,6 @@ class Config:
142
  device: Union[str, int, List[int]]
143
  cpu_num: int
144
 
145
- class_num: int
146
- class_list: List[str]
147
  class_idx_id: List[int]
148
  image_size: List[int]
149
 
 
45
  @dataclass
46
  class DatasetConfig:
47
  path: str
48
+ class_num: int
49
+ class_list: List[str]
50
  auto_download: Optional[DownloadOptions]
51
 
52
 
 
144
  device: Union[str, int, List[int]]
145
  cpu_num: int
146
 
 
 
147
  class_idx_id: List[int]
148
  image_size: List[int]
149
 
yolo/config/dataset/coco.yaml CHANGED
@@ -2,6 +2,9 @@ path: data/coco
2
  train: train2017
3
  validation: val2017
4
 
 
 
 
5
  auto_download:
6
  images:
7
  base_url: http://images.cocodataset.org/zips/
 
2
  train: train2017
3
  validation: val2017
4
 
5
+ class_num: 80
6
+ class_list: ['Person', 'Bicycle', 'Car', 'Motorcycle', 'Airplane', 'Bus', 'Train', 'Truck', 'Boat', 'Traffic light', 'Fire hydrant', 'Stop sign', 'Parking meter', 'Bench', 'Bird', 'Cat', 'Dog', 'Horse', 'Sheep', 'Cow', 'Elephant', 'Bear', 'Zebra', 'Giraffe', 'Backpack', 'Umbrella', 'Handbag', 'Tie', 'Suitcase', 'Frisbee', 'Skis', 'Snowboard', 'Sports ball', 'Kite', 'Baseball bat', 'Baseball glove', 'Skateboard', 'Surfboard', 'Tennis racket', 'Bottle', 'Wine glass', 'Cup', 'Fork', 'Knife', 'Spoon', 'Bowl', 'Banana', 'Apple', 'Sandwich', 'Orange', 'Broccoli', 'Carrot', 'Hot dog', 'Pizza', 'Donut', 'Cake', 'Chair', 'Couch', 'Potted plant', 'Bed', 'Dining table', 'Toilet', 'Tv', 'Laptop', 'Mouse', 'Remote', 'Keyboard', 'Cell phone', 'Microwave', 'Oven', 'Toaster', 'Sink', 'Refrigerator', 'Book', 'Clock', 'Vase', 'Scissors', 'Teddy bear', 'Hair drier', 'Toothbrush']
7
+
8
  auto_download:
9
  images:
10
  base_url: http://images.cocodataset.org/zips/
yolo/config/dataset/dev.yaml CHANGED
@@ -2,4 +2,7 @@ path: data/dev
2
  train: train
3
  validation: val
4
 
 
 
 
5
  auto_download:
 
2
  train: train
3
  validation: val
4
 
5
+ class_num: 80
6
+ class_list: ['Person', 'Bicycle', 'Car', 'Motorcycle', 'Airplane', 'Bus', 'Train', 'Truck', 'Boat', 'Traffic light', 'Fire hydrant', 'Stop sign', 'Parking meter', 'Bench', 'Bird', 'Cat', 'Dog', 'Horse', 'Sheep', 'Cow', 'Elephant', 'Bear', 'Zebra', 'Giraffe', 'Backpack', 'Umbrella', 'Handbag', 'Tie', 'Suitcase', 'Frisbee', 'Skis', 'Snowboard', 'Sports ball', 'Kite', 'Baseball bat', 'Baseball glove', 'Skateboard', 'Surfboard', 'Tennis racket', 'Bottle', 'Wine glass', 'Cup', 'Fork', 'Knife', 'Spoon', 'Bowl', 'Banana', 'Apple', 'Sandwich', 'Orange', 'Broccoli', 'Carrot', 'Hot dog', 'Pizza', 'Donut', 'Cake', 'Chair', 'Couch', 'Potted plant', 'Bed', 'Dining table', 'Toilet', 'Tv', 'Laptop', 'Mouse', 'Remote', 'Keyboard', 'Cell phone', 'Microwave', 'Oven', 'Toaster', 'Sink', 'Refrigerator', 'Book', 'Clock', 'Vase', 'Scissors', 'Teddy bear', 'Hair drier', 'Toothbrush']
7
+
8
  auto_download:
yolo/config/dataset/mock.yaml CHANGED
@@ -2,6 +2,9 @@ path: tests/data
2
  train: train
3
  validation: val
4
 
 
 
 
5
  auto_download:
6
  images:
7
  base_url: https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/
 
2
  train: train
3
  validation: val
4
 
5
+ class_num: 80
6
+ class_list: ['Person', 'Bicycle', 'Car', 'Motorcycle', 'Airplane', 'Bus', 'Train', 'Truck', 'Boat', 'Traffic light', 'Fire hydrant', 'Stop sign', 'Parking meter', 'Bench', 'Bird', 'Cat', 'Dog', 'Horse', 'Sheep', 'Cow', 'Elephant', 'Bear', 'Zebra', 'Giraffe', 'Backpack', 'Umbrella', 'Handbag', 'Tie', 'Suitcase', 'Frisbee', 'Skis', 'Snowboard', 'Sports ball', 'Kite', 'Baseball bat', 'Baseball glove', 'Skateboard', 'Surfboard', 'Tennis racket', 'Bottle', 'Wine glass', 'Cup', 'Fork', 'Knife', 'Spoon', 'Bowl', 'Banana', 'Apple', 'Sandwich', 'Orange', 'Broccoli', 'Carrot', 'Hot dog', 'Pizza', 'Donut', 'Cake', 'Chair', 'Couch', 'Potted plant', 'Bed', 'Dining table', 'Toilet', 'Tv', 'Laptop', 'Mouse', 'Remote', 'Keyboard', 'Cell phone', 'Microwave', 'Oven', 'Toaster', 'Sink', 'Refrigerator', 'Book', 'Clock', 'Vase', 'Scissors', 'Teddy bear', 'Hair drier', 'Toothbrush']
7
+
8
  auto_download:
9
  images:
10
  base_url: https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/
yolo/config/general.yaml CHANGED
@@ -1,8 +1,6 @@
1
  device: 0
2
  cpu_num: 16
3
 
4
- class_num: 80
5
- class_list: ['Person', 'Bicycle', 'Car', 'Motorcycle', 'Airplane', 'Bus', 'Train', 'Truck', 'Boat', 'Traffic light', 'Fire hydrant', 'Stop sign', 'Parking meter', 'Bench', 'Bird', 'Cat', 'Dog', 'Horse', 'Sheep', 'Cow', 'Elephant', 'Bear', 'Zebra', 'Giraffe', 'Backpack', 'Umbrella', 'Handbag', 'Tie', 'Suitcase', 'Frisbee', 'Skis', 'Snowboard', 'Sports ball', 'Kite', 'Baseball bat', 'Baseball glove', 'Skateboard', 'Surfboard', 'Tennis racket', 'Bottle', 'Wine glass', 'Cup', 'Fork', 'Knife', 'Spoon', 'Bowl', 'Banana', 'Apple', 'Sandwich', 'Orange', 'Broccoli', 'Carrot', 'Hot dog', 'Pizza', 'Donut', 'Cake', 'Chair', 'Couch', 'Potted plant', 'Bed', 'Dining table', 'Toilet', 'Tv', 'Laptop', 'Mouse', 'Remote', 'Keyboard', 'Cell phone', 'Microwave', 'Oven', 'Toaster', 'Sink', 'Refrigerator', 'Book', 'Clock', 'Vase', 'Scissors', 'Teddy bear', 'Hair drier', 'Toothbrush']
6
  image_size: [640, 640]
7
 
8
  out_path: runs
 
1
  device: 0
2
  cpu_num: 16
3
 
 
 
4
  image_size: [640, 640]
5
 
6
  out_path: runs
yolo/lazy.py CHANGED
@@ -24,7 +24,7 @@ def main(cfg: Config):
24
  if getattr(cfg.task, "fast_inference", False):
25
  model = FastModelLoader(cfg).load_model(device)
26
  else:
27
- model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
28
  model = model.to(device)
29
 
30
  converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)
 
24
  if getattr(cfg.task, "fast_inference", False):
25
  model = FastModelLoader(cfg).load_model(device)
26
  else:
27
+ model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
28
  model = model.to(device)
29
 
30
  converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)
yolo/tools/loss_functions.py CHANGED
@@ -109,7 +109,7 @@ class YOLOLoss:
109
  class DualLoss:
110
  def __init__(self, cfg: Config, vec2box) -> None:
111
  loss_cfg = cfg.task.loss
112
- self.loss = YOLOLoss(loss_cfg, vec2box, class_num=cfg.class_num, reg_max=cfg.model.anchor.reg_max)
113
 
114
  self.aux_rate = loss_cfg.aux
115
 
 
109
  class DualLoss:
110
  def __init__(self, cfg: Config, vec2box) -> None:
111
  loss_cfg = cfg.task.loss
112
+ self.loss = YOLOLoss(loss_cfg, vec2box, class_num=cfg.dataset.class_num, reg_max=cfg.model.anchor.reg_max)
113
 
114
  self.aux_rate = loss_cfg.aux
115
 
yolo/tools/solver.py CHANGED
@@ -162,7 +162,7 @@ class ModelTester:
162
  self.save_path = progress.save_path / "images"
163
  os.makedirs(self.save_path, exist_ok=True)
164
  self.save_predict = getattr(cfg.task, "save_predict", None)
165
- self.idx2label = cfg.class_list
166
 
167
  def solve(self, dataloader: StreamDataLoader):
168
  logger.info("πŸ‘€ Start Inference!")
 
162
  self.save_path = progress.save_path / "images"
163
  os.makedirs(self.save_path, exist_ok=True)
164
  self.save_predict = getattr(cfg.task, "save_predict", None)
165
+ self.idx2label = cfg.dataset.class_list
166
 
167
  def solve(self, dataloader: StreamDataLoader):
168
  logger.info("πŸ‘€ Start Inference!")
yolo/utils/deploy_utils.py CHANGED
@@ -12,6 +12,8 @@ class FastModelLoader:
12
  def __init__(self, cfg: Config):
13
  self.cfg = cfg
14
  self.compiler = cfg.task.fast_inference
 
 
15
  self._validate_compiler()
16
  if cfg.weight == True:
17
  cfg.weight = Path("weights") / f"{cfg.model.name}.pt"
@@ -32,7 +34,7 @@ class FastModelLoader:
32
  return self._load_trt_model().to(device)
33
  elif self.compiler == "deploy":
34
  self.cfg.model.model.auxiliary = {}
35
- return create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).to(device)
36
 
37
  def _load_onnx_model(self, device):
38
  from onnxruntime import InferenceSession
@@ -67,7 +69,7 @@ class FastModelLoader:
67
  from onnxruntime import InferenceSession
68
  from torch.onnx import export
69
 
70
- model = create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).eval()
71
  dummy_input = torch.ones((1, 3, *self.cfg.image_size))
72
  export(
73
  model,
@@ -95,7 +97,7 @@ class FastModelLoader:
95
  def _create_trt_model(self):
96
  from torch2trt import torch2trt
97
 
98
- model = create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).eval()
99
  dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda()
100
  logger.info(f"♻️ Creating TensorRT model")
101
  model_trt = torch2trt(model.cuda(), [dummy_input])
 
12
  def __init__(self, cfg: Config):
13
  self.cfg = cfg
14
  self.compiler = cfg.task.fast_inference
15
+ self.class_num = cfg.dataset.class_num
16
+
17
  self._validate_compiler()
18
  if cfg.weight == True:
19
  cfg.weight = Path("weights") / f"{cfg.model.name}.pt"
 
34
  return self._load_trt_model().to(device)
35
  elif self.compiler == "deploy":
36
  self.cfg.model.model.auxiliary = {}
37
+ return create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).to(device)
38
 
39
  def _load_onnx_model(self, device):
40
  from onnxruntime import InferenceSession
 
69
  from onnxruntime import InferenceSession
70
  from torch.onnx import export
71
 
72
+ model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
73
  dummy_input = torch.ones((1, 3, *self.cfg.image_size))
74
  export(
75
  model,
 
97
  def _create_trt_model(self):
98
  from torch2trt import torch2trt
99
 
100
+ model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
101
  dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda()
102
  logger.info(f"♻️ Creating TensorRT model")
103
  model_trt = torch2trt(model.cuda(), [dummy_input])