henry000 commited on
Commit
8f0b970
·
1 Parent(s): d009076

✨ [New] inference code and refactor train example

Browse files
examples/example_inference.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ import hydra
5
+ import torch
6
+
7
+ project_root = Path(__file__).resolve().parent.parent
8
+ sys.path.append(str(project_root))
9
+
10
+ from yolo.config.config import Config
11
+ from yolo.model.yolo import get_model
12
+ from yolo.tools.data_loader import create_dataloader
13
+ from yolo.tools.solver import ModelTester
14
+ from yolo.utils.logging_utils import custom_logger, validate_log_directory
15
+
16
+
17
+ @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
18
+ def main(cfg: Config):
19
+ custom_logger()
20
+ save_path = validate_log_directory(cfg, cfg.name)
21
+
22
+ device = torch.device(cfg.device)
23
+ model = get_model(cfg).to(device)
24
+
25
+ save_path = validate_log_directory(cfg, cfg.name)
26
+ dataloader = create_dataloader(cfg)
27
+ device = torch.device(cfg.device)
28
+ model = get_model(cfg).to(device)
29
+
30
+ tester = ModelTester(cfg, model, save_path, device)
31
+ tester.solve(dataloader, cfg.task.epoch)
32
+
33
+
34
+ if __name__ == "__main__":
35
+ main()
examples/example_train.py CHANGED
@@ -3,30 +3,28 @@ from pathlib import Path
3
 
4
  import hydra
5
  import torch
6
- from loguru import logger
7
 
8
  project_root = Path(__file__).resolve().parent.parent
9
  sys.path.append(str(project_root))
10
 
11
  from yolo.config.config import Config
 
12
  from yolo.tools.data_loader import create_dataloader
13
- from yolo.tools.dataset_preparation import prepare_dataset
14
- from yolo.tools.trainer import ModelTrainer
15
  from yolo.utils.logging_utils import custom_logger, validate_log_directory
16
 
17
 
18
  @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
19
  def main(cfg: Config):
20
  custom_logger()
21
- save_path = validate_log_directory(cfg.hyper.general, cfg.name)
22
- if cfg.download.auto:
23
- prepare_dataset(cfg.download)
24
-
25
  dataloader = create_dataloader(cfg)
26
  # TODO: get_device or rank, for DDP mode
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- trainer = ModelTrainer(cfg, save_path, device)
29
- trainer.train(dataloader, cfg.hyper.train.epoch)
 
 
30
 
31
 
32
  if __name__ == "__main__":
 
3
 
4
  import hydra
5
  import torch
 
6
 
7
  project_root = Path(__file__).resolve().parent.parent
8
  sys.path.append(str(project_root))
9
 
10
  from yolo.config.config import Config
11
+ from yolo.model.yolo import get_model
12
  from yolo.tools.data_loader import create_dataloader
13
+ from yolo.tools.solver import ModelTrainer
 
14
  from yolo.utils.logging_utils import custom_logger, validate_log_directory
15
 
16
 
17
  @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
18
  def main(cfg: Config):
19
  custom_logger()
20
+ save_path = validate_log_directory(cfg, cfg.name)
 
 
 
21
  dataloader = create_dataloader(cfg)
22
  # TODO: get_device or rank, for DDP mode
23
+ device = torch.device(cfg.device)
24
+ model = get_model(cfg).to(device)
25
+
26
+ trainer = ModelTrainer(cfg, model, save_path, device)
27
+ trainer.solve(dataloader, cfg.task.epoch)
28
 
29
 
30
  if __name__ == "__main__":