π [Merge] branch 'DATASET' of https://github.com/WongKinYiu/yolov9mit into DATASET
Browse files- .github/workflows/main.yaml +18 -4
- LICENSE +21 -0
- README.md +13 -11
- config/config.py +52 -0
- config/hyper/default.yaml +14 -0
- tests/test_utils/test_dataaugment.py +6 -7
- tools/model_helper.py +51 -0
- tools/trainer.py +63 -0
- train.py +9 -2
- utils/converter_json2txt.py +86 -0
- utils/dataloader.py +1 -1
- utils/loss.py +2 -0
.github/workflows/main.yaml
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
name: YOLOv9 - Model test
|
2 |
|
3 |
on:
|
4 |
push:
|
@@ -8,7 +8,6 @@ on:
|
|
8 |
|
9 |
jobs:
|
10 |
build:
|
11 |
-
|
12 |
runs-on: ubuntu-latest
|
13 |
|
14 |
steps:
|
@@ -17,10 +16,25 @@ jobs:
|
|
17 |
uses: actions/setup-python@v2
|
18 |
with:
|
19 |
python-version: 3.8
|
|
|
20 |
- name: Install dependencies
|
21 |
run: |
|
22 |
python -m pip install --upgrade pip
|
23 |
pip install -r requirements.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
- name: Test with pytest
|
25 |
-
run:
|
26 |
-
pytest
|
|
|
1 |
+
name: YOLOv9 - Model test and Code Style Check
|
2 |
|
3 |
on:
|
4 |
push:
|
|
|
8 |
|
9 |
jobs:
|
10 |
build:
|
|
|
11 |
runs-on: ubuntu-latest
|
12 |
|
13 |
steps:
|
|
|
16 |
uses: actions/setup-python@v2
|
17 |
with:
|
18 |
python-version: 3.8
|
19 |
+
|
20 |
- name: Install dependencies
|
21 |
run: |
|
22 |
python -m pip install --upgrade pip
|
23 |
pip install -r requirements.txt
|
24 |
+
|
25 |
+
- name: Install pre-commit
|
26 |
+
run: pip install pre-commit
|
27 |
+
|
28 |
+
- name: Cache pre-commit environment
|
29 |
+
uses: actions/cache@v2
|
30 |
+
with:
|
31 |
+
path: ~/.cache/pre-commit
|
32 |
+
key: ${{ runner.os }}-precommit-${{ hashFiles('**/.pre-commit-config.yaml') }}
|
33 |
+
restore-keys: |
|
34 |
+
${{ runner.os }}-precommit-
|
35 |
+
|
36 |
+
- name: Run pre-commit (black and isort)
|
37 |
+
run: pre-commit run --all-files
|
38 |
+
|
39 |
- name: Test with pytest
|
40 |
+
run: pytest
|
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Kin-Yiu, Wong
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -13,20 +13,22 @@ While the project's structure is still being finalized, we ask that potential co
|
|
13 |
|
14 |
If you are interested in contributing, please keep an eye on project updates or contact us directly at [[email protected]](mailto:[email protected]) for more information.
|
15 |
|
|
|
16 |
|
|
|
17 |
|
18 |
## To-Do Lists
|
19 |
- [ ] Project Setup
|
20 |
- [X] requirements
|
21 |
-
- [
|
22 |
- [ ] README
|
23 |
-
- [
|
24 |
- [ ] setup.py/pip install
|
25 |
-
- [
|
26 |
- [ ] hugging face
|
27 |
- [ ] Data proccess
|
28 |
- [ ] Dataset
|
29 |
-
- [
|
30 |
- [ ] Auto Download
|
31 |
- [ ] xywh, xxyy, xcyc
|
32 |
- [ ] Dataloder
|
@@ -35,14 +37,14 @@ If you are interested in contributing, please keep an eye on project updates or
|
|
35 |
- [ ] load model
|
36 |
- [ ] from yaml
|
37 |
- [ ] from github
|
38 |
-
- [
|
39 |
-
- [
|
40 |
-
- [
|
41 |
-
- [ ] DDP
|
|
|
|
|
42 |
- [ ] Run
|
43 |
- [ ] train
|
44 |
- [ ] test
|
45 |
- [ ] demo
|
46 |
-
- [
|
47 |
-
- [ ] hyperparams: dataclass
|
48 |
-
- [ ] model cfg: yaml
|
|
|
13 |
|
14 |
If you are interested in contributing, please keep an eye on project updates or contact us directly at [[email protected]](mailto:[email protected]) for more information.
|
15 |
|
16 |
+
## Star History
|
17 |
|
18 |
+
[](https://star-history.com/#WongKinYiu/yolov9mit&Date)
|
19 |
|
20 |
## To-Do Lists
|
21 |
- [ ] Project Setup
|
22 |
- [X] requirements
|
23 |
+
- [x] LICENSE
|
24 |
- [ ] README
|
25 |
+
- [x] pytests
|
26 |
- [ ] setup.py/pip install
|
27 |
+
- [x] log format
|
28 |
- [ ] hugging face
|
29 |
- [ ] Data proccess
|
30 |
- [ ] Dataset
|
31 |
+
- [x] Download script
|
32 |
- [ ] Auto Download
|
33 |
- [ ] xywh, xxyy, xcyc
|
34 |
- [ ] Dataloder
|
|
|
37 |
- [ ] load model
|
38 |
- [ ] from yaml
|
39 |
- [ ] from github
|
40 |
+
- [x] trainer
|
41 |
+
- [x] train_one_iter
|
42 |
+
- [x] train_one_epoch
|
43 |
+
- [ ] DDP
|
44 |
+
- [x] EMA, OTA
|
45 |
+
- [ ] Loss
|
46 |
- [ ] Run
|
47 |
- [ ] train
|
48 |
- [ ] test
|
49 |
- [ ] demo
|
50 |
+
- [x] Configuration
|
|
|
|
config/config.py
CHANGED
@@ -14,6 +14,57 @@ class Download:
|
|
14 |
path: str
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
@dataclass
|
18 |
class Dataset:
|
19 |
file_name: str
|
@@ -37,3 +88,4 @@ class Download:
|
|
37 |
class Config:
|
38 |
model: Model
|
39 |
download: Download
|
|
|
|
14 |
path: str
|
15 |
|
16 |
|
17 |
+
@dataclass
|
18 |
+
class DataLoaderConfig:
|
19 |
+
batch_size: int
|
20 |
+
shuffle: bool
|
21 |
+
num_workers: int
|
22 |
+
pin_memory: bool
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class OptimizerArgs:
|
27 |
+
lr: float
|
28 |
+
weight_decay: float
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class OptimizerConfig:
|
33 |
+
type: str
|
34 |
+
args: OptimizerArgs
|
35 |
+
|
36 |
+
|
37 |
+
@dataclass
|
38 |
+
class SchedulerArgs:
|
39 |
+
step_size: int
|
40 |
+
gamma: float
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class SchedulerConfig:
|
45 |
+
type: str
|
46 |
+
args: SchedulerArgs
|
47 |
+
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class EMAConfig:
|
51 |
+
enabled: bool
|
52 |
+
decay: float
|
53 |
+
|
54 |
+
|
55 |
+
@dataclass
|
56 |
+
class TrainConfig:
|
57 |
+
optimizer: OptimizerConfig
|
58 |
+
scheduler: SchedulerConfig
|
59 |
+
ema: EMAConfig
|
60 |
+
|
61 |
+
|
62 |
+
@dataclass
|
63 |
+
class HyperConfig:
|
64 |
+
data: DataLoaderConfig
|
65 |
+
train: TrainConfig
|
66 |
+
|
67 |
+
|
68 |
@dataclass
|
69 |
class Dataset:
|
70 |
file_name: str
|
|
|
88 |
class Config:
|
89 |
model: Model
|
90 |
download: Download
|
91 |
+
hyper: HyperConfig
|
config/hyper/default.yaml
CHANGED
@@ -3,3 +3,17 @@ data:
|
|
3 |
shuffle: True
|
4 |
num_workers: 4
|
5 |
pin_memory: True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
shuffle: True
|
4 |
num_workers: 4
|
5 |
pin_memory: True
|
6 |
+
train:
|
7 |
+
optimizer:
|
8 |
+
type: Adam
|
9 |
+
args:
|
10 |
+
lr: 0.001
|
11 |
+
weight_decay: 0.0001
|
12 |
+
scheduler:
|
13 |
+
type: StepLR
|
14 |
+
args:
|
15 |
+
step_size: 10
|
16 |
+
gamma: 0.1
|
17 |
+
ema:
|
18 |
+
enabled: true
|
19 |
+
decay: 0.995
|
tests/test_utils/test_dataaugment.py
CHANGED
@@ -6,23 +6,22 @@ from PIL import Image
|
|
6 |
from torchvision.transforms import functional as TF
|
7 |
|
8 |
sys.path.append("./")
|
9 |
-
from utils.data_augment import Compose, Mosaic,
|
10 |
|
11 |
|
12 |
-
def
|
13 |
# Create a mock image and bounding boxes
|
14 |
img = Image.new("RGB", (100, 100), color="red")
|
15 |
-
boxes = torch.tensor([[1, 0.
|
16 |
|
17 |
-
flip_transform =
|
18 |
flipped_img, flipped_boxes = flip_transform(img, boxes)
|
19 |
|
20 |
# Assert image is flipped by comparing it to a manually flipped image
|
21 |
assert TF.hflip(img) == flipped_img
|
22 |
|
23 |
# Assert bounding boxes are flipped correctly
|
24 |
-
expected_boxes = torch.tensor([[1, 0.
|
25 |
-
expected_boxes[:, [1, 3]] = 1 - expected_boxes[:, [3, 1]]
|
26 |
assert torch.allclose(flipped_boxes, expected_boxes), "Bounding boxes were not flipped correctly"
|
27 |
|
28 |
|
@@ -60,5 +59,5 @@ def test_mosaic():
|
|
60 |
# Checks here would depend on the exact expected behavior of the mosaic function,
|
61 |
# such as dimensions and content of the output image and boxes.
|
62 |
|
63 |
-
assert mosaic_img.size == (
|
64 |
assert len(mosaic_boxes) > 0, "Should have some bounding boxes"
|
|
|
6 |
from torchvision.transforms import functional as TF
|
7 |
|
8 |
sys.path.append("./")
|
9 |
+
from utils.data_augment import Compose, HorizontalFlip, Mosaic, VerticalFlip
|
10 |
|
11 |
|
12 |
+
def test_horizontal_flip():
|
13 |
# Create a mock image and bounding boxes
|
14 |
img = Image.new("RGB", (100, 100), color="red")
|
15 |
+
boxes = torch.tensor([[1, 0.05, 0.1, 0.7, 0.9]]) # class, xmin, ymin, xmax, ymax
|
16 |
|
17 |
+
flip_transform = HorizontalFlip(prob=1) # Set probability to 1 to ensure flip
|
18 |
flipped_img, flipped_boxes = flip_transform(img, boxes)
|
19 |
|
20 |
# Assert image is flipped by comparing it to a manually flipped image
|
21 |
assert TF.hflip(img) == flipped_img
|
22 |
|
23 |
# Assert bounding boxes are flipped correctly
|
24 |
+
expected_boxes = torch.tensor([[1, 0.3, 0.1, 0.95, 0.9]])
|
|
|
25 |
assert torch.allclose(flipped_boxes, expected_boxes), "Bounding boxes were not flipped correctly"
|
26 |
|
27 |
|
|
|
59 |
# Checks here would depend on the exact expected behavior of the mosaic function,
|
60 |
# such as dimensions and content of the output image and boxes.
|
61 |
|
62 |
+
assert mosaic_img.size == (100, 100), "Mosaic image size should be same"
|
63 |
assert len(mosaic_boxes) > 0, "Should have some bounding boxes"
|
tools/model_helper.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Type
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.optim import Optimizer
|
5 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
6 |
+
|
7 |
+
from config.config import OptimizerConfig, SchedulerConfig
|
8 |
+
|
9 |
+
|
10 |
+
class EMA:
|
11 |
+
def __init__(self, model: torch.nn.Module, decay: float):
|
12 |
+
self.model = model
|
13 |
+
self.decay = decay
|
14 |
+
self.shadow = {name: param.clone().detach() for name, param in model.named_parameters()}
|
15 |
+
|
16 |
+
def update(self):
|
17 |
+
"""Update the shadow parameters using the current model parameters."""
|
18 |
+
for name, param in self.model.named_parameters():
|
19 |
+
assert name in self.shadow, "All model parameters should have a corresponding shadow parameter."
|
20 |
+
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
|
21 |
+
self.shadow[name] = new_average.clone()
|
22 |
+
|
23 |
+
def apply_shadow(self):
|
24 |
+
"""Apply the shadow parameters to the model."""
|
25 |
+
for name, param in self.model.named_parameters():
|
26 |
+
param.data.copy_(self.shadow[name])
|
27 |
+
|
28 |
+
def restore(self):
|
29 |
+
"""Restore the original parameters from the shadow."""
|
30 |
+
for name, param in self.model.named_parameters():
|
31 |
+
self.shadow[name].copy_(param.data)
|
32 |
+
|
33 |
+
|
34 |
+
def get_optimizer(model_parameters, optim_cfg: OptimizerConfig) -> Optimizer:
|
35 |
+
"""Create an optimizer for the given model parameters based on the configuration.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
An instance of the optimizer configured according to the provided settings.
|
39 |
+
"""
|
40 |
+
optimizer_class: Type[Optimizer] = getattr(torch.optim, optim_cfg.type)
|
41 |
+
return optimizer_class(model_parameters, **optim_cfg.args)
|
42 |
+
|
43 |
+
|
44 |
+
def get_scheduler(optimizer: Optimizer, schedul_cfg: SchedulerConfig) -> _LRScheduler:
|
45 |
+
"""Create a learning rate scheduler for the given optimizer based on the configuration.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
An instance of the scheduler configured according to the provided settings.
|
49 |
+
"""
|
50 |
+
scheduler_class: Type[_LRScheduler] = getattr(torch.optim.lr_scheduler, schedul_cfg.type)
|
51 |
+
return scheduler_class(optimizer, **schedul_cfg.args)
|
tools/trainer.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from loguru import logger
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from config.config import TrainConfig
|
6 |
+
from model.yolo import YOLO
|
7 |
+
from tools.model_helper import EMA, get_optimizer, get_scheduler
|
8 |
+
from utils.loss import get_loss_function
|
9 |
+
|
10 |
+
|
11 |
+
class Trainer:
|
12 |
+
def __init__(self, model: YOLO, train_cfg: TrainConfig, device):
|
13 |
+
self.model = model.to(device)
|
14 |
+
self.device = device
|
15 |
+
self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
|
16 |
+
self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
|
17 |
+
self.loss_fn = get_loss_function()
|
18 |
+
|
19 |
+
if train_cfg.ema.get("enabled", False):
|
20 |
+
self.ema = EMA(model, decay=train_cfg.ema.decay)
|
21 |
+
else:
|
22 |
+
self.ema = None
|
23 |
+
|
24 |
+
def train_one_batch(self, data, targets):
|
25 |
+
data, targets = data.to(self.device), targets.to(self.device)
|
26 |
+
self.optimizer.zero_grad()
|
27 |
+
outputs = self.model(data)
|
28 |
+
loss = self.loss_fn(outputs, targets)
|
29 |
+
loss.backward()
|
30 |
+
self.optimizer.step()
|
31 |
+
if self.ema:
|
32 |
+
self.ema.update()
|
33 |
+
return loss.item()
|
34 |
+
|
35 |
+
def train_one_epoch(self, dataloader):
|
36 |
+
self.model.train()
|
37 |
+
total_loss = 0
|
38 |
+
for data, targets in tqdm(dataloader, desc="Training"):
|
39 |
+
loss = self.train_one_batch(data, targets)
|
40 |
+
total_loss += loss
|
41 |
+
if self.scheduler:
|
42 |
+
self.scheduler.step()
|
43 |
+
return total_loss / len(dataloader)
|
44 |
+
|
45 |
+
def save_checkpoint(self, epoch, filename="checkpoint.pt"):
|
46 |
+
checkpoint = {
|
47 |
+
"epoch": epoch,
|
48 |
+
"model_state_dict": self.model.state_dict(),
|
49 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
50 |
+
}
|
51 |
+
if self.ema:
|
52 |
+
self.ema.apply_shadow()
|
53 |
+
checkpoint["model_state_dict_ema"] = self.model.state_dict()
|
54 |
+
self.ema.restore()
|
55 |
+
torch.save(checkpoint, filename)
|
56 |
+
|
57 |
+
def train(self, dataloader, num_epochs):
|
58 |
+
logger.info("start train")
|
59 |
+
for epoch in range(num_epochs):
|
60 |
+
epoch_loss = self.train_one_epoch(dataloader)
|
61 |
+
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
62 |
+
if (epoch + 1) % 5 == 0:
|
63 |
+
self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
|
train.py
CHANGED
@@ -1,20 +1,27 @@
|
|
1 |
import hydra
|
|
|
2 |
from loguru import logger
|
3 |
|
4 |
from config.config import Config
|
5 |
from model.yolo import get_model
|
6 |
from tools.log_helper import custom_logger
|
7 |
-
from
|
|
|
8 |
from utils.get_dataset import prepare_dataset
|
9 |
|
10 |
|
11 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
12 |
def main(cfg: Config):
|
13 |
-
dataset = YoloDataset(cfg)
|
14 |
if cfg.download.auto:
|
15 |
prepare_dataset(cfg.download)
|
16 |
|
|
|
17 |
model = get_model(cfg.model)
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
if __name__ == "__main__":
|
|
|
1 |
import hydra
|
2 |
+
import torch
|
3 |
from loguru import logger
|
4 |
|
5 |
from config.config import Config
|
6 |
from model.yolo import get_model
|
7 |
from tools.log_helper import custom_logger
|
8 |
+
from tools.trainer import Trainer
|
9 |
+
from utils.dataloader import get_dataloader
|
10 |
from utils.get_dataset import prepare_dataset
|
11 |
|
12 |
|
13 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
14 |
def main(cfg: Config):
|
|
|
15 |
if cfg.download.auto:
|
16 |
prepare_dataset(cfg.download)
|
17 |
|
18 |
+
dataloader = get_dataloader(cfg)
|
19 |
model = get_model(cfg.model)
|
20 |
+
# TODO: get_device or rank, for DDP mode
|
21 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
+
|
23 |
+
trainer = Trainer(model, cfg.hyper.train, device)
|
24 |
+
trainer.train(dataloader, 10)
|
25 |
|
26 |
|
27 |
if __name__ == "__main__":
|
utils/converter_json2txt.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from typing import Dict, List, Optional
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
def discretize_categories(categories: List[Dict[str, int]]) -> Dict[int, int]:
|
9 |
+
"""
|
10 |
+
Maps each unique 'id' in the list of category dictionaries to a sequential integer index.
|
11 |
+
Indices are assigned based on the sorted 'id' values.
|
12 |
+
"""
|
13 |
+
sorted_categories = sorted(categories, key=lambda category: category["id"])
|
14 |
+
return {category["id"]: index for index, category in enumerate(sorted_categories)}
|
15 |
+
|
16 |
+
|
17 |
+
def process_annotations(
|
18 |
+
image_annotations: Dict[int, List[Dict]],
|
19 |
+
image_info_dict: Dict[int, tuple],
|
20 |
+
output_dir: str,
|
21 |
+
id_to_idx: Optional[Dict[int, int]] = None,
|
22 |
+
) -> None:
|
23 |
+
"""
|
24 |
+
Process and save annotations to files, with option to remap category IDs.
|
25 |
+
"""
|
26 |
+
for image_id, annotations in tqdm(image_annotations.items(), desc="Processing annotations"):
|
27 |
+
file_path = os.path.join(output_dir, f"{image_id:0>12}.txt")
|
28 |
+
if not annotations:
|
29 |
+
continue
|
30 |
+
with open(file_path, "w") as file:
|
31 |
+
for annotation in annotations:
|
32 |
+
process_annotation(annotation, image_info_dict[image_id], id_to_idx, file)
|
33 |
+
|
34 |
+
|
35 |
+
def process_annotation(annotation: Dict, image_dims: tuple, id_to_idx: Optional[Dict[int, int]], file) -> None:
|
36 |
+
"""
|
37 |
+
Convert a single annotation's segmentation and write it to the open file handle.
|
38 |
+
"""
|
39 |
+
category_id = annotation["category_id"]
|
40 |
+
segmentation = (
|
41 |
+
annotation["segmentation"][0]
|
42 |
+
if annotation["segmentation"] and isinstance(annotation["segmentation"][0], list)
|
43 |
+
else None
|
44 |
+
)
|
45 |
+
|
46 |
+
if segmentation is None:
|
47 |
+
return
|
48 |
+
|
49 |
+
img_width, img_height = image_dims
|
50 |
+
normalized_segmentation = normalize_segmentation(segmentation, img_width, img_height)
|
51 |
+
|
52 |
+
if id_to_idx:
|
53 |
+
category_id = id_to_idx.get(category_id, category_id)
|
54 |
+
|
55 |
+
file.write(f"{category_id} {' '.join(normalized_segmentation)}\n")
|
56 |
+
|
57 |
+
|
58 |
+
def normalize_segmentation(segmentation: List[float], img_width: int, img_height: int) -> List[str]:
|
59 |
+
"""
|
60 |
+
Normalize and format segmentation coordinates.
|
61 |
+
"""
|
62 |
+
return [f"{x/img_width:.6f}" if i % 2 == 0 else f"{x/img_height:.6f}" for i, x in enumerate(segmentation)]
|
63 |
+
|
64 |
+
|
65 |
+
def convert_annotations(json_file: str, output_dir: str) -> None:
|
66 |
+
"""
|
67 |
+
Load annotation data from a JSON file and process all annotations.
|
68 |
+
"""
|
69 |
+
with open(json_file) as file:
|
70 |
+
data = json.load(file)
|
71 |
+
|
72 |
+
os.makedirs(output_dir, exist_ok=True)
|
73 |
+
|
74 |
+
image_info_dict = {img["id"]: (img["width"], img["height"]) for img in data.get("images", [])}
|
75 |
+
id_to_idx = discretize_categories(data.get("categories", [])) if "categories" in data else None
|
76 |
+
image_annotations = {img_id: [] for img_id in image_info_dict}
|
77 |
+
|
78 |
+
for annotation in data.get("annotations", []):
|
79 |
+
if not annotation.get("iscrowd", False):
|
80 |
+
image_annotations[annotation["image_id"]].append(annotation)
|
81 |
+
|
82 |
+
process_annotations(image_annotations, image_info_dict, output_dir, id_to_idx)
|
83 |
+
|
84 |
+
|
85 |
+
convert_annotations("./data/coco/annotations/instances_train2017.json", "./data/coco/labels/train2017/")
|
86 |
+
convert_annotations("./data/coco/annotations/instances_val2017.json", "./data/coco/labels/val2017/")
|
utils/dataloader.py
CHANGED
@@ -96,7 +96,7 @@ class YoloDataset(Dataset):
|
|
96 |
cache[phase_name] = data
|
97 |
|
98 |
cache.close()
|
99 |
-
logger.info("Loaded {} cache", phase_name)
|
100 |
data = cache[phase_name]
|
101 |
return data
|
102 |
|
|
|
96 |
cache[phase_name] = data
|
97 |
|
98 |
cache.close()
|
99 |
+
logger.info("π¦ Loaded {} cache", phase_name)
|
100 |
data = cache[phase_name]
|
101 |
return data
|
102 |
|
utils/loss.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
def get_loss_function(*args, **kwargs):
|
2 |
+
raise NotImplementedError
|