π [Merge] branch 'main' into SETUP
Browse files- .gitignore +4 -0
- demo/images/inference/image.png +0 -0
- examples/example_inference.py +35 -0
- examples/example_train.py +8 -10
- examples/lazy.py +37 -0
- tests/test_model/test_yolo.py +2 -0
- tests/test_utils/test_dataaugment.py +1 -1
- yolo/config/config.py +2 -0
- yolo/config/general.yaml +4 -2
- yolo/config/task/dataset/coco.yaml +2 -0
- yolo/config/task/dataset/demo.yaml +3 -0
- yolo/config/task/inference.yaml +11 -0
- yolo/model/yolo.py +11 -0
- yolo/tools/dataset_preparation.py +31 -13
- yolo/tools/{trainer.py β solver.py} +33 -7
.gitignore
CHANGED
@@ -42,6 +42,7 @@ htmlcov/
|
|
42 |
.coverage
|
43 |
.coverage.*
|
44 |
.cache
|
|
|
45 |
nosetests.xml
|
46 |
coverage.xml
|
47 |
*.cover
|
@@ -140,3 +141,6 @@ runs
|
|
140 |
|
141 |
# Ignore npm packages (if using frontend components)
|
142 |
node_modules/
|
|
|
|
|
|
|
|
42 |
.coverage
|
43 |
.coverage.*
|
44 |
.cache
|
45 |
+
*.cache
|
46 |
nosetests.xml
|
47 |
coverage.xml
|
48 |
*.cover
|
|
|
141 |
|
142 |
# Ignore npm packages (if using frontend components)
|
143 |
node_modules/
|
144 |
+
|
145 |
+
# Not ignore image for demo
|
146 |
+
!demo/images/inference/*
|
demo/images/inference/image.png
ADDED
![]() |
examples/example_inference.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import hydra
|
5 |
+
import torch
|
6 |
+
|
7 |
+
project_root = Path(__file__).resolve().parent.parent
|
8 |
+
sys.path.append(str(project_root))
|
9 |
+
|
10 |
+
from yolo.config.config import Config
|
11 |
+
from yolo.model.yolo import get_model
|
12 |
+
from yolo.tools.data_loader import create_dataloader
|
13 |
+
from yolo.tools.solver import ModelTester
|
14 |
+
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
15 |
+
|
16 |
+
|
17 |
+
@hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
|
18 |
+
def main(cfg: Config):
|
19 |
+
custom_logger()
|
20 |
+
save_path = validate_log_directory(cfg, cfg.name)
|
21 |
+
|
22 |
+
device = torch.device(cfg.device)
|
23 |
+
model = get_model(cfg).to(device)
|
24 |
+
|
25 |
+
save_path = validate_log_directory(cfg, cfg.name)
|
26 |
+
dataloader = create_dataloader(cfg)
|
27 |
+
device = torch.device(cfg.device)
|
28 |
+
model = get_model(cfg).to(device)
|
29 |
+
|
30 |
+
tester = ModelTester(cfg, model, save_path, device)
|
31 |
+
tester.solve(dataloader)
|
32 |
+
|
33 |
+
|
34 |
+
if __name__ == "__main__":
|
35 |
+
main()
|
examples/example_train.py
CHANGED
@@ -3,30 +3,28 @@ from pathlib import Path
|
|
3 |
|
4 |
import hydra
|
5 |
import torch
|
6 |
-
from loguru import logger
|
7 |
|
8 |
project_root = Path(__file__).resolve().parent.parent
|
9 |
sys.path.append(str(project_root))
|
10 |
|
11 |
from yolo.config.config import Config
|
|
|
12 |
from yolo.tools.data_loader import create_dataloader
|
13 |
-
from yolo.tools.
|
14 |
-
from yolo.tools.trainer import ModelTrainer
|
15 |
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
16 |
|
17 |
|
18 |
@hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
|
19 |
def main(cfg: Config):
|
20 |
custom_logger()
|
21 |
-
save_path = validate_log_directory(cfg
|
22 |
-
if cfg.download.auto:
|
23 |
-
prepare_dataset(cfg.download)
|
24 |
-
|
25 |
dataloader = create_dataloader(cfg)
|
26 |
# TODO: get_device or rank, for DDP mode
|
27 |
-
device = torch.device(
|
28 |
-
|
29 |
-
|
|
|
|
|
30 |
|
31 |
|
32 |
if __name__ == "__main__":
|
|
|
3 |
|
4 |
import hydra
|
5 |
import torch
|
|
|
6 |
|
7 |
project_root = Path(__file__).resolve().parent.parent
|
8 |
sys.path.append(str(project_root))
|
9 |
|
10 |
from yolo.config.config import Config
|
11 |
+
from yolo.model.yolo import get_model
|
12 |
from yolo.tools.data_loader import create_dataloader
|
13 |
+
from yolo.tools.solver import ModelTrainer
|
|
|
14 |
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
15 |
|
16 |
|
17 |
@hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
|
18 |
def main(cfg: Config):
|
19 |
custom_logger()
|
20 |
+
save_path = validate_log_directory(cfg, cfg.name)
|
|
|
|
|
|
|
21 |
dataloader = create_dataloader(cfg)
|
22 |
# TODO: get_device or rank, for DDP mode
|
23 |
+
device = torch.device(cfg.device)
|
24 |
+
model = get_model(cfg).to(device)
|
25 |
+
|
26 |
+
trainer = ModelTrainer(cfg, model, save_path, device)
|
27 |
+
trainer.solve(dataloader, cfg.task.epoch)
|
28 |
|
29 |
|
30 |
if __name__ == "__main__":
|
examples/lazy.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import hydra
|
5 |
+
import torch
|
6 |
+
|
7 |
+
project_root = Path(__file__).resolve().parent.parent
|
8 |
+
sys.path.append(str(project_root))
|
9 |
+
|
10 |
+
from yolo.config.config import Config
|
11 |
+
from yolo.model.yolo import get_model
|
12 |
+
from yolo.tools.data_loader import create_dataloader
|
13 |
+
from yolo.tools.solver import ModelTester, ModelTrainer
|
14 |
+
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
15 |
+
|
16 |
+
|
17 |
+
@hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
|
18 |
+
def main(cfg: Config):
|
19 |
+
custom_logger()
|
20 |
+
|
21 |
+
custom_logger()
|
22 |
+
save_path = validate_log_directory(cfg, cfg.name)
|
23 |
+
dataloader = create_dataloader(cfg)
|
24 |
+
device = torch.device(cfg.device)
|
25 |
+
model = get_model(cfg).to(device)
|
26 |
+
|
27 |
+
if cfg.task.task == "train":
|
28 |
+
trainer = ModelTrainer(cfg, model, save_path, device)
|
29 |
+
trainer.solve(dataloader)
|
30 |
+
|
31 |
+
if cfg.task.task == "inference":
|
32 |
+
tester = ModelTester(cfg, model, save_path, device)
|
33 |
+
tester.solve(dataloader)
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
main()
|
tests/test_model/test_yolo.py
CHANGED
@@ -19,6 +19,7 @@ def test_build_model():
|
|
19 |
cfg = compose(config_name=config_name)
|
20 |
|
21 |
OmegaConf.set_struct(cfg.model, False)
|
|
|
22 |
model = YOLO(cfg.model, 80)
|
23 |
assert len(model.model) == 38
|
24 |
|
@@ -26,6 +27,7 @@ def test_build_model():
|
|
26 |
def test_get_model():
|
27 |
with initialize(config_path=config_path, version_base=None):
|
28 |
cfg = compose(config_name=config_name)
|
|
|
29 |
model = get_model(cfg)
|
30 |
assert isinstance(model, YOLO)
|
31 |
|
|
|
19 |
cfg = compose(config_name=config_name)
|
20 |
|
21 |
OmegaConf.set_struct(cfg.model, False)
|
22 |
+
cfg.weight = None
|
23 |
model = YOLO(cfg.model, 80)
|
24 |
assert len(model.model) == 38
|
25 |
|
|
|
27 |
def test_get_model():
|
28 |
with initialize(config_path=config_path, version_base=None):
|
29 |
cfg = compose(config_name=config_name)
|
30 |
+
cfg.weight = None
|
31 |
model = get_model(cfg)
|
32 |
assert isinstance(model, YOLO)
|
33 |
|
tests/test_utils/test_dataaugment.py
CHANGED
@@ -39,7 +39,7 @@ def test_compose():
|
|
39 |
return image, boxes
|
40 |
|
41 |
compose = AugmentationComposer([mock_transform, mock_transform])
|
42 |
-
img = Image.new("RGB", (
|
43 |
boxes = torch.tensor([[0, 0.2, 0.2, 0.8, 0.8]])
|
44 |
|
45 |
transformed_img, transformed_boxes = compose(img, boxes)
|
|
|
39 |
return image, boxes
|
40 |
|
41 |
compose = AugmentationComposer([mock_transform, mock_transform])
|
42 |
+
img = Image.new("RGB", (640, 640), color="blue")
|
43 |
boxes = torch.tensor([[0, 0.2, 0.2, 0.8, 0.8]])
|
44 |
|
45 |
transformed_img, transformed_boxes = compose(img, boxes)
|
yolo/config/config.py
CHANGED
@@ -135,6 +135,8 @@ class Config:
|
|
135 |
use_wandb: bool
|
136 |
use_TensorBoard: bool
|
137 |
|
|
|
|
|
138 |
|
139 |
@dataclass
|
140 |
class YOLOLayer(nn.Module):
|
|
|
135 |
use_wandb: bool
|
136 |
use_TensorBoard: bool
|
137 |
|
138 |
+
weight: Optional[str]
|
139 |
+
|
140 |
|
141 |
@dataclass
|
142 |
class YOLOLayer(nn.Module):
|
yolo/config/general.yaml
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
cpu_num: 16
|
3 |
|
4 |
class_num: 80
|
@@ -9,4 +9,6 @@ exist_ok: True
|
|
9 |
|
10 |
lucky_number: 10
|
11 |
use_wandb: False
|
12 |
-
use_TensorBoard: False
|
|
|
|
|
|
1 |
+
device: 0
|
2 |
cpu_num: 16
|
3 |
|
4 |
class_num: 80
|
|
|
9 |
|
10 |
lucky_number: 10
|
11 |
use_wandb: False
|
12 |
+
use_TensorBoard: False
|
13 |
+
|
14 |
+
weight: v9-c.pt
|
yolo/config/task/dataset/coco.yaml
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
path: data/coco
|
|
|
|
|
2 |
|
3 |
auto_download:
|
4 |
images:
|
|
|
1 |
path: data/coco
|
2 |
+
train: train2017
|
3 |
+
|
4 |
|
5 |
auto_download:
|
6 |
images:
|
yolo/config/task/dataset/demo.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
path: demo
|
2 |
+
|
3 |
+
auto_download:
|
yolo/config/task/inference.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
task: inference
|
2 |
+
defaults:
|
3 |
+
- dataset: demo
|
4 |
+
data:
|
5 |
+
batch_size: 16
|
6 |
+
shuffle: False
|
7 |
+
pin_memory: True
|
8 |
+
data_augment: {}
|
9 |
+
nms:
|
10 |
+
min_confidence: 0.75
|
11 |
+
min_iou: 0.5
|
yolo/model/yolo.py
CHANGED
@@ -1,10 +1,13 @@
|
|
|
|
1 |
from typing import Any, Dict, List, Union
|
2 |
|
|
|
3 |
import torch.nn as nn
|
4 |
from loguru import logger
|
5 |
from omegaconf import ListConfig, OmegaConf
|
6 |
|
7 |
from yolo.config.config import Config, Model, YOLOLayer
|
|
|
8 |
from yolo.tools.drawer import draw_model
|
9 |
from yolo.utils.logging_utils import log_model_structure
|
10 |
from yolo.utils.module_utils import get_layer_map
|
@@ -125,6 +128,14 @@ def get_model(cfg: Config) -> YOLO:
|
|
125 |
OmegaConf.set_struct(cfg.model, False)
|
126 |
model = YOLO(cfg.model, cfg.class_num)
|
127 |
logger.info("β
Success load model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
log_model_structure(model.model)
|
129 |
draw_model(model=model)
|
130 |
return model
|
|
|
1 |
+
import os
|
2 |
from typing import Any, Dict, List, Union
|
3 |
|
4 |
+
import torch
|
5 |
import torch.nn as nn
|
6 |
from loguru import logger
|
7 |
from omegaconf import ListConfig, OmegaConf
|
8 |
|
9 |
from yolo.config.config import Config, Model, YOLOLayer
|
10 |
+
from yolo.tools.dataset_preparation import prepare_weight
|
11 |
from yolo.tools.drawer import draw_model
|
12 |
from yolo.utils.logging_utils import log_model_structure
|
13 |
from yolo.utils.module_utils import get_layer_map
|
|
|
128 |
OmegaConf.set_struct(cfg.model, False)
|
129 |
model = YOLO(cfg.model, cfg.class_num)
|
130 |
logger.info("β
Success load model")
|
131 |
+
if cfg.weight:
|
132 |
+
if os.path.exists(cfg.weight):
|
133 |
+
model.model.load_state_dict(torch.load(cfg.weight))
|
134 |
+
logger.info("β
Success load model weight")
|
135 |
+
else:
|
136 |
+
logger.info(f"π Weight {cfg.weight} not found, try downloading")
|
137 |
+
prepare_weight(weight_name=cfg.weight)
|
138 |
+
|
139 |
log_model_structure(model.model)
|
140 |
draw_model(model=model)
|
141 |
return model
|
yolo/tools/dataset_preparation.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
import os
|
2 |
import zipfile
|
|
|
3 |
|
4 |
import requests
|
5 |
-
from hydra import main
|
6 |
from loguru import logger
|
7 |
-
from
|
8 |
|
9 |
from yolo.config.config import DatasetConfig
|
10 |
|
@@ -13,18 +13,24 @@ def download_file(url, destination):
|
|
13 |
"""
|
14 |
Downloads a file from the specified URL to the destination path with progress logging.
|
15 |
"""
|
16 |
-
logger.info(f"Downloading {os.path.basename(destination)}...")
|
17 |
with requests.get(url, stream=True) as response:
|
18 |
response.raise_for_status()
|
19 |
total_size = int(response.headers.get("content-length", 0))
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
def unzip_file(source, destination):
|
@@ -46,7 +52,6 @@ def check_files(directory, expected_count=None):
|
|
46 |
return len(files) == expected_count if expected_count is not None else bool(files)
|
47 |
|
48 |
|
49 |
-
@main(config_path="../config/data", config_name="download", version_base=None)
|
50 |
def prepare_dataset(cfg: DatasetConfig):
|
51 |
"""
|
52 |
Prepares dataset by downloading and unzipping if necessary.
|
@@ -76,6 +81,19 @@ def prepare_dataset(cfg: DatasetConfig):
|
|
76 |
logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
|
77 |
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
if __name__ == "__main__":
|
80 |
import sys
|
81 |
|
@@ -83,4 +101,4 @@ if __name__ == "__main__":
|
|
83 |
from utils.logging_utils import custom_logger
|
84 |
|
85 |
custom_logger()
|
86 |
-
|
|
|
1 |
import os
|
2 |
import zipfile
|
3 |
+
from typing import Optional
|
4 |
|
5 |
import requests
|
|
|
6 |
from loguru import logger
|
7 |
+
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
8 |
|
9 |
from yolo.config.config import DatasetConfig
|
10 |
|
|
|
13 |
"""
|
14 |
Downloads a file from the specified URL to the destination path with progress logging.
|
15 |
"""
|
|
|
16 |
with requests.get(url, stream=True) as response:
|
17 |
response.raise_for_status()
|
18 |
total_size = int(response.headers.get("content-length", 0))
|
19 |
+
with Progress(
|
20 |
+
TextColumn("[progress.description]{task.description}"),
|
21 |
+
BarColumn(),
|
22 |
+
"[progress.percentage]{task.percentage:>3.1f}%",
|
23 |
+
"β’",
|
24 |
+
"{task.completed}/{task.total} bytes",
|
25 |
+
"β’",
|
26 |
+
TimeRemainingColumn(),
|
27 |
+
) as progress:
|
28 |
+
task = progress.add_task(f"π₯ Downloading {os.path.basename(destination)}...", total=total_size)
|
29 |
+
with open(destination, "wb") as file:
|
30 |
+
for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
|
31 |
+
file.write(data)
|
32 |
+
progress.update(task, advance=len(data))
|
33 |
+
logger.info("β
Download completed.")
|
34 |
|
35 |
|
36 |
def unzip_file(source, destination):
|
|
|
52 |
return len(files) == expected_count if expected_count is not None else bool(files)
|
53 |
|
54 |
|
|
|
55 |
def prepare_dataset(cfg: DatasetConfig):
|
56 |
"""
|
57 |
Prepares dataset by downloading and unzipping if necessary.
|
|
|
81 |
logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
|
82 |
|
83 |
|
84 |
+
def prepare_weight(downlaod_link: Optional[str] = None, weight_name: str = "v9-c.pt"):
|
85 |
+
if downlaod_link is None:
|
86 |
+
downlaod_link = "https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/"
|
87 |
+
weight_link = f"{downlaod_link}{weight_name}"
|
88 |
+
|
89 |
+
if os.path.exists(weight_name):
|
90 |
+
logger.info(f"Weight file '{weight_name}' already exists.")
|
91 |
+
try:
|
92 |
+
download_file(weight_link, weight_name)
|
93 |
+
except requests.exceptions.RequestException as e:
|
94 |
+
logger.warning(f"Failed to download the weight file: {e}")
|
95 |
+
|
96 |
+
|
97 |
if __name__ == "__main__":
|
98 |
import sys
|
99 |
|
|
|
101 |
from utils.logging_utils import custom_logger
|
102 |
|
103 |
custom_logger()
|
104 |
+
prepare_weight()
|
yolo/tools/{trainer.py β solver.py}
RENAMED
@@ -6,8 +6,10 @@ from torch import Tensor
|
|
6 |
from torch.cuda.amp import GradScaler, autocast
|
7 |
|
8 |
from yolo.config.config import Config, TrainConfig
|
9 |
-
from yolo.model.yolo import
|
|
|
10 |
from yolo.tools.loss_functions import get_loss_function
|
|
|
11 |
from yolo.utils.logging_utils import ProgressTracker
|
12 |
from yolo.utils.model_utils import (
|
13 |
ExponentialMovingAverage,
|
@@ -17,16 +19,15 @@ from yolo.utils.model_utils import (
|
|
17 |
|
18 |
|
19 |
class ModelTrainer:
|
20 |
-
def __init__(self, cfg: Config, save_path: str, device):
|
21 |
train_cfg: TrainConfig = cfg.task
|
22 |
-
model =
|
23 |
-
|
24 |
-
self.model = model.to(device)
|
25 |
self.device = device
|
26 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
27 |
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
28 |
self.loss_fn = get_loss_function(cfg)
|
29 |
-
self.progress = ProgressTracker(cfg, save_path, use_wandb
|
|
|
30 |
|
31 |
if getattr(train_cfg.ema, "enabled", False):
|
32 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
@@ -75,8 +76,9 @@ class ModelTrainer:
|
|
75 |
self.ema.restore()
|
76 |
torch.save(checkpoint, filename)
|
77 |
|
78 |
-
def
|
79 |
logger.info("π Start Training!")
|
|
|
80 |
|
81 |
with self.progress.progress:
|
82 |
self.progress.start_train(num_epochs)
|
@@ -89,3 +91,27 @@ class ModelTrainer:
|
|
89 |
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
90 |
if (epoch + 1) % 5 == 0:
|
91 |
self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from torch.cuda.amp import GradScaler, autocast
|
7 |
|
8 |
from yolo.config.config import Config, TrainConfig
|
9 |
+
from yolo.model.yolo import YOLO
|
10 |
+
from yolo.tools.drawer import draw_bboxes
|
11 |
from yolo.tools.loss_functions import get_loss_function
|
12 |
+
from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms
|
13 |
from yolo.utils.logging_utils import ProgressTracker
|
14 |
from yolo.utils.model_utils import (
|
15 |
ExponentialMovingAverage,
|
|
|
19 |
|
20 |
|
21 |
class ModelTrainer:
|
22 |
+
def __init__(self, cfg: Config, model: YOLO, save_path: str, device):
|
23 |
train_cfg: TrainConfig = cfg.task
|
24 |
+
self.model = model
|
|
|
|
|
25 |
self.device = device
|
26 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
27 |
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
28 |
self.loss_fn = get_loss_function(cfg)
|
29 |
+
self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
|
30 |
+
self.num_epochs = cfg.task.epoch
|
31 |
|
32 |
if getattr(train_cfg.ema, "enabled", False):
|
33 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
|
|
76 |
self.ema.restore()
|
77 |
torch.save(checkpoint, filename)
|
78 |
|
79 |
+
def solve(self, dataloader):
|
80 |
logger.info("π Start Training!")
|
81 |
+
num_epochs = self.num_epochs
|
82 |
|
83 |
with self.progress.progress:
|
84 |
self.progress.start_train(num_epochs)
|
|
|
91 |
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
92 |
if (epoch + 1) % 5 == 0:
|
93 |
self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
|
94 |
+
|
95 |
+
|
96 |
+
class ModelTester:
|
97 |
+
def __init__(self, cfg: Config, model: YOLO, save_path: str, device):
|
98 |
+
self.model = model
|
99 |
+
self.device = device
|
100 |
+
self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
|
101 |
+
|
102 |
+
self.anchor2box = AnchorBoxConverter(cfg, device)
|
103 |
+
self.nms = cfg.task.nms
|
104 |
+
self.save_path = save_path
|
105 |
+
|
106 |
+
def solve(self, dataloader):
|
107 |
+
logger.info("π Start Inference!")
|
108 |
+
|
109 |
+
for images, _ in dataloader:
|
110 |
+
images = images.to(self.device)
|
111 |
+
with torch.no_grad():
|
112 |
+
raw_output = self.model(images)
|
113 |
+
predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True)
|
114 |
+
|
115 |
+
nms_out = bbox_nms(predict, self.nms)
|
116 |
+
for image, bbox in zip(images, nms_out):
|
117 |
+
draw_bboxes(image, bbox, scaled_bbox=False, save_path=self.save_path)
|