π [Merge] branch 'main' into DATASET
Browse files- .github/workflows/main.yaml +18 -4
- config/config.py +52 -0
- config/hyper/default.yaml +14 -0
- tests/test_utils/test_dataaugment.py +6 -7
- tools/model_helper.py +51 -0
- tools/trainer.py +63 -0
- train.py +9 -2
- utils/loss.py +2 -0
.github/workflows/main.yaml
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
name: YOLOv9 - Model test
|
2 |
|
3 |
on:
|
4 |
push:
|
@@ -8,7 +8,6 @@ on:
|
|
8 |
|
9 |
jobs:
|
10 |
build:
|
11 |
-
|
12 |
runs-on: ubuntu-latest
|
13 |
|
14 |
steps:
|
@@ -17,10 +16,25 @@ jobs:
|
|
17 |
uses: actions/setup-python@v2
|
18 |
with:
|
19 |
python-version: 3.8
|
|
|
20 |
- name: Install dependencies
|
21 |
run: |
|
22 |
python -m pip install --upgrade pip
|
23 |
pip install -r requirements.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
- name: Test with pytest
|
25 |
-
run:
|
26 |
-
pytest
|
|
|
1 |
+
name: YOLOv9 - Model test and Code Style Check
|
2 |
|
3 |
on:
|
4 |
push:
|
|
|
8 |
|
9 |
jobs:
|
10 |
build:
|
|
|
11 |
runs-on: ubuntu-latest
|
12 |
|
13 |
steps:
|
|
|
16 |
uses: actions/setup-python@v2
|
17 |
with:
|
18 |
python-version: 3.8
|
19 |
+
|
20 |
- name: Install dependencies
|
21 |
run: |
|
22 |
python -m pip install --upgrade pip
|
23 |
pip install -r requirements.txt
|
24 |
+
|
25 |
+
- name: Install pre-commit
|
26 |
+
run: pip install pre-commit
|
27 |
+
|
28 |
+
- name: Cache pre-commit environment
|
29 |
+
uses: actions/cache@v2
|
30 |
+
with:
|
31 |
+
path: ~/.cache/pre-commit
|
32 |
+
key: ${{ runner.os }}-precommit-${{ hashFiles('**/.pre-commit-config.yaml') }}
|
33 |
+
restore-keys: |
|
34 |
+
${{ runner.os }}-precommit-
|
35 |
+
|
36 |
+
- name: Run pre-commit (black and isort)
|
37 |
+
run: pre-commit run --all-files
|
38 |
+
|
39 |
- name: Test with pytest
|
40 |
+
run: pytest
|
|
config/config.py
CHANGED
@@ -14,6 +14,57 @@ class Download:
|
|
14 |
path: str
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
@dataclass
|
18 |
class Dataset:
|
19 |
file_name: str
|
@@ -37,3 +88,4 @@ class Download:
|
|
37 |
class Config:
|
38 |
model: Model
|
39 |
download: Download
|
|
|
|
14 |
path: str
|
15 |
|
16 |
|
17 |
+
@dataclass
|
18 |
+
class DataLoaderConfig:
|
19 |
+
batch_size: int
|
20 |
+
shuffle: bool
|
21 |
+
num_workers: int
|
22 |
+
pin_memory: bool
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class OptimizerArgs:
|
27 |
+
lr: float
|
28 |
+
weight_decay: float
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class OptimizerConfig:
|
33 |
+
type: str
|
34 |
+
args: OptimizerArgs
|
35 |
+
|
36 |
+
|
37 |
+
@dataclass
|
38 |
+
class SchedulerArgs:
|
39 |
+
step_size: int
|
40 |
+
gamma: float
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class SchedulerConfig:
|
45 |
+
type: str
|
46 |
+
args: SchedulerArgs
|
47 |
+
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class EMAConfig:
|
51 |
+
enabled: bool
|
52 |
+
decay: float
|
53 |
+
|
54 |
+
|
55 |
+
@dataclass
|
56 |
+
class TrainConfig:
|
57 |
+
optimizer: OptimizerConfig
|
58 |
+
scheduler: SchedulerConfig
|
59 |
+
ema: EMAConfig
|
60 |
+
|
61 |
+
|
62 |
+
@dataclass
|
63 |
+
class HyperConfig:
|
64 |
+
data: DataLoaderConfig
|
65 |
+
train: TrainConfig
|
66 |
+
|
67 |
+
|
68 |
@dataclass
|
69 |
class Dataset:
|
70 |
file_name: str
|
|
|
88 |
class Config:
|
89 |
model: Model
|
90 |
download: Download
|
91 |
+
hyper: HyperConfig
|
config/hyper/default.yaml
CHANGED
@@ -3,3 +3,17 @@ data:
|
|
3 |
shuffle: True
|
4 |
num_workers: 4
|
5 |
pin_memory: True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
shuffle: True
|
4 |
num_workers: 4
|
5 |
pin_memory: True
|
6 |
+
train:
|
7 |
+
optimizer:
|
8 |
+
type: Adam
|
9 |
+
args:
|
10 |
+
lr: 0.001
|
11 |
+
weight_decay: 0.0001
|
12 |
+
scheduler:
|
13 |
+
type: StepLR
|
14 |
+
args:
|
15 |
+
step_size: 10
|
16 |
+
gamma: 0.1
|
17 |
+
ema:
|
18 |
+
enabled: true
|
19 |
+
decay: 0.995
|
tests/test_utils/test_dataaugment.py
CHANGED
@@ -6,23 +6,22 @@ from PIL import Image
|
|
6 |
from torchvision.transforms import functional as TF
|
7 |
|
8 |
sys.path.append("./")
|
9 |
-
from utils.data_augment import Compose, Mosaic,
|
10 |
|
11 |
|
12 |
-
def
|
13 |
# Create a mock image and bounding boxes
|
14 |
img = Image.new("RGB", (100, 100), color="red")
|
15 |
-
boxes = torch.tensor([[1, 0.
|
16 |
|
17 |
-
flip_transform =
|
18 |
flipped_img, flipped_boxes = flip_transform(img, boxes)
|
19 |
|
20 |
# Assert image is flipped by comparing it to a manually flipped image
|
21 |
assert TF.hflip(img) == flipped_img
|
22 |
|
23 |
# Assert bounding boxes are flipped correctly
|
24 |
-
expected_boxes = torch.tensor([[1, 0.
|
25 |
-
expected_boxes[:, [1, 3]] = 1 - expected_boxes[:, [3, 1]]
|
26 |
assert torch.allclose(flipped_boxes, expected_boxes), "Bounding boxes were not flipped correctly"
|
27 |
|
28 |
|
@@ -60,5 +59,5 @@ def test_mosaic():
|
|
60 |
# Checks here would depend on the exact expected behavior of the mosaic function,
|
61 |
# such as dimensions and content of the output image and boxes.
|
62 |
|
63 |
-
assert mosaic_img.size == (
|
64 |
assert len(mosaic_boxes) > 0, "Should have some bounding boxes"
|
|
|
6 |
from torchvision.transforms import functional as TF
|
7 |
|
8 |
sys.path.append("./")
|
9 |
+
from utils.data_augment import Compose, HorizontalFlip, Mosaic, VerticalFlip
|
10 |
|
11 |
|
12 |
+
def test_horizontal_flip():
|
13 |
# Create a mock image and bounding boxes
|
14 |
img = Image.new("RGB", (100, 100), color="red")
|
15 |
+
boxes = torch.tensor([[1, 0.05, 0.1, 0.7, 0.9]]) # class, xmin, ymin, xmax, ymax
|
16 |
|
17 |
+
flip_transform = HorizontalFlip(prob=1) # Set probability to 1 to ensure flip
|
18 |
flipped_img, flipped_boxes = flip_transform(img, boxes)
|
19 |
|
20 |
# Assert image is flipped by comparing it to a manually flipped image
|
21 |
assert TF.hflip(img) == flipped_img
|
22 |
|
23 |
# Assert bounding boxes are flipped correctly
|
24 |
+
expected_boxes = torch.tensor([[1, 0.3, 0.1, 0.95, 0.9]])
|
|
|
25 |
assert torch.allclose(flipped_boxes, expected_boxes), "Bounding boxes were not flipped correctly"
|
26 |
|
27 |
|
|
|
59 |
# Checks here would depend on the exact expected behavior of the mosaic function,
|
60 |
# such as dimensions and content of the output image and boxes.
|
61 |
|
62 |
+
assert mosaic_img.size == (100, 100), "Mosaic image size should be same"
|
63 |
assert len(mosaic_boxes) > 0, "Should have some bounding boxes"
|
tools/model_helper.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Type
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.optim import Optimizer
|
5 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
6 |
+
|
7 |
+
from config.config import OptimizerConfig, SchedulerConfig
|
8 |
+
|
9 |
+
|
10 |
+
class EMA:
|
11 |
+
def __init__(self, model: torch.nn.Module, decay: float):
|
12 |
+
self.model = model
|
13 |
+
self.decay = decay
|
14 |
+
self.shadow = {name: param.clone().detach() for name, param in model.named_parameters()}
|
15 |
+
|
16 |
+
def update(self):
|
17 |
+
"""Update the shadow parameters using the current model parameters."""
|
18 |
+
for name, param in self.model.named_parameters():
|
19 |
+
assert name in self.shadow, "All model parameters should have a corresponding shadow parameter."
|
20 |
+
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
|
21 |
+
self.shadow[name] = new_average.clone()
|
22 |
+
|
23 |
+
def apply_shadow(self):
|
24 |
+
"""Apply the shadow parameters to the model."""
|
25 |
+
for name, param in self.model.named_parameters():
|
26 |
+
param.data.copy_(self.shadow[name])
|
27 |
+
|
28 |
+
def restore(self):
|
29 |
+
"""Restore the original parameters from the shadow."""
|
30 |
+
for name, param in self.model.named_parameters():
|
31 |
+
self.shadow[name].copy_(param.data)
|
32 |
+
|
33 |
+
|
34 |
+
def get_optimizer(model_parameters, optim_cfg: OptimizerConfig) -> Optimizer:
|
35 |
+
"""Create an optimizer for the given model parameters based on the configuration.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
An instance of the optimizer configured according to the provided settings.
|
39 |
+
"""
|
40 |
+
optimizer_class: Type[Optimizer] = getattr(torch.optim, optim_cfg.type)
|
41 |
+
return optimizer_class(model_parameters, **optim_cfg.args)
|
42 |
+
|
43 |
+
|
44 |
+
def get_scheduler(optimizer: Optimizer, schedul_cfg: SchedulerConfig) -> _LRScheduler:
|
45 |
+
"""Create a learning rate scheduler for the given optimizer based on the configuration.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
An instance of the scheduler configured according to the provided settings.
|
49 |
+
"""
|
50 |
+
scheduler_class: Type[_LRScheduler] = getattr(torch.optim.lr_scheduler, schedul_cfg.type)
|
51 |
+
return scheduler_class(optimizer, **schedul_cfg.args)
|
tools/trainer.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from loguru import logger
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from config.config import TrainConfig
|
6 |
+
from model.yolo import YOLO
|
7 |
+
from tools.model_helper import EMA, get_optimizer, get_scheduler
|
8 |
+
from utils.loss import get_loss_function
|
9 |
+
|
10 |
+
|
11 |
+
class Trainer:
|
12 |
+
def __init__(self, model: YOLO, train_cfg: TrainConfig, device):
|
13 |
+
self.model = model.to(device)
|
14 |
+
self.device = device
|
15 |
+
self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
|
16 |
+
self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
|
17 |
+
self.loss_fn = get_loss_function()
|
18 |
+
|
19 |
+
if train_cfg.ema.get("enabled", False):
|
20 |
+
self.ema = EMA(model, decay=train_cfg.ema.decay)
|
21 |
+
else:
|
22 |
+
self.ema = None
|
23 |
+
|
24 |
+
def train_one_batch(self, data, targets):
|
25 |
+
data, targets = data.to(self.device), targets.to(self.device)
|
26 |
+
self.optimizer.zero_grad()
|
27 |
+
outputs = self.model(data)
|
28 |
+
loss = self.loss_fn(outputs, targets)
|
29 |
+
loss.backward()
|
30 |
+
self.optimizer.step()
|
31 |
+
if self.ema:
|
32 |
+
self.ema.update()
|
33 |
+
return loss.item()
|
34 |
+
|
35 |
+
def train_one_epoch(self, dataloader):
|
36 |
+
self.model.train()
|
37 |
+
total_loss = 0
|
38 |
+
for data, targets in tqdm(dataloader, desc="Training"):
|
39 |
+
loss = self.train_one_batch(data, targets)
|
40 |
+
total_loss += loss
|
41 |
+
if self.scheduler:
|
42 |
+
self.scheduler.step()
|
43 |
+
return total_loss / len(dataloader)
|
44 |
+
|
45 |
+
def save_checkpoint(self, epoch, filename="checkpoint.pt"):
|
46 |
+
checkpoint = {
|
47 |
+
"epoch": epoch,
|
48 |
+
"model_state_dict": self.model.state_dict(),
|
49 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
50 |
+
}
|
51 |
+
if self.ema:
|
52 |
+
self.ema.apply_shadow()
|
53 |
+
checkpoint["model_state_dict_ema"] = self.model.state_dict()
|
54 |
+
self.ema.restore()
|
55 |
+
torch.save(checkpoint, filename)
|
56 |
+
|
57 |
+
def train(self, dataloader, num_epochs):
|
58 |
+
logger.info("start train")
|
59 |
+
for epoch in range(num_epochs):
|
60 |
+
epoch_loss = self.train_one_epoch(dataloader)
|
61 |
+
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
62 |
+
if (epoch + 1) % 5 == 0:
|
63 |
+
self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
|
train.py
CHANGED
@@ -1,20 +1,27 @@
|
|
1 |
import hydra
|
|
|
2 |
from loguru import logger
|
3 |
|
4 |
from config.config import Config
|
5 |
from model.yolo import get_model
|
6 |
from tools.log_helper import custom_logger
|
7 |
-
from
|
|
|
8 |
from utils.get_dataset import prepare_dataset
|
9 |
|
10 |
|
11 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
12 |
def main(cfg: Config):
|
13 |
-
dataset = YoloDataset(cfg)
|
14 |
if cfg.download.auto:
|
15 |
prepare_dataset(cfg.download)
|
16 |
|
|
|
17 |
model = get_model(cfg.model)
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
if __name__ == "__main__":
|
|
|
1 |
import hydra
|
2 |
+
import torch
|
3 |
from loguru import logger
|
4 |
|
5 |
from config.config import Config
|
6 |
from model.yolo import get_model
|
7 |
from tools.log_helper import custom_logger
|
8 |
+
from tools.trainer import Trainer
|
9 |
+
from utils.dataloader import get_dataloader
|
10 |
from utils.get_dataset import prepare_dataset
|
11 |
|
12 |
|
13 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
14 |
def main(cfg: Config):
|
|
|
15 |
if cfg.download.auto:
|
16 |
prepare_dataset(cfg.download)
|
17 |
|
18 |
+
dataloader = get_dataloader(cfg)
|
19 |
model = get_model(cfg.model)
|
20 |
+
# TODO: get_device or rank, for DDP mode
|
21 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
+
|
23 |
+
trainer = Trainer(model, cfg.hyper.train, device)
|
24 |
+
trainer.train(dataloader, 10)
|
25 |
|
26 |
|
27 |
if __name__ == "__main__":
|
utils/loss.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
def get_loss_function(*args, **kwargs):
|
2 |
+
raise NotImplementedError
|