henry000 commited on
Commit
936317c
Β·
2 Parent(s): a0c7025 230a441

πŸ”€ [Merge] branch 'main' into SETUP

Browse files
.gitignore CHANGED
@@ -42,6 +42,7 @@ htmlcov/
42
  .coverage
43
  .coverage.*
44
  .cache
 
45
  nosetests.xml
46
  coverage.xml
47
  *.cover
@@ -140,3 +141,6 @@ runs
140
 
141
  # Ignore npm packages (if using frontend components)
142
  node_modules/
 
 
 
 
42
  .coverage
43
  .coverage.*
44
  .cache
45
+ *.cache
46
  nosetests.xml
47
  coverage.xml
48
  *.cover
 
141
 
142
  # Ignore npm packages (if using frontend components)
143
  node_modules/
144
+
145
+ # Not ignore image for demo
146
+ !demo/images/inference/*
demo/images/inference/image.png ADDED
examples/example_inference.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ import hydra
5
+ import torch
6
+
7
+ project_root = Path(__file__).resolve().parent.parent
8
+ sys.path.append(str(project_root))
9
+
10
+ from yolo.config.config import Config
11
+ from yolo.model.yolo import get_model
12
+ from yolo.tools.data_loader import create_dataloader
13
+ from yolo.tools.solver import ModelTester
14
+ from yolo.utils.logging_utils import custom_logger, validate_log_directory
15
+
16
+
17
+ @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
18
+ def main(cfg: Config):
19
+ custom_logger()
20
+ save_path = validate_log_directory(cfg, cfg.name)
21
+
22
+ device = torch.device(cfg.device)
23
+ model = get_model(cfg).to(device)
24
+
25
+ save_path = validate_log_directory(cfg, cfg.name)
26
+ dataloader = create_dataloader(cfg)
27
+ device = torch.device(cfg.device)
28
+ model = get_model(cfg).to(device)
29
+
30
+ tester = ModelTester(cfg, model, save_path, device)
31
+ tester.solve(dataloader)
32
+
33
+
34
+ if __name__ == "__main__":
35
+ main()
examples/example_train.py CHANGED
@@ -3,30 +3,28 @@ from pathlib import Path
3
 
4
  import hydra
5
  import torch
6
- from loguru import logger
7
 
8
  project_root = Path(__file__).resolve().parent.parent
9
  sys.path.append(str(project_root))
10
 
11
  from yolo.config.config import Config
 
12
  from yolo.tools.data_loader import create_dataloader
13
- from yolo.tools.dataset_preparation import prepare_dataset
14
- from yolo.tools.trainer import ModelTrainer
15
  from yolo.utils.logging_utils import custom_logger, validate_log_directory
16
 
17
 
18
  @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
19
  def main(cfg: Config):
20
  custom_logger()
21
- save_path = validate_log_directory(cfg.hyper.general, cfg.name)
22
- if cfg.download.auto:
23
- prepare_dataset(cfg.download)
24
-
25
  dataloader = create_dataloader(cfg)
26
  # TODO: get_device or rank, for DDP mode
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- trainer = ModelTrainer(cfg, save_path, device)
29
- trainer.train(dataloader, cfg.hyper.train.epoch)
 
 
30
 
31
 
32
  if __name__ == "__main__":
 
3
 
4
  import hydra
5
  import torch
 
6
 
7
  project_root = Path(__file__).resolve().parent.parent
8
  sys.path.append(str(project_root))
9
 
10
  from yolo.config.config import Config
11
+ from yolo.model.yolo import get_model
12
  from yolo.tools.data_loader import create_dataloader
13
+ from yolo.tools.solver import ModelTrainer
 
14
  from yolo.utils.logging_utils import custom_logger, validate_log_directory
15
 
16
 
17
  @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
18
  def main(cfg: Config):
19
  custom_logger()
20
+ save_path = validate_log_directory(cfg, cfg.name)
 
 
 
21
  dataloader = create_dataloader(cfg)
22
  # TODO: get_device or rank, for DDP mode
23
+ device = torch.device(cfg.device)
24
+ model = get_model(cfg).to(device)
25
+
26
+ trainer = ModelTrainer(cfg, model, save_path, device)
27
+ trainer.solve(dataloader, cfg.task.epoch)
28
 
29
 
30
  if __name__ == "__main__":
examples/lazy.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ import hydra
5
+ import torch
6
+
7
+ project_root = Path(__file__).resolve().parent.parent
8
+ sys.path.append(str(project_root))
9
+
10
+ from yolo.config.config import Config
11
+ from yolo.model.yolo import get_model
12
+ from yolo.tools.data_loader import create_dataloader
13
+ from yolo.tools.solver import ModelTester, ModelTrainer
14
+ from yolo.utils.logging_utils import custom_logger, validate_log_directory
15
+
16
+
17
+ @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
18
+ def main(cfg: Config):
19
+ custom_logger()
20
+
21
+ custom_logger()
22
+ save_path = validate_log_directory(cfg, cfg.name)
23
+ dataloader = create_dataloader(cfg)
24
+ device = torch.device(cfg.device)
25
+ model = get_model(cfg).to(device)
26
+
27
+ if cfg.task.task == "train":
28
+ trainer = ModelTrainer(cfg, model, save_path, device)
29
+ trainer.solve(dataloader)
30
+
31
+ if cfg.task.task == "inference":
32
+ tester = ModelTester(cfg, model, save_path, device)
33
+ tester.solve(dataloader)
34
+
35
+
36
+ if __name__ == "__main__":
37
+ main()
tests/test_model/test_yolo.py CHANGED
@@ -19,6 +19,7 @@ def test_build_model():
19
  cfg = compose(config_name=config_name)
20
 
21
  OmegaConf.set_struct(cfg.model, False)
 
22
  model = YOLO(cfg.model, 80)
23
  assert len(model.model) == 38
24
 
@@ -26,6 +27,7 @@ def test_build_model():
26
  def test_get_model():
27
  with initialize(config_path=config_path, version_base=None):
28
  cfg = compose(config_name=config_name)
 
29
  model = get_model(cfg)
30
  assert isinstance(model, YOLO)
31
 
 
19
  cfg = compose(config_name=config_name)
20
 
21
  OmegaConf.set_struct(cfg.model, False)
22
+ cfg.weight = None
23
  model = YOLO(cfg.model, 80)
24
  assert len(model.model) == 38
25
 
 
27
  def test_get_model():
28
  with initialize(config_path=config_path, version_base=None):
29
  cfg = compose(config_name=config_name)
30
+ cfg.weight = None
31
  model = get_model(cfg)
32
  assert isinstance(model, YOLO)
33
 
tests/test_utils/test_dataaugment.py CHANGED
@@ -39,7 +39,7 @@ def test_compose():
39
  return image, boxes
40
 
41
  compose = AugmentationComposer([mock_transform, mock_transform])
42
- img = Image.new("RGB", (10, 10), color="blue")
43
  boxes = torch.tensor([[0, 0.2, 0.2, 0.8, 0.8]])
44
 
45
  transformed_img, transformed_boxes = compose(img, boxes)
 
39
  return image, boxes
40
 
41
  compose = AugmentationComposer([mock_transform, mock_transform])
42
+ img = Image.new("RGB", (640, 640), color="blue")
43
  boxes = torch.tensor([[0, 0.2, 0.2, 0.8, 0.8]])
44
 
45
  transformed_img, transformed_boxes = compose(img, boxes)
yolo/config/config.py CHANGED
@@ -135,6 +135,8 @@ class Config:
135
  use_wandb: bool
136
  use_TensorBoard: bool
137
 
 
 
138
 
139
  @dataclass
140
  class YOLOLayer(nn.Module):
 
135
  use_wandb: bool
136
  use_TensorBoard: bool
137
 
138
+ weight: Optional[str]
139
+
140
 
141
  @dataclass
142
  class YOLOLayer(nn.Module):
yolo/config/general.yaml CHANGED
@@ -1,4 +1,4 @@
1
- deivce: [0]
2
  cpu_num: 16
3
 
4
  class_num: 80
@@ -9,4 +9,6 @@ exist_ok: True
9
 
10
  lucky_number: 10
11
  use_wandb: False
12
- use_TensorBoard: False
 
 
 
1
+ device: 0
2
  cpu_num: 16
3
 
4
  class_num: 80
 
9
 
10
  lucky_number: 10
11
  use_wandb: False
12
+ use_TensorBoard: False
13
+
14
+ weight: v9-c.pt
yolo/config/task/dataset/coco.yaml CHANGED
@@ -1,4 +1,6 @@
1
  path: data/coco
 
 
2
 
3
  auto_download:
4
  images:
 
1
  path: data/coco
2
+ train: train2017
3
+
4
 
5
  auto_download:
6
  images:
yolo/config/task/dataset/demo.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ path: demo
2
+
3
+ auto_download:
yolo/config/task/inference.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task: inference
2
+ defaults:
3
+ - dataset: demo
4
+ data:
5
+ batch_size: 16
6
+ shuffle: False
7
+ pin_memory: True
8
+ data_augment: {}
9
+ nms:
10
+ min_confidence: 0.75
11
+ min_iou: 0.5
yolo/model/yolo.py CHANGED
@@ -1,10 +1,13 @@
 
1
  from typing import Any, Dict, List, Union
2
 
 
3
  import torch.nn as nn
4
  from loguru import logger
5
  from omegaconf import ListConfig, OmegaConf
6
 
7
  from yolo.config.config import Config, Model, YOLOLayer
 
8
  from yolo.tools.drawer import draw_model
9
  from yolo.utils.logging_utils import log_model_structure
10
  from yolo.utils.module_utils import get_layer_map
@@ -125,6 +128,14 @@ def get_model(cfg: Config) -> YOLO:
125
  OmegaConf.set_struct(cfg.model, False)
126
  model = YOLO(cfg.model, cfg.class_num)
127
  logger.info("βœ… Success load model")
 
 
 
 
 
 
 
 
128
  log_model_structure(model.model)
129
  draw_model(model=model)
130
  return model
 
1
+ import os
2
  from typing import Any, Dict, List, Union
3
 
4
+ import torch
5
  import torch.nn as nn
6
  from loguru import logger
7
  from omegaconf import ListConfig, OmegaConf
8
 
9
  from yolo.config.config import Config, Model, YOLOLayer
10
+ from yolo.tools.dataset_preparation import prepare_weight
11
  from yolo.tools.drawer import draw_model
12
  from yolo.utils.logging_utils import log_model_structure
13
  from yolo.utils.module_utils import get_layer_map
 
128
  OmegaConf.set_struct(cfg.model, False)
129
  model = YOLO(cfg.model, cfg.class_num)
130
  logger.info("βœ… Success load model")
131
+ if cfg.weight:
132
+ if os.path.exists(cfg.weight):
133
+ model.model.load_state_dict(torch.load(cfg.weight))
134
+ logger.info("βœ… Success load model weight")
135
+ else:
136
+ logger.info(f"🌐 Weight {cfg.weight} not found, try downloading")
137
+ prepare_weight(weight_name=cfg.weight)
138
+
139
  log_model_structure(model.model)
140
  draw_model(model=model)
141
  return model
yolo/tools/dataset_preparation.py CHANGED
@@ -1,10 +1,10 @@
1
  import os
2
  import zipfile
 
3
 
4
  import requests
5
- from hydra import main
6
  from loguru import logger
7
- from tqdm import tqdm
8
 
9
  from yolo.config.config import DatasetConfig
10
 
@@ -13,18 +13,24 @@ def download_file(url, destination):
13
  """
14
  Downloads a file from the specified URL to the destination path with progress logging.
15
  """
16
- logger.info(f"Downloading {os.path.basename(destination)}...")
17
  with requests.get(url, stream=True) as response:
18
  response.raise_for_status()
19
  total_size = int(response.headers.get("content-length", 0))
20
- progress = tqdm(total=total_size, unit="iB", unit_scale=True, desc=os.path.basename(destination), leave=True)
21
-
22
- with open(destination, "wb") as file:
23
- for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
24
- file.write(data)
25
- progress.update(len(data))
26
- progress.close()
27
- logger.info("Download completed.")
 
 
 
 
 
 
 
28
 
29
 
30
  def unzip_file(source, destination):
@@ -46,7 +52,6 @@ def check_files(directory, expected_count=None):
46
  return len(files) == expected_count if expected_count is not None else bool(files)
47
 
48
 
49
- @main(config_path="../config/data", config_name="download", version_base=None)
50
  def prepare_dataset(cfg: DatasetConfig):
51
  """
52
  Prepares dataset by downloading and unzipping if necessary.
@@ -76,6 +81,19 @@ def prepare_dataset(cfg: DatasetConfig):
76
  logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  if __name__ == "__main__":
80
  import sys
81
 
@@ -83,4 +101,4 @@ if __name__ == "__main__":
83
  from utils.logging_utils import custom_logger
84
 
85
  custom_logger()
86
- prepare_dataset()
 
1
  import os
2
  import zipfile
3
+ from typing import Optional
4
 
5
  import requests
 
6
  from loguru import logger
7
+ from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
8
 
9
  from yolo.config.config import DatasetConfig
10
 
 
13
  """
14
  Downloads a file from the specified URL to the destination path with progress logging.
15
  """
 
16
  with requests.get(url, stream=True) as response:
17
  response.raise_for_status()
18
  total_size = int(response.headers.get("content-length", 0))
19
+ with Progress(
20
+ TextColumn("[progress.description]{task.description}"),
21
+ BarColumn(),
22
+ "[progress.percentage]{task.percentage:>3.1f}%",
23
+ "β€’",
24
+ "{task.completed}/{task.total} bytes",
25
+ "β€’",
26
+ TimeRemainingColumn(),
27
+ ) as progress:
28
+ task = progress.add_task(f"πŸ“₯ Downloading {os.path.basename(destination)}...", total=total_size)
29
+ with open(destination, "wb") as file:
30
+ for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
31
+ file.write(data)
32
+ progress.update(task, advance=len(data))
33
+ logger.info("βœ… Download completed.")
34
 
35
 
36
  def unzip_file(source, destination):
 
52
  return len(files) == expected_count if expected_count is not None else bool(files)
53
 
54
 
 
55
  def prepare_dataset(cfg: DatasetConfig):
56
  """
57
  Prepares dataset by downloading and unzipping if necessary.
 
81
  logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
82
 
83
 
84
+ def prepare_weight(downlaod_link: Optional[str] = None, weight_name: str = "v9-c.pt"):
85
+ if downlaod_link is None:
86
+ downlaod_link = "https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/"
87
+ weight_link = f"{downlaod_link}{weight_name}"
88
+
89
+ if os.path.exists(weight_name):
90
+ logger.info(f"Weight file '{weight_name}' already exists.")
91
+ try:
92
+ download_file(weight_link, weight_name)
93
+ except requests.exceptions.RequestException as e:
94
+ logger.warning(f"Failed to download the weight file: {e}")
95
+
96
+
97
  if __name__ == "__main__":
98
  import sys
99
 
 
101
  from utils.logging_utils import custom_logger
102
 
103
  custom_logger()
104
+ prepare_weight()
yolo/tools/{trainer.py β†’ solver.py} RENAMED
@@ -6,8 +6,10 @@ from torch import Tensor
6
  from torch.cuda.amp import GradScaler, autocast
7
 
8
  from yolo.config.config import Config, TrainConfig
9
- from yolo.model.yolo import get_model
 
10
  from yolo.tools.loss_functions import get_loss_function
 
11
  from yolo.utils.logging_utils import ProgressTracker
12
  from yolo.utils.model_utils import (
13
  ExponentialMovingAverage,
@@ -17,16 +19,15 @@ from yolo.utils.model_utils import (
17
 
18
 
19
  class ModelTrainer:
20
- def __init__(self, cfg: Config, save_path: str, device):
21
  train_cfg: TrainConfig = cfg.task
22
- model = get_model(cfg)
23
-
24
- self.model = model.to(device)
25
  self.device = device
26
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
27
  self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
28
  self.loss_fn = get_loss_function(cfg)
29
- self.progress = ProgressTracker(cfg, save_path, use_wandb=True)
 
30
 
31
  if getattr(train_cfg.ema, "enabled", False):
32
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
@@ -75,8 +76,9 @@ class ModelTrainer:
75
  self.ema.restore()
76
  torch.save(checkpoint, filename)
77
 
78
- def train(self, dataloader, num_epochs):
79
  logger.info("πŸš„ Start Training!")
 
80
 
81
  with self.progress.progress:
82
  self.progress.start_train(num_epochs)
@@ -89,3 +91,27 @@ class ModelTrainer:
89
  logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
90
  if (epoch + 1) % 5 == 0:
91
  self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from torch.cuda.amp import GradScaler, autocast
7
 
8
  from yolo.config.config import Config, TrainConfig
9
+ from yolo.model.yolo import YOLO
10
+ from yolo.tools.drawer import draw_bboxes
11
  from yolo.tools.loss_functions import get_loss_function
12
+ from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms
13
  from yolo.utils.logging_utils import ProgressTracker
14
  from yolo.utils.model_utils import (
15
  ExponentialMovingAverage,
 
19
 
20
 
21
  class ModelTrainer:
22
+ def __init__(self, cfg: Config, model: YOLO, save_path: str, device):
23
  train_cfg: TrainConfig = cfg.task
24
+ self.model = model
 
 
25
  self.device = device
26
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
27
  self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
28
  self.loss_fn = get_loss_function(cfg)
29
+ self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
30
+ self.num_epochs = cfg.task.epoch
31
 
32
  if getattr(train_cfg.ema, "enabled", False):
33
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
 
76
  self.ema.restore()
77
  torch.save(checkpoint, filename)
78
 
79
+ def solve(self, dataloader):
80
  logger.info("πŸš„ Start Training!")
81
+ num_epochs = self.num_epochs
82
 
83
  with self.progress.progress:
84
  self.progress.start_train(num_epochs)
 
91
  logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
92
  if (epoch + 1) % 5 == 0:
93
  self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
94
+
95
+
96
+ class ModelTester:
97
+ def __init__(self, cfg: Config, model: YOLO, save_path: str, device):
98
+ self.model = model
99
+ self.device = device
100
+ self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
101
+
102
+ self.anchor2box = AnchorBoxConverter(cfg, device)
103
+ self.nms = cfg.task.nms
104
+ self.save_path = save_path
105
+
106
+ def solve(self, dataloader):
107
+ logger.info("πŸ‘€ Start Inference!")
108
+
109
+ for images, _ in dataloader:
110
+ images = images.to(self.device)
111
+ with torch.no_grad():
112
+ raw_output = self.model(images)
113
+ predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True)
114
+
115
+ nms_out = bbox_nms(predict, self.nms)
116
+ for image, bbox in zip(images, nms_out):
117
+ draw_bboxes(image, bbox, scaled_bbox=False, save_path=self.save_path)