π [Merge] branch 'main' into DATASET
Browse files- README.md +67 -37
- docs/CONTRIBUTING.md +44 -0
- docs/HOWTO.md +90 -0
- docs/MODELS.md +31 -0
- examples/example_train.py +5 -9
- yolo/config/config.py +17 -3
- yolo/config/hyper/default.yaml +23 -10
- yolo/model/module.py +4 -1
- yolo/model/yolo.py +3 -2
- yolo/tools/log_helper.py +45 -16
- yolo/tools/model_helper.py +23 -5
- yolo/tools/trainer.py +8 -8
- yolo/utils/dataloader.py +3 -2
README.md
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
-
#
|
2 |
-
An MIT license rewrite of YOLOv9
|
3 |
|
4 |

|
5 |
> [!IMPORTANT]
|
@@ -7,44 +6,75 @@ An MIT license rewrite of YOLOv9
|
|
7 |
>
|
8 |
> Use of this code is at your own risk and discretion. It is advisable to consult with the project owner before deploying or integrating into any critical systems.
|
9 |
|
10 |
-
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
|
|
|
|
15 |
|
16 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
[](https://star-history.com/#WongKinYiu/yolov9mit&Date)
|
19 |
|
20 |
-
##
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
- [x] Download script
|
32 |
-
- [ ] Auto Download
|
33 |
-
- [ ] xywh, xxyy, xcyc
|
34 |
-
- [x] Dataloder
|
35 |
-
- [x] Data arugment
|
36 |
-
- [ ] Model
|
37 |
-
- [ ] load model
|
38 |
-
- [ ] from yaml
|
39 |
-
- [ ] from github
|
40 |
-
- [x] trainer
|
41 |
-
- [x] train_one_iter
|
42 |
-
- [x] train_one_epoch
|
43 |
-
- [ ] DDP
|
44 |
-
- [x] EMA, OTA
|
45 |
-
- [ ] Loss
|
46 |
-
- [ ] Run
|
47 |
-
- [ ] train
|
48 |
-
- [ ] test
|
49 |
-
- [ ] demo
|
50 |
-
- [x] Configuration
|
|
|
1 |
+
# YOLO: Official Implementation of YOLOv{7, 9}
|
|
|
2 |
|
3 |

|
4 |
> [!IMPORTANT]
|
|
|
6 |
>
|
7 |
> Use of this code is at your own risk and discretion. It is advisable to consult with the project owner before deploying or integrating into any critical systems.
|
8 |
|
9 |
+
Welcome to the official implementation of the YOLOv7 and YOLOv9. This repository will contains the complete codebase, pre-trained models, and detailed instructions for training and deploying YOLOv9.
|
10 |
|
11 |
+
## TL;DR
|
12 |
+
- Official YOLOv9 model implementation.
|
13 |
+
- Features real-time detection with state-of-the-art accuracy.
|
14 |
+
<!-- - Includes pre-trained models and training scripts. -->
|
15 |
+
- Quick train: `python examples/example_train.py`
|
16 |
|
17 |
+
## Introduction
|
18 |
+
- [**YOLOv9**: Learning What You Want to Learn Using Programmable Gradient Information](https://arxiv.org/abs/2402.13616)
|
19 |
+
- [**YOLOv7**: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors](https://arxiv.org/abs/2207.02696)
|
20 |
|
21 |
+
## Installation
|
22 |
+
To get started with YOLOv9, clone this repository and install the required dependencies:
|
23 |
+
```shell
|
24 |
+
git clone [email protected]:WongKinYiu/yolov9mit.git
|
25 |
+
cd yolov9mit
|
26 |
+
pip install -r requirements.txt
|
27 |
+
```
|
28 |
+
|
29 |
+
<!--
|
30 |
+
```
|
31 |
+
pip install git+https://github.com/WongKinYiu/yolov9mit.git
|
32 |
+
```
|
33 |
+
-->
|
34 |
+
|
35 |
+
<!-- ### Quick Start
|
36 |
+
Run YOLOv9 on a pre-trained model with:
|
37 |
+
|
38 |
+
```shell
|
39 |
+
python examples/example_train.py hyper.data.batch_size=8
|
40 |
+
``` -->
|
41 |
+
|
42 |
+
<!-- ## Model Zoo[WIP]
|
43 |
+
Find pre-trained models with benchmarks on various datasets in the [Model Zoo](docs/MODELS). -->
|
44 |
+
|
45 |
+
## Training
|
46 |
+
For training YOLOv9 on your dataset:
|
47 |
+
|
48 |
+
Modify the configuration file data/config.yaml to point to your dataset.
|
49 |
+
Run the training script:
|
50 |
|
51 |
+
```shell
|
52 |
+
python examples/example_train.py hyper.data.batch_size=8 model=v9-c
|
53 |
+
```
|
54 |
+
|
55 |
+
More customization details, or ways to modify the model can be found [HOWTO](docs/HOWTO).
|
56 |
+
|
57 |
+
## Evaluation [WIP]
|
58 |
+
Evaluate the model performance using:
|
59 |
+
|
60 |
+
```shell
|
61 |
+
python examples/examples_evaluate.py weights=v9-c.pt
|
62 |
+
```
|
63 |
+
|
64 |
+
## Contributing
|
65 |
+
Contributions to the YOLOv9 project are welcome! See [CONTRIBUTING](docs/CONTRIBUTING.md) for how to help out.
|
66 |
+
|
67 |
+
## Star History
|
68 |
[](https://star-history.com/#WongKinYiu/yolov9mit&Date)
|
69 |
|
70 |
+
## Citations
|
71 |
+
```
|
72 |
+
@misc{wang2024yolov9,
|
73 |
+
title={YOLOv9: Learning What You Want to Learn Using Programmable Gradient Information},
|
74 |
+
author={Chien-Yao Wang and I-Hau Yeh and Hong-Yuan Mark Liao},
|
75 |
+
year={2024},
|
76 |
+
eprint={2402.13616},
|
77 |
+
archivePrefix={arXiv},
|
78 |
+
primaryClass={cs.CV}
|
79 |
+
}
|
80 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/CONTRIBUTING.md
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to YOLO
|
2 |
+
|
3 |
+
Thank you for your interest in contributing to this project! We value your contributions and want to make the process as easy and enjoyable as possible. Below you will find the guidelines for contributing.
|
4 |
+
|
5 |
+
## Quick Links
|
6 |
+
- [Main README](../README.md)
|
7 |
+
- [License](../LICENSE)
|
8 |
+
- [Issue Tracker](https://github.com/WongKinYiu/yolov9mit/issues)
|
9 |
+
- [Pull Requests](https://github.com/WongKinYiu/yolov9mit/pulls)
|
10 |
+
|
11 |
+
## Testing and Formatting
|
12 |
+
We strive to maintain a high standard of quality in our codebase:
|
13 |
+
- **Testing:** We use `pytest` for testing. Please add tests for new code you create.
|
14 |
+
- **Formatting:** Our code follows a consistent style enforced by `isort` for imports sorting and `black` for code formatting. Run these tools to format your code before submitting a pull request.
|
15 |
+
|
16 |
+
## GitHub Actions
|
17 |
+
We utilize GitHub Actions for continuous integration. When you submit a pull request, automated tests and formatting checks will run. Ensure that these checks pass for your pull request to be accepted.
|
18 |
+
|
19 |
+
## How to Contribute
|
20 |
+
|
21 |
+
### Proposing Enhancements
|
22 |
+
For feature requests or improvements, open an issue with:
|
23 |
+
- A clear title and description.
|
24 |
+
- Explain why this enhancement would be useful.
|
25 |
+
- Considerations or potential implementation details.
|
26 |
+
|
27 |
+
## Pull Request Checklist
|
28 |
+
Before sending your pull request, always check the following:
|
29 |
+
- The code follows the [Python style guide](https://www.python.org/dev/peps/pep-0008/).
|
30 |
+
- Code and files are well organized.
|
31 |
+
- All tests pass.
|
32 |
+
- New code is covered by tests.
|
33 |
+
- We would be very happy if [gitmojiπ](https://www.npmjs.com/package/gitmoji-cli) could be used to assist the commit messageπ¬!
|
34 |
+
|
35 |
+
## Code Review Process
|
36 |
+
Once you submit a PR, maintainers will review your work, suggest changes if necessary, and merge it once itβs approved.
|
37 |
+
|
38 |
+
---
|
39 |
+
|
40 |
+
Your contributions are greatly appreciated and vital to the project's success!
|
41 |
+
|
42 |
+
Please feel free to contact [[email protected]](mailto:[email protected])!
|
43 |
+
|
44 |
+
|
docs/HOWTO.md
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# How To modified YOLO
|
2 |
+
|
3 |
+
To facilitate easy customization of the YOLO model, we've structured the codebase to allow for changes through configuration files and minimal code adjustments. This guide will walk you through the steps to customize various components of the model including the architecture, blocks, data loaders, and loss functions.
|
4 |
+
|
5 |
+
## Custom Model Architecture
|
6 |
+
|
7 |
+
You can change the model architecture simply by modifying the YAML configuration file. Here's how:
|
8 |
+
|
9 |
+
1. **Modify Architecture in Config:**
|
10 |
+
|
11 |
+
Navigate to your model's configuration file (typically formate like `yolo/config/model/v9-c.yaml`).
|
12 |
+
- Adjust the architecture settings under the `architecture` section. Ensure that every module you reference exists in `module.py`, or refer to the next section on how to add new modules.
|
13 |
+
|
14 |
+
```yaml
|
15 |
+
model:
|
16 |
+
foo:
|
17 |
+
- ADown:
|
18 |
+
args: {out_channels: 256}
|
19 |
+
- RepNCSPELAN:
|
20 |
+
source: -2
|
21 |
+
args: {out_channels: 512, part_channels: 256}
|
22 |
+
tags: B4
|
23 |
+
bar:
|
24 |
+
- Concat:
|
25 |
+
source: [-2, B4]
|
26 |
+
```
|
27 |
+
|
28 |
+
`tags`: Use this to labels any module you want, and could be the module source.
|
29 |
+
|
30 |
+
`source`: Set this to the index of the module output you wish to use as input; default is `-1` which refers to the last module's output. Capable tags, relative position, absolute position
|
31 |
+
|
32 |
+
`args`: A dictionary used to initialize parameters for convolutional or bottleneck layers.
|
33 |
+
|
34 |
+
`output`: Whether to serve as the output of the model.
|
35 |
+
|
36 |
+
## Custom Block
|
37 |
+
|
38 |
+
To add or modify a block in the model:
|
39 |
+
|
40 |
+
1. **Create a New Module:**
|
41 |
+
|
42 |
+
Define a new class in `module.py` that inherits from `nn.Module`.
|
43 |
+
|
44 |
+
The constructor should accept `in_channels` as a parameter. Make sure to calculate `out_channels` based on your model's requirements or configure it through the YAML file using `args`.
|
45 |
+
|
46 |
+
```python
|
47 |
+
class CustomBlock(nn.Module):
|
48 |
+
def __init__(self, in_channels, out_channels, **kwargs):
|
49 |
+
super().__init__()
|
50 |
+
self.module = # conv, bool, ...
|
51 |
+
def forward(self, x):
|
52 |
+
return self.module(x)
|
53 |
+
```
|
54 |
+
|
55 |
+
2. **Reference in Config:**
|
56 |
+
```yaml
|
57 |
+
...
|
58 |
+
- CustomBlock:
|
59 |
+
args: {out_channels: int, etc: ...}
|
60 |
+
...
|
61 |
+
...
|
62 |
+
```
|
63 |
+
|
64 |
+
|
65 |
+
## Custom Data Augmentation
|
66 |
+
|
67 |
+
Custom transformations should be designed to accept an image and its bounding boxes, and return them after applying the desired changes. Hereβs how you can define such a transformation:
|
68 |
+
|
69 |
+
|
70 |
+
1. **Define Dataset:**
|
71 |
+
|
72 |
+
Your class must have a `__call__` method that takes a PIL image and its corresponding bounding boxes as input, and returns them after processing.
|
73 |
+
|
74 |
+
|
75 |
+
```python
|
76 |
+
class CustomTransform:
|
77 |
+
def __init__(self, prob=0.5):
|
78 |
+
self.prob = prob
|
79 |
+
|
80 |
+
def __call__(self, image, boxes):
|
81 |
+
return image, boxes
|
82 |
+
```
|
83 |
+
2. **Update CustomTransform in Config:**
|
84 |
+
|
85 |
+
Specify your custom transformation in a YAML config `yolo/config/data/augment.yaml`. For examples:
|
86 |
+
```yaml
|
87 |
+
Mosaic: 1
|
88 |
+
# ... (Other Transform)
|
89 |
+
CustomTransform: 0.5
|
90 |
+
```
|
docs/MODELS.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# YOLO Model Zoo
|
2 |
+
|
3 |
+
Welcome to the YOLOv9 Model Zoo! Here, you will find a variety of pre-trained models tailored to different use cases and performance needs. Each model comes with detailed information about its training regime, performance metrics, and usage instructions.
|
4 |
+
|
5 |
+
## Standard Models
|
6 |
+
|
7 |
+
These models are trained on common datasets like COCO and provide a balance between speed and accuracy.
|
8 |
+
|
9 |
+
|
10 |
+
| Model | Test Size | AP<sup>val</sup> | AP<sub>50</sub><sup>val</sup> | AP<sub>75</sub><sup>val</sup> | Param. | FLOPs |
|
11 |
+
| :-- | :-: | :-: | :-: | :-: | :-: | :-: |
|
12 |
+
| [**YOLOv9-T**]() | 640 | **38.3%** | **53.1%** | **41.3%** | **2.0M** | **7.7G** |
|
13 |
+
| [**YOLOv9-S**]() | 640 | **46.8%** | **63.4%** | **50.7%** | **7.1M** | **26.4G** |
|
14 |
+
| [**YOLOv9-M**]() | 640 | **51.4%** | **68.1%** | **56.1%** | **20.0M** | **76.3G** |
|
15 |
+
| [**YOLOv9-C**]() | 640 | **53.0%** | **70.2%** | **57.8%** | **25.3M** | **102.1G** |
|
16 |
+
| [**YOLOv9-E**]() | 640 | **55.6%** | **72.8%** | **60.6%** | **57.3M** | **189.0G** |
|
17 |
+
| | | | | | | |
|
18 |
+
| [**YOLOv7**]() | 640 | **51.4%** | **69.7%** | **55.9%** |
|
19 |
+
| [**YOLOv7-X**]() | 640 | **53.1%** | **71.2%** | **57.8%** |
|
20 |
+
| [**YOLOv7-W6**]() | 1280 | **54.9%** | **72.6%** | **60.1%** |
|
21 |
+
| [**YOLOv7-E6**]() | 1280 | **56.0%** | **73.5%** | **61.2%** |
|
22 |
+
| [**YOLOv7-D6**]() | 1280 | **56.6%** | **74.0%** | **61.8%** |
|
23 |
+
| [**YOLOv7-E6E**]() | 1280 | **56.8%** | **74.4%** | **62.1%** |
|
24 |
+
|
25 |
+
## Download and Usage Instructions
|
26 |
+
|
27 |
+
To use these models, download them from the links provided and use the following command to run detection:
|
28 |
+
|
29 |
+
```bash
|
30 |
+
$yolo detect weights=path/to/model.pt img=640 conf=0.25 source=your_image.jpg
|
31 |
+
```
|
examples/example_train.py
CHANGED
@@ -9,29 +9,25 @@ project_root = Path(__file__).resolve().parent.parent
|
|
9 |
sys.path.append(str(project_root))
|
10 |
|
11 |
from yolo.config.config import Config
|
12 |
-
from yolo.
|
13 |
-
from yolo.tools.log_helper import custom_logger
|
14 |
from yolo.tools.trainer import Trainer
|
15 |
from yolo.utils.dataloader import get_dataloader
|
16 |
-
from yolo.utils.drawer import draw_model
|
17 |
from yolo.utils.get_dataset import prepare_dataset
|
18 |
|
19 |
|
20 |
@hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
|
21 |
def main(cfg: Config):
|
|
|
|
|
22 |
if cfg.download.auto:
|
23 |
prepare_dataset(cfg.download)
|
24 |
|
25 |
dataloader = get_dataloader(cfg)
|
26 |
-
model = get_model(cfg)
|
27 |
-
draw_model(model=model)
|
28 |
# TODO: get_device or rank, for DDP mode
|
29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
-
|
31 |
-
trainer
|
32 |
-
trainer.train(dataloader, 10)
|
33 |
|
34 |
|
35 |
if __name__ == "__main__":
|
36 |
-
custom_logger()
|
37 |
main()
|
|
|
9 |
sys.path.append(str(project_root))
|
10 |
|
11 |
from yolo.config.config import Config
|
12 |
+
from yolo.tools.log_helper import custom_logger, get_valid_folder
|
|
|
13 |
from yolo.tools.trainer import Trainer
|
14 |
from yolo.utils.dataloader import get_dataloader
|
|
|
15 |
from yolo.utils.get_dataset import prepare_dataset
|
16 |
|
17 |
|
18 |
@hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
|
19 |
def main(cfg: Config):
|
20 |
+
custom_logger()
|
21 |
+
save_path = get_valid_folder(cfg.hyper.general, cfg.name)
|
22 |
if cfg.download.auto:
|
23 |
prepare_dataset(cfg.download)
|
24 |
|
25 |
dataloader = get_dataloader(cfg)
|
|
|
|
|
26 |
# TODO: get_device or rank, for DDP mode
|
27 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
trainer = Trainer(cfg, save_path, device)
|
29 |
+
trainer.train(dataloader, cfg.hyper.train.epoch)
|
|
|
30 |
|
31 |
|
32 |
if __name__ == "__main__":
|
|
|
33 |
main()
|
yolo/config/config.py
CHANGED
@@ -25,11 +25,10 @@ class Download:
|
|
25 |
@dataclass
|
26 |
class DataLoaderConfig:
|
27 |
batch_size: int
|
|
|
|
|
28 |
shuffle: bool
|
29 |
-
num_workers: int
|
30 |
pin_memory: bool
|
31 |
-
image_size: List[int]
|
32 |
-
class_num: int
|
33 |
|
34 |
|
35 |
@dataclass
|
@@ -54,6 +53,7 @@ class SchedulerArgs:
|
|
54 |
class SchedulerConfig:
|
55 |
type: str
|
56 |
args: SchedulerArgs
|
|
|
57 |
|
58 |
|
59 |
@dataclass
|
@@ -85,8 +85,22 @@ class TrainConfig:
|
|
85 |
loss: LossConfig
|
86 |
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
@dataclass
|
89 |
class HyperConfig:
|
|
|
90 |
data: DataLoaderConfig
|
91 |
train: TrainConfig
|
92 |
|
|
|
25 |
@dataclass
|
26 |
class DataLoaderConfig:
|
27 |
batch_size: int
|
28 |
+
class_num: int
|
29 |
+
image_size: List[int]
|
30 |
shuffle: bool
|
|
|
31 |
pin_memory: bool
|
|
|
|
|
32 |
|
33 |
|
34 |
@dataclass
|
|
|
53 |
class SchedulerConfig:
|
54 |
type: str
|
55 |
args: SchedulerArgs
|
56 |
+
warmup: Dict[str, Union[str, int, float]]
|
57 |
|
58 |
|
59 |
@dataclass
|
|
|
85 |
loss: LossConfig
|
86 |
|
87 |
|
88 |
+
@dataclass
|
89 |
+
class GeneralConfig:
|
90 |
+
out_path: str
|
91 |
+
task: str
|
92 |
+
device: Union[str, int, List[int]]
|
93 |
+
cpu_num: int
|
94 |
+
use_wandb: bool
|
95 |
+
lucky_number: 10
|
96 |
+
exist_ok: bool
|
97 |
+
resume_train: bool
|
98 |
+
use_TensorBoard: bool
|
99 |
+
|
100 |
+
|
101 |
@dataclass
|
102 |
class HyperConfig:
|
103 |
+
general: GeneralConfig
|
104 |
data: DataLoaderConfig
|
105 |
train: TrainConfig
|
106 |
|
yolo/config/hyper/default.yaml
CHANGED
@@ -1,17 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
data:
|
2 |
batch_size: 16
|
3 |
-
shuffle: True
|
4 |
-
num_workers: 16
|
5 |
-
pin_memory: True
|
6 |
class_num: 80
|
7 |
image_size: [640, 640]
|
|
|
|
|
8 |
train:
|
9 |
-
epoch:
|
10 |
optimizer:
|
11 |
-
type:
|
12 |
args:
|
13 |
-
lr: 0.
|
14 |
-
weight_decay: 0.
|
|
|
15 |
loss:
|
16 |
objective:
|
17 |
BCELoss: 0.5
|
@@ -26,10 +36,13 @@ train:
|
|
26 |
iou: 6.0
|
27 |
cls: 0.5
|
28 |
scheduler:
|
29 |
-
type:
|
|
|
|
|
30 |
args:
|
31 |
-
|
32 |
-
|
|
|
33 |
ema:
|
34 |
enabled: true
|
35 |
decay: 0.995
|
|
|
1 |
+
general:
|
2 |
+
out_path: runs
|
3 |
+
task: train
|
4 |
+
deivce: [0]
|
5 |
+
cpu_num: 16
|
6 |
+
use_wandb: False
|
7 |
+
lucky_number: 10
|
8 |
+
exist_ok: True
|
9 |
+
resume_train: False
|
10 |
+
use_TensorBoard: False
|
11 |
data:
|
12 |
batch_size: 16
|
|
|
|
|
|
|
13 |
class_num: 80
|
14 |
image_size: [640, 640]
|
15 |
+
shuffle: True
|
16 |
+
pin_memory: True
|
17 |
train:
|
18 |
+
epoch: 500
|
19 |
optimizer:
|
20 |
+
type: SGD
|
21 |
args:
|
22 |
+
lr: 0.01
|
23 |
+
weight_decay: 0.0005
|
24 |
+
momentum: 0.937
|
25 |
loss:
|
26 |
objective:
|
27 |
BCELoss: 0.5
|
|
|
36 |
iou: 6.0
|
37 |
cls: 0.5
|
38 |
scheduler:
|
39 |
+
type: LinearLR
|
40 |
+
warmup:
|
41 |
+
epochs: 3.0
|
42 |
args:
|
43 |
+
total_iters: ${hyper.train.epoch}
|
44 |
+
start_factor: 1
|
45 |
+
end_factor: 0.01
|
46 |
ema:
|
47 |
enabled: true
|
48 |
decay: 0.995
|
yolo/model/module.py
CHANGED
@@ -25,7 +25,7 @@ class Conv(nn.Module):
|
|
25 |
super().__init__()
|
26 |
kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs))
|
27 |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
|
28 |
-
self.bn = nn.BatchNorm2d(out_channels)
|
29 |
self.act = get_activation(activation)
|
30 |
|
31 |
def forward(self, x: Tensor) -> Tensor:
|
@@ -69,6 +69,9 @@ class Detection(nn.Module):
|
|
69 |
Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1)
|
70 |
)
|
71 |
|
|
|
|
|
|
|
72 |
def forward(self, x: List[Tensor]) -> List[Tensor]:
|
73 |
anchor_x = self.anchor_conv(x)
|
74 |
class_x = self.class_conv(x)
|
|
|
25 |
super().__init__()
|
26 |
kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs))
|
27 |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
|
28 |
+
self.bn = nn.BatchNorm2d(out_channels, eps=1e-3, momentum=3e-2)
|
29 |
self.act = get_activation(activation)
|
30 |
|
31 |
def forward(self, x: Tensor) -> Tensor:
|
|
|
69 |
Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1)
|
70 |
)
|
71 |
|
72 |
+
self.anchor_conv[-1].bias.data.fill_(1.0)
|
73 |
+
self.class_conv[-1].bias.data.fill_(-10)
|
74 |
+
|
75 |
def forward(self, x: List[Tensor]) -> List[Tensor]:
|
76 |
anchor_x = self.anchor_conv(x)
|
77 |
class_x = self.class_conv(x)
|
yolo/model/yolo.py
CHANGED
@@ -7,6 +7,7 @@ from omegaconf import ListConfig, OmegaConf
|
|
7 |
from yolo.config.config import Config, Model, YOLOLayer
|
8 |
from yolo.tools.layer_helper import get_layer_map
|
9 |
from yolo.tools.log_helper import log_model
|
|
|
10 |
|
11 |
|
12 |
class YOLO(nn.Module):
|
@@ -24,8 +25,6 @@ class YOLO(nn.Module):
|
|
24 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
25 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
26 |
self.build_model(model_cfg.model)
|
27 |
-
# TODO: Move to other position
|
28 |
-
log_model(self.model)
|
29 |
|
30 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
31 |
self.layer_index = {}
|
@@ -126,4 +125,6 @@ def get_model(cfg: Config) -> YOLO:
|
|
126 |
OmegaConf.set_struct(cfg.model, False)
|
127 |
model = YOLO(cfg.model, cfg.hyper.data.class_num)
|
128 |
logger.info("β
Success load model")
|
|
|
|
|
129 |
return model
|
|
|
7 |
from yolo.config.config import Config, Model, YOLOLayer
|
8 |
from yolo.tools.layer_helper import get_layer_map
|
9 |
from yolo.tools.log_helper import log_model
|
10 |
+
from yolo.utils.drawer import draw_model
|
11 |
|
12 |
|
13 |
class YOLO(nn.Module):
|
|
|
25 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
26 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
27 |
self.build_model(model_cfg.model)
|
|
|
|
|
28 |
|
29 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
30 |
self.layer_index = {}
|
|
|
125 |
OmegaConf.set_struct(cfg.model, False)
|
126 |
model = YOLO(cfg.model, cfg.hyper.data.class_num)
|
127 |
logger.info("β
Success load model")
|
128 |
+
log_model(model.model)
|
129 |
+
# draw_model(model=model)
|
130 |
return model
|
yolo/tools/log_helper.py
CHANGED
@@ -11,6 +11,7 @@ Example:
|
|
11 |
custom_logger()
|
12 |
"""
|
13 |
|
|
|
14 |
import sys
|
15 |
from typing import Dict, List
|
16 |
|
@@ -22,19 +23,20 @@ from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
|
22 |
from rich.table import Table
|
23 |
from torch import Tensor
|
24 |
|
25 |
-
from yolo.config.config import Config, YOLOLayer
|
26 |
|
27 |
|
28 |
def custom_logger():
|
29 |
logger.remove()
|
30 |
logger.add(
|
31 |
sys.stderr,
|
32 |
-
|
|
|
33 |
)
|
34 |
|
35 |
|
36 |
class CustomProgress:
|
37 |
-
def __init__(self, cfg: Config, use_wandb: bool = False):
|
38 |
self.progress = Progress(
|
39 |
TextColumn("[progress.description]{task.description}"),
|
40 |
BarColumn(bar_width=None),
|
@@ -44,18 +46,19 @@ class CustomProgress:
|
|
44 |
self.use_wandb = use_wandb
|
45 |
if self.use_wandb:
|
46 |
wandb.errors.term._log = custom_wandb_log
|
47 |
-
self.wandb = wandb.init(
|
|
|
|
|
48 |
|
49 |
def start_train(self, num_epochs: int):
|
50 |
-
self.task_epoch = self.progress.add_task("[cyan]Epochs", total=num_epochs)
|
51 |
|
52 |
-
def
|
53 |
-
self.
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
def start_batch(self, num_batches):
|
59 |
self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
|
60 |
|
61 |
def one_batch(self, loss_dict: Dict[str, Tensor]):
|
@@ -63,21 +66,25 @@ class CustomProgress:
|
|
63 |
for loss_name, loss_value in loss_dict.items():
|
64 |
self.wandb.log({f"Loss/{loss_name}": loss_value})
|
65 |
|
66 |
-
loss_str = "
|
67 |
for loss_name, loss_val in loss_dict.items():
|
68 |
-
loss_str += f" {
|
69 |
|
70 |
self.progress.update(self.batch_task, advance=1, description=f"[green]Batches [white]{loss_str}")
|
71 |
|
72 |
-
def
|
73 |
self.progress.remove_task(self.batch_task)
|
|
|
|
|
|
|
|
|
74 |
|
75 |
|
76 |
def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
|
77 |
if silent:
|
78 |
return
|
79 |
for line in string.split("\n"):
|
80 |
-
logger.opt(raw=not newline).info("π " + line)
|
81 |
|
82 |
|
83 |
def log_model(model: List[YOLOLayer]):
|
@@ -99,3 +106,25 @@ def log_model(model: List[YOLOLayer]):
|
|
99 |
channels = "-"
|
100 |
table.add_row(str(idx), layer.layer_type, layer.tags, f"{layer_param:,}", channels)
|
101 |
console.print(table)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
custom_logger()
|
12 |
"""
|
13 |
|
14 |
+
import os
|
15 |
import sys
|
16 |
from typing import Dict, List
|
17 |
|
|
|
23 |
from rich.table import Table
|
24 |
from torch import Tensor
|
25 |
|
26 |
+
from yolo.config.config import Config, GeneralConfig, YOLOLayer
|
27 |
|
28 |
|
29 |
def custom_logger():
|
30 |
logger.remove()
|
31 |
logger.add(
|
32 |
sys.stderr,
|
33 |
+
colorize=True,
|
34 |
+
format="<fg #003385>[{time:MM/DD HH:mm:ss}]</> <level>{level: ^8}</level>| <level>{message}</level>",
|
35 |
)
|
36 |
|
37 |
|
38 |
class CustomProgress:
|
39 |
+
def __init__(self, cfg: Config, save_path: str, use_wandb: bool = False):
|
40 |
self.progress = Progress(
|
41 |
TextColumn("[progress.description]{task.description}"),
|
42 |
BarColumn(bar_width=None),
|
|
|
46 |
self.use_wandb = use_wandb
|
47 |
if self.use_wandb:
|
48 |
wandb.errors.term._log = custom_wandb_log
|
49 |
+
self.wandb = wandb.init(
|
50 |
+
project="YOLO", resume="allow", mode="online", dir=save_path, id=None, name=cfg.name
|
51 |
+
)
|
52 |
|
53 |
def start_train(self, num_epochs: int):
|
54 |
+
self.task_epoch = self.progress.add_task("[cyan]Epochs [white]| Loss | Box | DFL | BCE |", total=num_epochs)
|
55 |
|
56 |
+
def start_one_epoch(self, num_batches, optimizer, epoch_idx):
|
57 |
+
if self.use_wandb:
|
58 |
+
lr_values = [params["lr"] for params in optimizer.param_groups]
|
59 |
+
lr_names = ["bias", "norm", "conv"]
|
60 |
+
for lr_name, lr_value in zip(lr_names, lr_values):
|
61 |
+
self.wandb.log({f"Learning Rate/{lr_name}": lr_value}, step=epoch_idx)
|
|
|
62 |
self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
|
63 |
|
64 |
def one_batch(self, loss_dict: Dict[str, Tensor]):
|
|
|
66 |
for loss_name, loss_value in loss_dict.items():
|
67 |
self.wandb.log({f"Loss/{loss_name}": loss_value})
|
68 |
|
69 |
+
loss_str = "| -.-- |"
|
70 |
for loss_name, loss_val in loss_dict.items():
|
71 |
+
loss_str += f" {loss_val:2.2f} |"
|
72 |
|
73 |
self.progress.update(self.batch_task, advance=1, description=f"[green]Batches [white]{loss_str}")
|
74 |
|
75 |
+
def finish_one_epoch(self):
|
76 |
self.progress.remove_task(self.batch_task)
|
77 |
+
self.progress.update(self.task_epoch, advance=1)
|
78 |
+
|
79 |
+
def finish_train(self):
|
80 |
+
self.wandb.finish()
|
81 |
|
82 |
|
83 |
def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
|
84 |
if silent:
|
85 |
return
|
86 |
for line in string.split("\n"):
|
87 |
+
logger.opt(raw=not newline, colors=True).info("π " + line)
|
88 |
|
89 |
|
90 |
def log_model(model: List[YOLOLayer]):
|
|
|
106 |
channels = "-"
|
107 |
table.add_row(str(idx), layer.layer_type, layer.tags, f"{layer_param:,}", channels)
|
108 |
console.print(table)
|
109 |
+
|
110 |
+
|
111 |
+
def get_valid_folder(general_cfg: GeneralConfig, exp_name):
|
112 |
+
base_path = os.path.join(general_cfg.out_path, general_cfg.task)
|
113 |
+
save_path = os.path.join(base_path, exp_name)
|
114 |
+
|
115 |
+
if not general_cfg.exist_ok:
|
116 |
+
index = 1
|
117 |
+
old_exp_name = exp_name
|
118 |
+
while os.path.isdir(save_path):
|
119 |
+
exp_name = f"{old_exp_name}{index}"
|
120 |
+
save_path = os.path.join(base_path, exp_name)
|
121 |
+
index += 1
|
122 |
+
if index > 1:
|
123 |
+
logger.opt(colors=True).warning(
|
124 |
+
f"π Experiment directory exists! Changed <red>{old_exp_name}</> to <green>{exp_name}</>"
|
125 |
+
)
|
126 |
+
|
127 |
+
os.makedirs(save_path, exist_ok=True)
|
128 |
+
logger.opt(colors=True).info(f"π Created log folder: <u><fg #808080>{save_path}</></>")
|
129 |
+
logger.add(os.path.join(save_path, "output.log"), backtrace=True, diagnose=True)
|
130 |
+
return save_path
|
yolo/tools/model_helper.py
CHANGED
@@ -2,9 +2,10 @@ from typing import Any, Dict, Type
|
|
2 |
|
3 |
import torch
|
4 |
from torch.optim import Optimizer
|
5 |
-
from torch.optim.lr_scheduler import _LRScheduler
|
6 |
|
7 |
from yolo.config.config import OptimizerConfig, SchedulerConfig
|
|
|
8 |
|
9 |
|
10 |
class EMA:
|
@@ -31,21 +32,38 @@ class EMA:
|
|
31 |
self.shadow[name].copy_(param.data)
|
32 |
|
33 |
|
34 |
-
def get_optimizer(
|
35 |
"""Create an optimizer for the given model parameters based on the configuration.
|
36 |
|
37 |
Returns:
|
38 |
An instance of the optimizer configured according to the provided settings.
|
39 |
"""
|
40 |
optimizer_class: Type[Optimizer] = getattr(torch.optim, optim_cfg.type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
return optimizer_class(model_parameters, **optim_cfg.args)
|
42 |
|
43 |
|
44 |
-
def get_scheduler(optimizer: Optimizer,
|
45 |
"""Create a learning rate scheduler for the given optimizer based on the configuration.
|
46 |
|
47 |
Returns:
|
48 |
An instance of the scheduler configured according to the provided settings.
|
49 |
"""
|
50 |
-
scheduler_class: Type[_LRScheduler] = getattr(torch.optim.lr_scheduler,
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import torch
|
4 |
from torch.optim import Optimizer
|
5 |
+
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
|
6 |
|
7 |
from yolo.config.config import OptimizerConfig, SchedulerConfig
|
8 |
+
from yolo.model.yolo import YOLO
|
9 |
|
10 |
|
11 |
class EMA:
|
|
|
32 |
self.shadow[name].copy_(param.data)
|
33 |
|
34 |
|
35 |
+
def get_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|
36 |
"""Create an optimizer for the given model parameters based on the configuration.
|
37 |
|
38 |
Returns:
|
39 |
An instance of the optimizer configured according to the provided settings.
|
40 |
"""
|
41 |
optimizer_class: Type[Optimizer] = getattr(torch.optim, optim_cfg.type)
|
42 |
+
|
43 |
+
bias_params = [p for name, p in model.named_parameters() if "bias" in name]
|
44 |
+
norm_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" in name]
|
45 |
+
conv_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" not in name]
|
46 |
+
|
47 |
+
model_parameters = [
|
48 |
+
{"params": bias_params, "nestrov": True, "momentum": 0.937},
|
49 |
+
{"params": conv_params, "weight_decay": 0.0},
|
50 |
+
{"params": norm_params, "weight_decay": 1e-5},
|
51 |
+
]
|
52 |
return optimizer_class(model_parameters, **optim_cfg.args)
|
53 |
|
54 |
|
55 |
+
def get_scheduler(optimizer: Optimizer, schedule_cfg: SchedulerConfig) -> _LRScheduler:
|
56 |
"""Create a learning rate scheduler for the given optimizer based on the configuration.
|
57 |
|
58 |
Returns:
|
59 |
An instance of the scheduler configured according to the provided settings.
|
60 |
"""
|
61 |
+
scheduler_class: Type[_LRScheduler] = getattr(torch.optim.lr_scheduler, schedule_cfg.type)
|
62 |
+
schedule = scheduler_class(optimizer, **schedule_cfg.args)
|
63 |
+
if hasattr(schedule_cfg, "warmup"):
|
64 |
+
wepoch = schedule_cfg.warmup.epochs
|
65 |
+
lambda1 = lambda epoch: 0.1 + 0.9 * (epoch + 1 / wepoch) if epoch < wepoch else 1
|
66 |
+
lambda2 = lambda epoch: 10 - 9 * (epoch + 1 / wepoch) if epoch < wepoch else 1
|
67 |
+
warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2, lambda1])
|
68 |
+
schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[2])
|
69 |
+
return schedule
|
yolo/tools/trainer.py
CHANGED
@@ -6,22 +6,23 @@ from torch import Tensor
|
|
6 |
from torch.cuda.amp import GradScaler, autocast
|
7 |
|
8 |
from yolo.config.config import Config, TrainConfig
|
9 |
-
from yolo.model.yolo import
|
10 |
from yolo.tools.log_helper import CustomProgress
|
11 |
from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
|
12 |
from yolo.utils.loss import get_loss_function
|
13 |
|
14 |
|
15 |
class Trainer:
|
16 |
-
def __init__(self,
|
17 |
train_cfg: TrainConfig = cfg.hyper.train
|
|
|
18 |
|
19 |
self.model = model.to(device)
|
20 |
self.device = device
|
21 |
-
self.optimizer = get_optimizer(model
|
22 |
self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
|
23 |
self.loss_fn = get_loss_function(cfg)
|
24 |
-
self.progress = CustomProgress(cfg, use_wandb=True)
|
25 |
|
26 |
if getattr(train_cfg.ema, "enabled", False):
|
27 |
self.ema = EMA(model, decay=train_cfg.ema.decay)
|
@@ -46,7 +47,6 @@ class Trainer:
|
|
46 |
def train_one_epoch(self, dataloader):
|
47 |
self.model.train()
|
48 |
total_loss = 0
|
49 |
-
self.progress.start_batch(len(dataloader))
|
50 |
|
51 |
for data, targets in dataloader:
|
52 |
loss, loss_each = self.train_one_batch(data, targets)
|
@@ -57,7 +57,6 @@ class Trainer:
|
|
57 |
if self.scheduler:
|
58 |
self.scheduler.step()
|
59 |
|
60 |
-
self.progress.finish_batch()
|
61 |
return total_loss / len(dataloader)
|
62 |
|
63 |
def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
|
@@ -79,8 +78,9 @@ class Trainer:
|
|
79 |
self.progress.start_train(num_epochs)
|
80 |
for epoch in range(num_epochs):
|
81 |
|
82 |
-
|
83 |
-
self.
|
|
|
84 |
|
85 |
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
86 |
if (epoch + 1) % 5 == 0:
|
|
|
6 |
from torch.cuda.amp import GradScaler, autocast
|
7 |
|
8 |
from yolo.config.config import Config, TrainConfig
|
9 |
+
from yolo.model.yolo import get_model
|
10 |
from yolo.tools.log_helper import CustomProgress
|
11 |
from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
|
12 |
from yolo.utils.loss import get_loss_function
|
13 |
|
14 |
|
15 |
class Trainer:
|
16 |
+
def __init__(self, cfg: Config, save_path: str, device):
|
17 |
train_cfg: TrainConfig = cfg.hyper.train
|
18 |
+
model = get_model(cfg)
|
19 |
|
20 |
self.model = model.to(device)
|
21 |
self.device = device
|
22 |
+
self.optimizer = get_optimizer(model, train_cfg.optimizer)
|
23 |
self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
|
24 |
self.loss_fn = get_loss_function(cfg)
|
25 |
+
self.progress = CustomProgress(cfg, save_path, use_wandb=True)
|
26 |
|
27 |
if getattr(train_cfg.ema, "enabled", False):
|
28 |
self.ema = EMA(model, decay=train_cfg.ema.decay)
|
|
|
47 |
def train_one_epoch(self, dataloader):
|
48 |
self.model.train()
|
49 |
total_loss = 0
|
|
|
50 |
|
51 |
for data, targets in dataloader:
|
52 |
loss, loss_each = self.train_one_batch(data, targets)
|
|
|
57 |
if self.scheduler:
|
58 |
self.scheduler.step()
|
59 |
|
|
|
60 |
return total_loss / len(dataloader)
|
61 |
|
62 |
def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
|
|
|
78 |
self.progress.start_train(num_epochs)
|
79 |
for epoch in range(num_epochs):
|
80 |
|
81 |
+
self.progress.start_one_epoch(len(dataloader), self.optimizer, epoch)
|
82 |
+
epoch_loss = self.train_one_epoch(dataloader)
|
83 |
+
self.progress.finish_one_epoch()
|
84 |
|
85 |
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
86 |
if (epoch + 1) % 5 == 0:
|
yolo/utils/dataloader.py
CHANGED
@@ -7,6 +7,7 @@ import numpy as np
|
|
7 |
import torch
|
8 |
from loguru import logger
|
9 |
from PIL import Image
|
|
|
10 |
from torch.utils.data import DataLoader, Dataset
|
11 |
from torchvision.transforms import functional as TF
|
12 |
from tqdm.rich import tqdm
|
@@ -74,7 +75,7 @@ class YoloDataset(Dataset):
|
|
74 |
|
75 |
data = []
|
76 |
valid_inputs = 0
|
77 |
-
for image_name in
|
78 |
if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
|
79 |
continue
|
80 |
image_id, _ = path.splitext(image_name)
|
@@ -159,7 +160,7 @@ class YoloDataLoader(DataLoader):
|
|
159 |
dataset,
|
160 |
batch_size=hyper.batch_size,
|
161 |
shuffle=hyper.shuffle,
|
162 |
-
num_workers=hyper.
|
163 |
pin_memory=hyper.pin_memory,
|
164 |
collate_fn=self.collate_fn,
|
165 |
)
|
|
|
7 |
import torch
|
8 |
from loguru import logger
|
9 |
from PIL import Image
|
10 |
+
from rich.progress import track
|
11 |
from torch.utils.data import DataLoader, Dataset
|
12 |
from torchvision.transforms import functional as TF
|
13 |
from tqdm.rich import tqdm
|
|
|
75 |
|
76 |
data = []
|
77 |
valid_inputs = 0
|
78 |
+
for image_name in track(images_list, description="Filtering data"):
|
79 |
if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
|
80 |
continue
|
81 |
image_id, _ = path.splitext(image_name)
|
|
|
160 |
dataset,
|
161 |
batch_size=hyper.batch_size,
|
162 |
shuffle=hyper.shuffle,
|
163 |
+
num_workers=config.hyper.general.cpu_num,
|
164 |
pin_memory=hyper.pin_memory,
|
165 |
collate_fn=self.collate_fn,
|
166 |
)
|