lucytuan commited on
Commit
c8b07ff
Β·
2 Parent(s): de1ec48 61ddf44

πŸ”€ [Merge] branch 'DATASET' of https://github.com/WongKinYiu/yolov9mit into DATASET

Browse files
.github/workflows/main.yaml CHANGED
@@ -1,4 +1,4 @@
1
- name: YOLOv9 - Model test
2
 
3
  on:
4
  push:
@@ -8,7 +8,6 @@ on:
8
 
9
  jobs:
10
  build:
11
-
12
  runs-on: ubuntu-latest
13
 
14
  steps:
@@ -17,10 +16,25 @@ jobs:
17
  uses: actions/setup-python@v2
18
  with:
19
  python-version: 3.8
 
20
  - name: Install dependencies
21
  run: |
22
  python -m pip install --upgrade pip
23
  pip install -r requirements.txt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  - name: Test with pytest
25
- run: |
26
- pytest
 
1
+ name: YOLOv9 - Model test and Code Style Check
2
 
3
  on:
4
  push:
 
8
 
9
  jobs:
10
  build:
 
11
  runs-on: ubuntu-latest
12
 
13
  steps:
 
16
  uses: actions/setup-python@v2
17
  with:
18
  python-version: 3.8
19
+
20
  - name: Install dependencies
21
  run: |
22
  python -m pip install --upgrade pip
23
  pip install -r requirements.txt
24
+
25
+ - name: Install pre-commit
26
+ run: pip install pre-commit
27
+
28
+ - name: Cache pre-commit environment
29
+ uses: actions/cache@v2
30
+ with:
31
+ path: ~/.cache/pre-commit
32
+ key: ${{ runner.os }}-precommit-${{ hashFiles('**/.pre-commit-config.yaml') }}
33
+ restore-keys: |
34
+ ${{ runner.os }}-precommit-
35
+
36
+ - name: Run pre-commit (black and isort)
37
+ run: pre-commit run --all-files
38
+
39
  - name: Test with pytest
40
+ run: pytest
 
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Kin-Yiu, Wong
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -13,20 +13,22 @@ While the project's structure is still being finalized, we ask that potential co
13
 
14
  If you are interested in contributing, please keep an eye on project updates or contact us directly at [[email protected]](mailto:[email protected]) for more information.
15
 
 
16
 
 
17
 
18
  ## To-Do Lists
19
  - [ ] Project Setup
20
  - [X] requirements
21
- - [ ] LICENSE
22
  - [ ] README
23
- - [ ] pytests
24
  - [ ] setup.py/pip install
25
- - [ ] log format
26
  - [ ] hugging face
27
  - [ ] Data proccess
28
  - [ ] Dataset
29
- - [ ] Download script
30
  - [ ] Auto Download
31
  - [ ] xywh, xxyy, xcyc
32
  - [ ] Dataloder
@@ -35,14 +37,14 @@ If you are interested in contributing, please keep an eye on project updates or
35
  - [ ] load model
36
  - [ ] from yaml
37
  - [ ] from github
38
- - [ ] trainer
39
- - [ ] train_one_iter
40
- - [ ] train_one_epoch
41
- - [ ] DDP, EMA, OTA
 
 
42
  - [ ] Run
43
  - [ ] train
44
  - [ ] test
45
  - [ ] demo
46
- - [ ] Configuration
47
- - [ ] hyperparams: dataclass
48
- - [ ] model cfg: yaml
 
13
 
14
  If you are interested in contributing, please keep an eye on project updates or contact us directly at [[email protected]](mailto:[email protected]) for more information.
15
 
16
+ ## Star History
17
 
18
+ [![Star History Chart](https://api.star-history.com/svg?repos=WongKinYiu/yolov9mit&type=Date)](https://star-history.com/#WongKinYiu/yolov9mit&Date)
19
 
20
  ## To-Do Lists
21
  - [ ] Project Setup
22
  - [X] requirements
23
+ - [x] LICENSE
24
  - [ ] README
25
+ - [x] pytests
26
  - [ ] setup.py/pip install
27
+ - [x] log format
28
  - [ ] hugging face
29
  - [ ] Data proccess
30
  - [ ] Dataset
31
+ - [x] Download script
32
  - [ ] Auto Download
33
  - [ ] xywh, xxyy, xcyc
34
  - [ ] Dataloder
 
37
  - [ ] load model
38
  - [ ] from yaml
39
  - [ ] from github
40
+ - [x] trainer
41
+ - [x] train_one_iter
42
+ - [x] train_one_epoch
43
+ - [ ] DDP
44
+ - [x] EMA, OTA
45
+ - [ ] Loss
46
  - [ ] Run
47
  - [ ] train
48
  - [ ] test
49
  - [ ] demo
50
+ - [x] Configuration
 
 
config/config.py CHANGED
@@ -14,6 +14,57 @@ class Download:
14
  path: str
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  @dataclass
18
  class Dataset:
19
  file_name: str
@@ -37,3 +88,4 @@ class Download:
37
  class Config:
38
  model: Model
39
  download: Download
 
 
14
  path: str
15
 
16
 
17
+ @dataclass
18
+ class DataLoaderConfig:
19
+ batch_size: int
20
+ shuffle: bool
21
+ num_workers: int
22
+ pin_memory: bool
23
+
24
+
25
+ @dataclass
26
+ class OptimizerArgs:
27
+ lr: float
28
+ weight_decay: float
29
+
30
+
31
+ @dataclass
32
+ class OptimizerConfig:
33
+ type: str
34
+ args: OptimizerArgs
35
+
36
+
37
+ @dataclass
38
+ class SchedulerArgs:
39
+ step_size: int
40
+ gamma: float
41
+
42
+
43
+ @dataclass
44
+ class SchedulerConfig:
45
+ type: str
46
+ args: SchedulerArgs
47
+
48
+
49
+ @dataclass
50
+ class EMAConfig:
51
+ enabled: bool
52
+ decay: float
53
+
54
+
55
+ @dataclass
56
+ class TrainConfig:
57
+ optimizer: OptimizerConfig
58
+ scheduler: SchedulerConfig
59
+ ema: EMAConfig
60
+
61
+
62
+ @dataclass
63
+ class HyperConfig:
64
+ data: DataLoaderConfig
65
+ train: TrainConfig
66
+
67
+
68
  @dataclass
69
  class Dataset:
70
  file_name: str
 
88
  class Config:
89
  model: Model
90
  download: Download
91
+ hyper: HyperConfig
config/hyper/default.yaml CHANGED
@@ -3,3 +3,17 @@ data:
3
  shuffle: True
4
  num_workers: 4
5
  pin_memory: True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  shuffle: True
4
  num_workers: 4
5
  pin_memory: True
6
+ train:
7
+ optimizer:
8
+ type: Adam
9
+ args:
10
+ lr: 0.001
11
+ weight_decay: 0.0001
12
+ scheduler:
13
+ type: StepLR
14
+ args:
15
+ step_size: 10
16
+ gamma: 0.1
17
+ ema:
18
+ enabled: true
19
+ decay: 0.995
tests/test_utils/test_dataaugment.py CHANGED
@@ -6,23 +6,22 @@ from PIL import Image
6
  from torchvision.transforms import functional as TF
7
 
8
  sys.path.append("./")
9
- from utils.data_augment import Compose, Mosaic, RandomHorizontalFlip
10
 
11
 
12
- def test_random_horizontal_flip():
13
  # Create a mock image and bounding boxes
14
  img = Image.new("RGB", (100, 100), color="red")
15
- boxes = torch.tensor([[1, 0.1, 0.1, 0.9, 0.9]]) # class, xmin, ymin, xmax, ymax
16
 
17
- flip_transform = RandomHorizontalFlip(prob=1) # Set probability to 1 to ensure flip
18
  flipped_img, flipped_boxes = flip_transform(img, boxes)
19
 
20
  # Assert image is flipped by comparing it to a manually flipped image
21
  assert TF.hflip(img) == flipped_img
22
 
23
  # Assert bounding boxes are flipped correctly
24
- expected_boxes = torch.tensor([[1, 0.1, 0.1, 0.9, 0.9]])
25
- expected_boxes[:, [1, 3]] = 1 - expected_boxes[:, [3, 1]]
26
  assert torch.allclose(flipped_boxes, expected_boxes), "Bounding boxes were not flipped correctly"
27
 
28
 
@@ -60,5 +59,5 @@ def test_mosaic():
60
  # Checks here would depend on the exact expected behavior of the mosaic function,
61
  # such as dimensions and content of the output image and boxes.
62
 
63
- assert mosaic_img.size == (200, 200), "Mosaic image size should be doubled"
64
  assert len(mosaic_boxes) > 0, "Should have some bounding boxes"
 
6
  from torchvision.transforms import functional as TF
7
 
8
  sys.path.append("./")
9
+ from utils.data_augment import Compose, HorizontalFlip, Mosaic, VerticalFlip
10
 
11
 
12
+ def test_horizontal_flip():
13
  # Create a mock image and bounding boxes
14
  img = Image.new("RGB", (100, 100), color="red")
15
+ boxes = torch.tensor([[1, 0.05, 0.1, 0.7, 0.9]]) # class, xmin, ymin, xmax, ymax
16
 
17
+ flip_transform = HorizontalFlip(prob=1) # Set probability to 1 to ensure flip
18
  flipped_img, flipped_boxes = flip_transform(img, boxes)
19
 
20
  # Assert image is flipped by comparing it to a manually flipped image
21
  assert TF.hflip(img) == flipped_img
22
 
23
  # Assert bounding boxes are flipped correctly
24
+ expected_boxes = torch.tensor([[1, 0.3, 0.1, 0.95, 0.9]])
 
25
  assert torch.allclose(flipped_boxes, expected_boxes), "Bounding boxes were not flipped correctly"
26
 
27
 
 
59
  # Checks here would depend on the exact expected behavior of the mosaic function,
60
  # such as dimensions and content of the output image and boxes.
61
 
62
+ assert mosaic_img.size == (100, 100), "Mosaic image size should be same"
63
  assert len(mosaic_boxes) > 0, "Should have some bounding boxes"
tools/model_helper.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Type
2
+
3
+ import torch
4
+ from torch.optim import Optimizer
5
+ from torch.optim.lr_scheduler import _LRScheduler
6
+
7
+ from config.config import OptimizerConfig, SchedulerConfig
8
+
9
+
10
+ class EMA:
11
+ def __init__(self, model: torch.nn.Module, decay: float):
12
+ self.model = model
13
+ self.decay = decay
14
+ self.shadow = {name: param.clone().detach() for name, param in model.named_parameters()}
15
+
16
+ def update(self):
17
+ """Update the shadow parameters using the current model parameters."""
18
+ for name, param in self.model.named_parameters():
19
+ assert name in self.shadow, "All model parameters should have a corresponding shadow parameter."
20
+ new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
21
+ self.shadow[name] = new_average.clone()
22
+
23
+ def apply_shadow(self):
24
+ """Apply the shadow parameters to the model."""
25
+ for name, param in self.model.named_parameters():
26
+ param.data.copy_(self.shadow[name])
27
+
28
+ def restore(self):
29
+ """Restore the original parameters from the shadow."""
30
+ for name, param in self.model.named_parameters():
31
+ self.shadow[name].copy_(param.data)
32
+
33
+
34
+ def get_optimizer(model_parameters, optim_cfg: OptimizerConfig) -> Optimizer:
35
+ """Create an optimizer for the given model parameters based on the configuration.
36
+
37
+ Returns:
38
+ An instance of the optimizer configured according to the provided settings.
39
+ """
40
+ optimizer_class: Type[Optimizer] = getattr(torch.optim, optim_cfg.type)
41
+ return optimizer_class(model_parameters, **optim_cfg.args)
42
+
43
+
44
+ def get_scheduler(optimizer: Optimizer, schedul_cfg: SchedulerConfig) -> _LRScheduler:
45
+ """Create a learning rate scheduler for the given optimizer based on the configuration.
46
+
47
+ Returns:
48
+ An instance of the scheduler configured according to the provided settings.
49
+ """
50
+ scheduler_class: Type[_LRScheduler] = getattr(torch.optim.lr_scheduler, schedul_cfg.type)
51
+ return scheduler_class(optimizer, **schedul_cfg.args)
tools/trainer.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from loguru import logger
3
+ from tqdm import tqdm
4
+
5
+ from config.config import TrainConfig
6
+ from model.yolo import YOLO
7
+ from tools.model_helper import EMA, get_optimizer, get_scheduler
8
+ from utils.loss import get_loss_function
9
+
10
+
11
+ class Trainer:
12
+ def __init__(self, model: YOLO, train_cfg: TrainConfig, device):
13
+ self.model = model.to(device)
14
+ self.device = device
15
+ self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
16
+ self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
17
+ self.loss_fn = get_loss_function()
18
+
19
+ if train_cfg.ema.get("enabled", False):
20
+ self.ema = EMA(model, decay=train_cfg.ema.decay)
21
+ else:
22
+ self.ema = None
23
+
24
+ def train_one_batch(self, data, targets):
25
+ data, targets = data.to(self.device), targets.to(self.device)
26
+ self.optimizer.zero_grad()
27
+ outputs = self.model(data)
28
+ loss = self.loss_fn(outputs, targets)
29
+ loss.backward()
30
+ self.optimizer.step()
31
+ if self.ema:
32
+ self.ema.update()
33
+ return loss.item()
34
+
35
+ def train_one_epoch(self, dataloader):
36
+ self.model.train()
37
+ total_loss = 0
38
+ for data, targets in tqdm(dataloader, desc="Training"):
39
+ loss = self.train_one_batch(data, targets)
40
+ total_loss += loss
41
+ if self.scheduler:
42
+ self.scheduler.step()
43
+ return total_loss / len(dataloader)
44
+
45
+ def save_checkpoint(self, epoch, filename="checkpoint.pt"):
46
+ checkpoint = {
47
+ "epoch": epoch,
48
+ "model_state_dict": self.model.state_dict(),
49
+ "optimizer_state_dict": self.optimizer.state_dict(),
50
+ }
51
+ if self.ema:
52
+ self.ema.apply_shadow()
53
+ checkpoint["model_state_dict_ema"] = self.model.state_dict()
54
+ self.ema.restore()
55
+ torch.save(checkpoint, filename)
56
+
57
+ def train(self, dataloader, num_epochs):
58
+ logger.info("start train")
59
+ for epoch in range(num_epochs):
60
+ epoch_loss = self.train_one_epoch(dataloader)
61
+ logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
62
+ if (epoch + 1) % 5 == 0:
63
+ self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
train.py CHANGED
@@ -1,20 +1,27 @@
1
  import hydra
 
2
  from loguru import logger
3
 
4
  from config.config import Config
5
  from model.yolo import get_model
6
  from tools.log_helper import custom_logger
7
- from utils.dataloader import YoloDataset
 
8
  from utils.get_dataset import prepare_dataset
9
 
10
 
11
  @hydra.main(config_path="config", config_name="config", version_base=None)
12
  def main(cfg: Config):
13
- dataset = YoloDataset(cfg)
14
  if cfg.download.auto:
15
  prepare_dataset(cfg.download)
16
 
 
17
  model = get_model(cfg.model)
 
 
 
 
 
18
 
19
 
20
  if __name__ == "__main__":
 
1
  import hydra
2
+ import torch
3
  from loguru import logger
4
 
5
  from config.config import Config
6
  from model.yolo import get_model
7
  from tools.log_helper import custom_logger
8
+ from tools.trainer import Trainer
9
+ from utils.dataloader import get_dataloader
10
  from utils.get_dataset import prepare_dataset
11
 
12
 
13
  @hydra.main(config_path="config", config_name="config", version_base=None)
14
  def main(cfg: Config):
 
15
  if cfg.download.auto:
16
  prepare_dataset(cfg.download)
17
 
18
+ dataloader = get_dataloader(cfg)
19
  model = get_model(cfg.model)
20
+ # TODO: get_device or rank, for DDP mode
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ trainer = Trainer(model, cfg.hyper.train, device)
24
+ trainer.train(dataloader, 10)
25
 
26
 
27
  if __name__ == "__main__":
utils/converter_json2txt.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Dict, List, Optional
4
+
5
+ from tqdm import tqdm
6
+
7
+
8
+ def discretize_categories(categories: List[Dict[str, int]]) -> Dict[int, int]:
9
+ """
10
+ Maps each unique 'id' in the list of category dictionaries to a sequential integer index.
11
+ Indices are assigned based on the sorted 'id' values.
12
+ """
13
+ sorted_categories = sorted(categories, key=lambda category: category["id"])
14
+ return {category["id"]: index for index, category in enumerate(sorted_categories)}
15
+
16
+
17
+ def process_annotations(
18
+ image_annotations: Dict[int, List[Dict]],
19
+ image_info_dict: Dict[int, tuple],
20
+ output_dir: str,
21
+ id_to_idx: Optional[Dict[int, int]] = None,
22
+ ) -> None:
23
+ """
24
+ Process and save annotations to files, with option to remap category IDs.
25
+ """
26
+ for image_id, annotations in tqdm(image_annotations.items(), desc="Processing annotations"):
27
+ file_path = os.path.join(output_dir, f"{image_id:0>12}.txt")
28
+ if not annotations:
29
+ continue
30
+ with open(file_path, "w") as file:
31
+ for annotation in annotations:
32
+ process_annotation(annotation, image_info_dict[image_id], id_to_idx, file)
33
+
34
+
35
+ def process_annotation(annotation: Dict, image_dims: tuple, id_to_idx: Optional[Dict[int, int]], file) -> None:
36
+ """
37
+ Convert a single annotation's segmentation and write it to the open file handle.
38
+ """
39
+ category_id = annotation["category_id"]
40
+ segmentation = (
41
+ annotation["segmentation"][0]
42
+ if annotation["segmentation"] and isinstance(annotation["segmentation"][0], list)
43
+ else None
44
+ )
45
+
46
+ if segmentation is None:
47
+ return
48
+
49
+ img_width, img_height = image_dims
50
+ normalized_segmentation = normalize_segmentation(segmentation, img_width, img_height)
51
+
52
+ if id_to_idx:
53
+ category_id = id_to_idx.get(category_id, category_id)
54
+
55
+ file.write(f"{category_id} {' '.join(normalized_segmentation)}\n")
56
+
57
+
58
+ def normalize_segmentation(segmentation: List[float], img_width: int, img_height: int) -> List[str]:
59
+ """
60
+ Normalize and format segmentation coordinates.
61
+ """
62
+ return [f"{x/img_width:.6f}" if i % 2 == 0 else f"{x/img_height:.6f}" for i, x in enumerate(segmentation)]
63
+
64
+
65
+ def convert_annotations(json_file: str, output_dir: str) -> None:
66
+ """
67
+ Load annotation data from a JSON file and process all annotations.
68
+ """
69
+ with open(json_file) as file:
70
+ data = json.load(file)
71
+
72
+ os.makedirs(output_dir, exist_ok=True)
73
+
74
+ image_info_dict = {img["id"]: (img["width"], img["height"]) for img in data.get("images", [])}
75
+ id_to_idx = discretize_categories(data.get("categories", [])) if "categories" in data else None
76
+ image_annotations = {img_id: [] for img_id in image_info_dict}
77
+
78
+ for annotation in data.get("annotations", []):
79
+ if not annotation.get("iscrowd", False):
80
+ image_annotations[annotation["image_id"]].append(annotation)
81
+
82
+ process_annotations(image_annotations, image_info_dict, output_dir, id_to_idx)
83
+
84
+
85
+ convert_annotations("./data/coco/annotations/instances_train2017.json", "./data/coco/labels/train2017/")
86
+ convert_annotations("./data/coco/annotations/instances_val2017.json", "./data/coco/labels/val2017/")
utils/dataloader.py CHANGED
@@ -96,7 +96,7 @@ class YoloDataset(Dataset):
96
  cache[phase_name] = data
97
 
98
  cache.close()
99
- logger.info("Loaded {} cache", phase_name)
100
  data = cache[phase_name]
101
  return data
102
 
 
96
  cache[phase_name] = data
97
 
98
  cache.close()
99
+ logger.info("πŸ“¦ Loaded {} cache", phase_name)
100
  data = cache[phase_name]
101
  return data
102
 
utils/loss.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ def get_loss_function(*args, **kwargs):
2
+ raise NotImplementedError