Feng Wang
commited on
Commit
·
0511196
1
Parent(s):
6ea9687
feat(exp): get_trainer method, add pre-commit (#1263)
Browse files- .pre-commit-config.yaml +43 -0
- tools/train.py +4 -4
- yolox/core/trainer.py +2 -1
- yolox/exp/yolox_base.py +7 -3
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/pycqa/flake8
|
3 |
+
rev: 3.8.3
|
4 |
+
hooks:
|
5 |
+
- id: flake8
|
6 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
7 |
+
rev: v3.1.0
|
8 |
+
hooks:
|
9 |
+
- id: check-added-large-files
|
10 |
+
- id: check-docstring-first
|
11 |
+
- id: check-executables-have-shebangs
|
12 |
+
- id: check-json
|
13 |
+
- id: check-yaml
|
14 |
+
args: ["--unsafe"]
|
15 |
+
- id: debug-statements
|
16 |
+
- id: end-of-file-fixer
|
17 |
+
- id: requirements-txt-fixer
|
18 |
+
- id: trailing-whitespace
|
19 |
+
- repo: https://github.com/jorisroovers/gitlint
|
20 |
+
rev: v0.15.1
|
21 |
+
hooks:
|
22 |
+
- id: gitlint
|
23 |
+
- repo: https://github.com/pycqa/isort
|
24 |
+
rev: 4.3.21
|
25 |
+
hooks:
|
26 |
+
- id: isort
|
27 |
+
|
28 |
+
- repo: https://github.com/PyCQA/autoflake
|
29 |
+
rev: v1.4
|
30 |
+
hooks:
|
31 |
+
- id: autoflake
|
32 |
+
name: Remove unused variables and imports
|
33 |
+
entry: autoflake
|
34 |
+
language: python
|
35 |
+
args:
|
36 |
+
[
|
37 |
+
"--in-place",
|
38 |
+
"--remove-all-unused-imports",
|
39 |
+
"--remove-unused-variables",
|
40 |
+
"--expand-star-imports",
|
41 |
+
"--ignore-init-module-imports",
|
42 |
+
]
|
43 |
+
files: \.py$
|
tools/train.py
CHANGED
@@ -10,8 +10,8 @@ from loguru import logger
|
|
10 |
import torch
|
11 |
import torch.backends.cudnn as cudnn
|
12 |
|
13 |
-
from yolox.core import
|
14 |
-
from yolox.exp import get_exp
|
15 |
from yolox.utils import configure_module, configure_nccl, configure_omp, get_num_devices
|
16 |
|
17 |
|
@@ -97,7 +97,7 @@ def make_parser():
|
|
97 |
|
98 |
|
99 |
@logger.catch
|
100 |
-
def main(exp, args):
|
101 |
if exp.seed is not None:
|
102 |
random.seed(exp.seed)
|
103 |
torch.manual_seed(exp.seed)
|
@@ -113,7 +113,7 @@ def main(exp, args):
|
|
113 |
configure_omp()
|
114 |
cudnn.benchmark = True
|
115 |
|
116 |
-
trainer =
|
117 |
trainer.train()
|
118 |
|
119 |
|
|
|
10 |
import torch
|
11 |
import torch.backends.cudnn as cudnn
|
12 |
|
13 |
+
from yolox.core import launch
|
14 |
+
from yolox.exp import Exp, get_exp
|
15 |
from yolox.utils import configure_module, configure_nccl, configure_omp, get_num_devices
|
16 |
|
17 |
|
|
|
97 |
|
98 |
|
99 |
@logger.catch
|
100 |
+
def main(exp: Exp, args):
|
101 |
if exp.seed is not None:
|
102 |
random.seed(exp.seed)
|
103 |
torch.manual_seed(exp.seed)
|
|
|
113 |
configure_omp()
|
114 |
cudnn.benchmark = True
|
115 |
|
116 |
+
trainer = exp.get_trainer(args)
|
117 |
trainer.train()
|
118 |
|
119 |
|
yolox/core/trainer.py
CHANGED
@@ -12,6 +12,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|
12 |
from torch.utils.tensorboard import SummaryWriter
|
13 |
|
14 |
from yolox.data import DataPrefetcher
|
|
|
15 |
from yolox.utils import (
|
16 |
MeterBuffer,
|
17 |
ModelEMA,
|
@@ -33,7 +34,7 @@ from yolox.utils import (
|
|
33 |
|
34 |
|
35 |
class Trainer:
|
36 |
-
def __init__(self, exp, args):
|
37 |
# init function only defines some basic attr, other attrs like model, optimizer are built in
|
38 |
# before_train methods.
|
39 |
self.exp = exp
|
|
|
12 |
from torch.utils.tensorboard import SummaryWriter
|
13 |
|
14 |
from yolox.data import DataPrefetcher
|
15 |
+
from yolox.exp import Exp
|
16 |
from yolox.utils import (
|
17 |
MeterBuffer,
|
18 |
ModelEMA,
|
|
|
34 |
|
35 |
|
36 |
class Trainer:
|
37 |
+
def __init__(self, exp: Exp, args):
|
38 |
# init function only defines some basic attr, other attrs like model, optimizer are built in
|
39 |
# before_train methods.
|
40 |
self.exp = exp
|
yolox/exp/yolox_base.py
CHANGED
@@ -127,9 +127,7 @@ class Exp(BaseExp):
|
|
127 |
self.model.train()
|
128 |
return self.model
|
129 |
|
130 |
-
def get_data_loader(
|
131 |
-
self, batch_size, is_distributed, no_aug=False, cache_img=False
|
132 |
-
):
|
133 |
from yolox.data import (
|
134 |
COCODataset,
|
135 |
TrainTransform,
|
@@ -314,5 +312,11 @@ class Exp(BaseExp):
|
|
314 |
)
|
315 |
return evaluator
|
316 |
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
def eval(self, model, evaluator, is_distributed, half=False):
|
318 |
return evaluator.evaluate(model, is_distributed, half)
|
|
|
127 |
self.model.train()
|
128 |
return self.model
|
129 |
|
130 |
+
def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False):
|
|
|
|
|
131 |
from yolox.data import (
|
132 |
COCODataset,
|
133 |
TrainTransform,
|
|
|
312 |
)
|
313 |
return evaluator
|
314 |
|
315 |
+
def get_trainer(self, args):
|
316 |
+
from yolox.core import Trainer
|
317 |
+
trainer = Trainer(self, args)
|
318 |
+
# NOTE: trainer shouldn't be an attribute of exp object
|
319 |
+
return trainer
|
320 |
+
|
321 |
def eval(self, model, evaluator, is_distributed, half=False):
|
322 |
return evaluator.evaluate(model, is_distributed, half)
|