Feng Wang commited on
Commit
0511196
·
1 Parent(s): 6ea9687

feat(exp): get_trainer method, add pre-commit (#1263)

Browse files
.pre-commit-config.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pycqa/flake8
3
+ rev: 3.8.3
4
+ hooks:
5
+ - id: flake8
6
+ - repo: https://github.com/pre-commit/pre-commit-hooks
7
+ rev: v3.1.0
8
+ hooks:
9
+ - id: check-added-large-files
10
+ - id: check-docstring-first
11
+ - id: check-executables-have-shebangs
12
+ - id: check-json
13
+ - id: check-yaml
14
+ args: ["--unsafe"]
15
+ - id: debug-statements
16
+ - id: end-of-file-fixer
17
+ - id: requirements-txt-fixer
18
+ - id: trailing-whitespace
19
+ - repo: https://github.com/jorisroovers/gitlint
20
+ rev: v0.15.1
21
+ hooks:
22
+ - id: gitlint
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 4.3.21
25
+ hooks:
26
+ - id: isort
27
+
28
+ - repo: https://github.com/PyCQA/autoflake
29
+ rev: v1.4
30
+ hooks:
31
+ - id: autoflake
32
+ name: Remove unused variables and imports
33
+ entry: autoflake
34
+ language: python
35
+ args:
36
+ [
37
+ "--in-place",
38
+ "--remove-all-unused-imports",
39
+ "--remove-unused-variables",
40
+ "--expand-star-imports",
41
+ "--ignore-init-module-imports",
42
+ ]
43
+ files: \.py$
tools/train.py CHANGED
@@ -10,8 +10,8 @@ from loguru import logger
10
  import torch
11
  import torch.backends.cudnn as cudnn
12
 
13
- from yolox.core import Trainer, launch
14
- from yolox.exp import get_exp
15
  from yolox.utils import configure_module, configure_nccl, configure_omp, get_num_devices
16
 
17
 
@@ -97,7 +97,7 @@ def make_parser():
97
 
98
 
99
  @logger.catch
100
- def main(exp, args):
101
  if exp.seed is not None:
102
  random.seed(exp.seed)
103
  torch.manual_seed(exp.seed)
@@ -113,7 +113,7 @@ def main(exp, args):
113
  configure_omp()
114
  cudnn.benchmark = True
115
 
116
- trainer = Trainer(exp, args)
117
  trainer.train()
118
 
119
 
 
10
  import torch
11
  import torch.backends.cudnn as cudnn
12
 
13
+ from yolox.core import launch
14
+ from yolox.exp import Exp, get_exp
15
  from yolox.utils import configure_module, configure_nccl, configure_omp, get_num_devices
16
 
17
 
 
97
 
98
 
99
  @logger.catch
100
+ def main(exp: Exp, args):
101
  if exp.seed is not None:
102
  random.seed(exp.seed)
103
  torch.manual_seed(exp.seed)
 
113
  configure_omp()
114
  cudnn.benchmark = True
115
 
116
+ trainer = exp.get_trainer(args)
117
  trainer.train()
118
 
119
 
yolox/core/trainer.py CHANGED
@@ -12,6 +12,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
12
  from torch.utils.tensorboard import SummaryWriter
13
 
14
  from yolox.data import DataPrefetcher
 
15
  from yolox.utils import (
16
  MeterBuffer,
17
  ModelEMA,
@@ -33,7 +34,7 @@ from yolox.utils import (
33
 
34
 
35
  class Trainer:
36
- def __init__(self, exp, args):
37
  # init function only defines some basic attr, other attrs like model, optimizer are built in
38
  # before_train methods.
39
  self.exp = exp
 
12
  from torch.utils.tensorboard import SummaryWriter
13
 
14
  from yolox.data import DataPrefetcher
15
+ from yolox.exp import Exp
16
  from yolox.utils import (
17
  MeterBuffer,
18
  ModelEMA,
 
34
 
35
 
36
  class Trainer:
37
+ def __init__(self, exp: Exp, args):
38
  # init function only defines some basic attr, other attrs like model, optimizer are built in
39
  # before_train methods.
40
  self.exp = exp
yolox/exp/yolox_base.py CHANGED
@@ -127,9 +127,7 @@ class Exp(BaseExp):
127
  self.model.train()
128
  return self.model
129
 
130
- def get_data_loader(
131
- self, batch_size, is_distributed, no_aug=False, cache_img=False
132
- ):
133
  from yolox.data import (
134
  COCODataset,
135
  TrainTransform,
@@ -314,5 +312,11 @@ class Exp(BaseExp):
314
  )
315
  return evaluator
316
 
 
 
 
 
 
 
317
  def eval(self, model, evaluator, is_distributed, half=False):
318
  return evaluator.evaluate(model, is_distributed, half)
 
127
  self.model.train()
128
  return self.model
129
 
130
+ def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False):
 
 
131
  from yolox.data import (
132
  COCODataset,
133
  TrainTransform,
 
312
  )
313
  return evaluator
314
 
315
+ def get_trainer(self, args):
316
+ from yolox.core import Trainer
317
+ trainer = Trainer(self, args)
318
+ # NOTE: trainer shouldn't be an attribute of exp object
319
+ return trainer
320
+
321
  def eval(self, model, evaluator, is_distributed, half=False):
322
  return evaluator.evaluate(model, is_distributed, half)