henry000 commited on
Commit
a33e03b
Β·
1 Parent(s): 67ec59a

🚚 [Rename] get_model to create, rename examples

Browse files
examples/lazy.py DELETED
@@ -1,37 +0,0 @@
1
- import sys
2
- from pathlib import Path
3
-
4
- import hydra
5
- import torch
6
-
7
- project_root = Path(__file__).resolve().parent.parent
8
- sys.path.append(str(project_root))
9
-
10
- from yolo.config.config import Config
11
- from yolo.model.yolo import get_model
12
- from yolo.tools.data_loader import create_dataloader
13
- from yolo.tools.solver import ModelTester, ModelTrainer
14
- from yolo.utils.logging_utils import custom_logger, validate_log_directory
15
-
16
-
17
- @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
18
- def main(cfg: Config):
19
- custom_logger()
20
-
21
- custom_logger()
22
- save_path = validate_log_directory(cfg, cfg.name)
23
- dataloader = create_dataloader(cfg)
24
- device = torch.device(cfg.device)
25
- model = get_model(cfg).to(device)
26
-
27
- if cfg.task.task == "train":
28
- trainer = ModelTrainer(cfg, model, save_path, device)
29
- trainer.solve(dataloader)
30
-
31
- if cfg.task.task == "inference":
32
- tester = ModelTester(cfg, model, save_path, device)
33
- tester.solve(dataloader)
34
-
35
-
36
- if __name__ == "__main__":
37
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/notebook_colab.ipynb ADDED
File without changes
examples/notebook_inference.ipynb ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": []
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": []
16
+ }
17
+ ],
18
+ "metadata": {
19
+ "language_info": {
20
+ "name": "python"
21
+ }
22
+ },
23
+ "nbformat": 4,
24
+ "nbformat_minor": 2
25
+ }
examples/notebook_train.ipynb ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": []
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": []
16
+ }
17
+ ],
18
+ "metadata": {
19
+ "language_info": {
20
+ "name": "python"
21
+ }
22
+ },
23
+ "nbformat": 4,
24
+ "nbformat_minor": 2
25
+ }
examples/{example_inference.py β†’ sample_inference.py} RENAMED
@@ -8,7 +8,7 @@ project_root = Path(__file__).resolve().parent.parent
8
  sys.path.append(str(project_root))
9
 
10
  from yolo.config.config import Config
11
- from yolo.model.yolo import get_model
12
  from yolo.tools.data_loader import create_dataloader
13
  from yolo.tools.solver import ModelTester
14
  from yolo.utils.logging_utils import custom_logger, validate_log_directory
@@ -17,15 +17,11 @@ from yolo.utils.logging_utils import custom_logger, validate_log_directory
17
  @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
18
  def main(cfg: Config):
19
  custom_logger()
20
- save_path = validate_log_directory(cfg, cfg.name)
21
-
22
- device = torch.device(cfg.device)
23
- model = get_model(cfg).to(device)
24
-
25
  save_path = validate_log_directory(cfg, cfg.name)
26
  dataloader = create_dataloader(cfg)
 
27
  device = torch.device(cfg.device)
28
- model = get_model(cfg).to(device)
29
 
30
  tester = ModelTester(cfg, model, save_path, device)
31
  tester.solve(dataloader)
 
8
  sys.path.append(str(project_root))
9
 
10
  from yolo.config.config import Config
11
+ from yolo.model.yolo import create_model
12
  from yolo.tools.data_loader import create_dataloader
13
  from yolo.tools.solver import ModelTester
14
  from yolo.utils.logging_utils import custom_logger, validate_log_directory
 
17
  @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
18
  def main(cfg: Config):
19
  custom_logger()
 
 
 
 
 
20
  save_path = validate_log_directory(cfg, cfg.name)
21
  dataloader = create_dataloader(cfg)
22
+
23
  device = torch.device(cfg.device)
24
+ model = create_model(cfg).to(device)
25
 
26
  tester = ModelTester(cfg, model, save_path, device)
27
  tester.solve(dataloader)
examples/{example_train.py β†’ sample_train.py} RENAMED
@@ -8,7 +8,7 @@ project_root = Path(__file__).resolve().parent.parent
8
  sys.path.append(str(project_root))
9
 
10
  from yolo.config.config import Config
11
- from yolo.model.yolo import get_model
12
  from yolo.tools.data_loader import create_dataloader
13
  from yolo.tools.solver import ModelTrainer
14
  from yolo.utils.logging_utils import custom_logger, validate_log_directory
@@ -21,7 +21,7 @@ def main(cfg: Config):
21
  dataloader = create_dataloader(cfg)
22
  # TODO: get_device or rank, for DDP mode
23
  device = torch.device(cfg.device)
24
- model = get_model(cfg).to(device)
25
 
26
  trainer = ModelTrainer(cfg, model, save_path, device)
27
  trainer.solve(dataloader, cfg.task.epoch)
 
8
  sys.path.append(str(project_root))
9
 
10
  from yolo.config.config import Config
11
+ from yolo.model.yolo import create_model
12
  from yolo.tools.data_loader import create_dataloader
13
  from yolo.tools.solver import ModelTrainer
14
  from yolo.utils.logging_utils import custom_logger, validate_log_directory
 
21
  dataloader = create_dataloader(cfg)
22
  # TODO: get_device or rank, for DDP mode
23
  device = torch.device(cfg.device)
24
+ model = create_model(cfg).to(device)
25
 
26
  trainer = ModelTrainer(cfg, model, save_path, device)
27
  trainer.solve(dataloader, cfg.task.epoch)
tests/test_model/test_yolo.py CHANGED
@@ -8,7 +8,7 @@ from omegaconf import OmegaConf
8
  project_root = Path(__file__).resolve().parent.parent.parent
9
  sys.path.append(str(project_root))
10
 
11
- from yolo.model.yolo import YOLO, get_model
12
 
13
  config_path = "../../yolo/config"
14
  config_name = "config"
@@ -24,18 +24,18 @@ def test_build_model():
24
  assert len(model.model) == 38
25
 
26
 
27
- def test_get_model():
28
  with initialize(config_path=config_path, version_base=None):
29
  cfg = compose(config_name=config_name)
30
  cfg.weight = None
31
- model = get_model(cfg)
32
  assert isinstance(model, YOLO)
33
 
34
 
35
  def test_yolo_forward_output_shape():
36
  with initialize(config_path=config_path, version_base=None):
37
  cfg = compose(config_name=config_name)
38
- model = get_model(cfg)
39
  # 2 - batch size, 3 - number of channels, 640x640 - image dimensions
40
  dummy_input = torch.rand(2, 3, 640, 640)
41
 
 
8
  project_root = Path(__file__).resolve().parent.parent.parent
9
  sys.path.append(str(project_root))
10
 
11
+ from yolo.model.yolo import YOLO, create_model
12
 
13
  config_path = "../../yolo/config"
14
  config_name = "config"
 
24
  assert len(model.model) == 38
25
 
26
 
27
+ def test_create_model():
28
  with initialize(config_path=config_path, version_base=None):
29
  cfg = compose(config_name=config_name)
30
  cfg.weight = None
31
+ model = create_model(cfg)
32
  assert isinstance(model, YOLO)
33
 
34
 
35
  def test_yolo_forward_output_shape():
36
  with initialize(config_path=config_path, version_base=None):
37
  cfg = compose(config_name=config_name)
38
+ model = create_model(cfg)
39
  # 2 - batch size, 3 - number of channels, 640x640 - image dimensions
40
  dummy_input = torch.rand(2, 3, 640, 640)
41
 
yolo/lazy.py CHANGED
@@ -8,7 +8,7 @@ project_root = Path(__file__).resolve().parent.parent
8
  sys.path.append(str(project_root))
9
 
10
  from yolo.config.config import Config
11
- from yolo.model.yolo import get_model
12
  from yolo.tools.data_loader import create_dataloader
13
  from yolo.tools.solver import ModelTester, ModelTrainer
14
  from yolo.utils.logging_utils import custom_logger, validate_log_directory
@@ -20,7 +20,7 @@ def main(cfg: Config):
20
  save_path = validate_log_directory(cfg, cfg.name)
21
  dataloader = create_dataloader(cfg)
22
  device = torch.device(cfg.device)
23
- model = get_model(cfg).to(device)
24
 
25
  if cfg.task.task == "train":
26
  trainer = ModelTrainer(cfg, model, save_path, device)
 
8
  sys.path.append(str(project_root))
9
 
10
  from yolo.config.config import Config
11
+ from yolo.model.yolo import create_model
12
  from yolo.tools.data_loader import create_dataloader
13
  from yolo.tools.solver import ModelTester, ModelTrainer
14
  from yolo.utils.logging_utils import custom_logger, validate_log_directory
 
20
  save_path = validate_log_directory(cfg, cfg.name)
21
  dataloader = create_dataloader(cfg)
22
  device = torch.device(cfg.device)
23
+ model = create_model(cfg).to(device)
24
 
25
  if cfg.task.task == "train":
26
  trainer = ModelTrainer(cfg, model, save_path, device)
yolo/model/yolo.py CHANGED
@@ -116,7 +116,7 @@ class YOLO(nn.Module):
116
  raise ValueError(f"Unsupported layer type: {layer_type}")
117
 
118
 
119
- def get_model(cfg: Config) -> YOLO:
120
  """Constructs and returns a model from a Dictionary configuration file.
121
 
122
  Args:
 
116
  raise ValueError(f"Unsupported layer type: {layer_type}")
117
 
118
 
119
+ def create_model(cfg: Config) -> YOLO:
120
  """Constructs and returns a model from a Dictionary configuration file.
121
 
122
  Args:
yolo/tools/drawer.py CHANGED
@@ -14,6 +14,7 @@ def draw_bboxes(
14
  *,
15
  scaled_bbox: bool = True,
16
  save_path: str = "",
 
17
  ):
18
  """
19
  Draw bounding boxes on an image.
@@ -46,7 +47,7 @@ def draw_bboxes(
46
  draw.rectangle(shape, outline="red", width=3)
47
  draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
48
 
49
- save_image_path = os.path.join(save_path, "visualize.png")
50
  img.save(save_image_path) # Save the image with annotations
51
  logger.info(f"πŸ’Ύ Saved visualize image at {save_image_path}")
52
  return img
@@ -56,9 +57,9 @@ def draw_model(*, model_cfg=None, model=None, v7_base=False):
56
  from graphviz import Digraph
57
 
58
  if model_cfg:
59
- from yolo.model.yolo import get_model
60
 
61
- model = get_model(model_cfg)
62
  elif model is None:
63
  raise ValueError("Drawing Object is None")
64
 
 
14
  *,
15
  scaled_bbox: bool = True,
16
  save_path: str = "",
17
+ save_name: str = "visualize.png",
18
  ):
19
  """
20
  Draw bounding boxes on an image.
 
47
  draw.rectangle(shape, outline="red", width=3)
48
  draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
49
 
50
+ save_image_path = os.path.join(save_path, save_name)
51
  img.save(save_image_path) # Save the image with annotations
52
  logger.info(f"πŸ’Ύ Saved visualize image at {save_image_path}")
53
  return img
 
57
  from graphviz import Digraph
58
 
59
  if model_cfg:
60
+ from yolo.model.yolo import create_model
61
 
62
+ model = create_model(model_cfg)
63
  elif model is None:
64
  raise ValueError("Drawing Object is None")
65