henry000 commited on
Commit
a0c7025
·
1 Parent(s): 7967aab

✨ [New] make it pip installable!

Browse files
Files changed (4) hide show
  1. pyproject.toml +11 -2
  2. yolo/__init__.py +0 -0
  3. yolo/config/__init__.py +0 -0
  4. yolo/lazy.py +35 -0
pyproject.toml CHANGED
@@ -1,16 +1,25 @@
1
  [project]
2
  name = "yolo"
3
- version = "0.0.0"
4
  dynamic = ["dependencies"]
5
 
6
  [tool.setuptools.dynamic]
7
  dependencies = {file = ["requirements.txt"]}
8
 
9
  [tool.setuptools.packages.find]
10
- where = ["yolo"]
 
 
 
 
 
 
11
 
12
  [build-system]
13
  build-backend = "setuptools.build_meta"
14
  requires = [
15
  "setuptools",
16
  ]
 
 
 
 
1
  [project]
2
  name = "yolo"
3
+ version = "0.1.0"
4
  dynamic = ["dependencies"]
5
 
6
  [tool.setuptools.dynamic]
7
  dependencies = {file = ["requirements.txt"]}
8
 
9
  [tool.setuptools.packages.find]
10
+ where = ["."]
11
+ include = ["yolo*"]
12
+
13
+ [tool.setuptools]
14
+ package-data = {"yolo" = ["**/*.yaml"]}
15
+ include-package-data = true
16
+
17
 
18
  [build-system]
19
  build-backend = "setuptools.build_meta"
20
  requires = [
21
  "setuptools",
22
  ]
23
+
24
+ [project.scripts]
25
+ yolo = "yolo.lazy:main"
yolo/__init__.py ADDED
File without changes
yolo/config/__init__.py ADDED
File without changes
yolo/lazy.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ import hydra
5
+ import torch
6
+
7
+ project_root = Path(__file__).resolve().parent.parent
8
+ sys.path.append(str(project_root))
9
+
10
+ from yolo.config.config import Config
11
+ from yolo.model.yolo import get_model
12
+ from yolo.tools.data_loader import create_dataloader
13
+ from yolo.tools.solver import ModelTester, ModelTrainer
14
+ from yolo.utils.logging_utils import custom_logger, validate_log_directory
15
+
16
+
17
+ @hydra.main(config_path="config", config_name="config", version_base=None)
18
+ def main(cfg: Config):
19
+ custom_logger()
20
+ save_path = validate_log_directory(cfg, cfg.name)
21
+ dataloader = create_dataloader(cfg)
22
+ device = torch.device(cfg.device)
23
+ model = get_model(cfg).to(device)
24
+
25
+ if cfg.task.task == "train":
26
+ trainer = ModelTrainer(cfg, model, save_path, device)
27
+ trainer.solve(dataloader)
28
+
29
+ if cfg.task.task == "inference":
30
+ tester = ModelTester(cfg, model, save_path, device)
31
+ tester.solve(dataloader)
32
+
33
+
34
+ if __name__ == "__main__":
35
+ main()