henry000 commited on
Commit
6a39ae1
·
1 Parent(s): 9867c3f

✅ [Add] test, increase test coverage for dev mode

Browse files
tests/conftest.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ import pytest
5
+ import torch
6
+ from hydra import compose, initialize
7
+
8
+ project_root = Path(__file__).resolve().parent.parent
9
+ sys.path.append(str(project_root))
10
+
11
+ from yolo import Config, Vec2Box, create_model
12
+ from yolo.model.yolo import YOLO
13
+ from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader
14
+ from yolo.utils.logging_utils import ProgressLogger, set_seed
15
+
16
+
17
+ def pytest_configure(config):
18
+ config.addinivalue_line("markers", "requires_cuda: mark test to run only if CUDA is available")
19
+
20
+
21
+ def get_cfg(overrides=[]) -> Config:
22
+ config_path = "../yolo/config"
23
+ with initialize(config_path=config_path, version_base=None):
24
+ cfg: Config = compose(config_name="config", overrides=overrides)
25
+ set_seed(cfg.lucky_number)
26
+ return cfg
27
+
28
+
29
+ @pytest.fixture(scope="session")
30
+ def train_cfg() -> Config:
31
+ return get_cfg(overrides=["task=train", "dataset=mock"])
32
+
33
+
34
+ @pytest.fixture(scope="session")
35
+ def validation_cfg():
36
+ return get_cfg(overrides=["task=validation", "dataset=mock"])
37
+
38
+
39
+ @pytest.fixture(scope="session")
40
+ def inference_cfg():
41
+ return get_cfg(overrides=["task=inference"])
42
+
43
+
44
+ @pytest.fixture(scope="session")
45
+ def device():
46
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+
48
+
49
+ @pytest.fixture(scope="session")
50
+ def train_progress_logger(train_cfg: Config):
51
+ progress_logger = ProgressLogger(train_cfg, exp_name=train_cfg.name)
52
+ return progress_logger
53
+
54
+
55
+ @pytest.fixture(scope="session")
56
+ def validation_progress_logger(validation_cfg: Config):
57
+ progress_logger = ProgressLogger(validation_cfg, exp_name=validation_cfg.name)
58
+ return progress_logger
59
+
60
+
61
+ @pytest.fixture(scope="session")
62
+ def model(train_cfg: Config, device) -> YOLO:
63
+ model = create_model(train_cfg.model)
64
+ return model.to(device)
65
+
66
+
67
+ @pytest.fixture(scope="session")
68
+ def vec2box(train_cfg: Config, model: YOLO, device) -> Vec2Box:
69
+ vec2box = Vec2Box(model, train_cfg.image_size, device)
70
+ return vec2box
71
+
72
+
73
+ @pytest.fixture(scope="session")
74
+ def train_dataloader(train_cfg: Config):
75
+ return YoloDataLoader(train_cfg.task.data, train_cfg.dataset, train_cfg.task.task)
76
+
77
+
78
+ @pytest.fixture(scope="session")
79
+ def validation_dataloader(validation_cfg: Config):
80
+ return YoloDataLoader(validation_cfg.task.data, validation_cfg.dataset, validation_cfg.task.task)
81
+
82
+
83
+ @pytest.fixture(scope="session")
84
+ def file_stream_data_loader(inference_cfg: Config):
85
+ return StreamDataLoader(inference_cfg.task.data)
86
+
87
+
88
+ @pytest.fixture(scope="session")
89
+ def directory_stream_data_loader(inference_cfg: Config):
90
+ inference_cfg.task.data.source = "tests/data/images/train"
91
+ return StreamDataLoader(inference_cfg.task.data)
tests/test_tools/test_data_loader.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ import pytest
5
+
6
+ project_root = Path(__file__).resolve().parent.parent.parent
7
+ sys.path.append(str(project_root))
8
+
9
+ from yolo.config.config import Config, TrainConfig
10
+ from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader, create_dataloader
11
+
12
+
13
+ def test_create_dataloader_cache(train_cfg: Config):
14
+ train_cfg.task.data.shuffle = False
15
+ train_cfg.task.data.batch_size = 2
16
+
17
+ cache_file = Path("tests/data/train.cache")
18
+ cache_file.unlink(missing_ok=True)
19
+
20
+ make_cache_loader = create_dataloader(train_cfg.task.data, train_cfg.dataset)
21
+ load_cache_loader = create_dataloader(train_cfg.task.data, train_cfg.dataset)
22
+ m_batch_size, m_images, _, m_reverse_tensors, m_image_paths = next(iter(make_cache_loader))
23
+ l_batch_size, l_images, _, l_reverse_tensors, l_image_paths = next(iter(load_cache_loader))
24
+ assert m_batch_size == l_batch_size
25
+ assert m_images.shape == l_images.shape
26
+ assert m_reverse_tensors.shape == l_reverse_tensors.shape
27
+ assert m_image_paths == l_image_paths
28
+
29
+
30
+ def test_training_data_loader_correctness(train_dataloader: YoloDataLoader):
31
+ """Test that the training data loader produces correctly shaped data and metadata."""
32
+ batch_size, images, _, reverse_tensors, image_paths = next(iter(train_dataloader))
33
+ assert batch_size == 2
34
+ assert images.shape == (2, 3, 640, 640)
35
+ assert reverse_tensors.shape == (2, 5)
36
+ expected_paths = [
37
+ Path("tests/data/images/train/000000050725.jpg"),
38
+ Path("tests/data/images/train/000000167848.jpg"),
39
+ ]
40
+ assert image_paths == expected_paths
41
+
42
+
43
+ def test_validation_data_loader_correctness(validation_dataloader: YoloDataLoader):
44
+ batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader))
45
+ assert batch_size == 4
46
+ assert images.shape == (4, 3, 640, 640)
47
+ assert targets.shape == (4, 18, 5)
48
+ assert reverse_tensors.shape == (4, 5)
49
+ expected_paths = [
50
+ Path("tests/data/images/val/000000151480.jpg"),
51
+ Path("tests/data/images/val/000000284106.jpg"),
52
+ Path("tests/data/images/val/000000323571.jpg"),
53
+ Path("tests/data/images/val/000000570456.jpg"),
54
+ ]
55
+ assert image_paths == expected_paths
56
+
57
+
58
+ def test_file_stream_data_loader_frame(file_stream_data_loader: StreamDataLoader):
59
+ """Test the frame output from the file stream data loader."""
60
+ frame, rev_tensor, origin_frame = next(iter(file_stream_data_loader))
61
+ assert frame.shape == (1, 3, 640, 640)
62
+ assert rev_tensor.shape == (1, 5)
63
+ assert origin_frame.size == (1024, 768)
64
+
65
+
66
+ def test_directory_stream_data_loader_frame(directory_stream_data_loader: StreamDataLoader):
67
+ """Test the frame output from the directory stream data loader."""
68
+ frame, rev_tensor, origin_frame = next(iter(directory_stream_data_loader))
69
+ assert frame.shape == (1, 3, 640, 640)
70
+ assert rev_tensor.shape == (1, 5)
71
+ assert origin_frame.size == (480, 640)
tests/test_tools/test_dataset_preparation.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ project_root = Path(__file__).resolve().parent.parent.parent
7
+ sys.path.append(str(project_root))
8
+
9
+ from yolo.config.config import Config
10
+ from yolo.tools.dataset_preparation import prepare_dataset, prepare_weight
11
+
12
+
13
+ def test_prepare_dataset(train_cfg: Config):
14
+ dataset_path = Path("tests/data")
15
+ if dataset_path.exists():
16
+ shutil.rmtree(dataset_path)
17
+ prepare_dataset(train_cfg.dataset, task="train")
18
+ prepare_dataset(train_cfg.dataset, task="val")
19
+
20
+ images_path = Path("tests/data/images")
21
+ for data_type in images_path.iterdir():
22
+ assert len(os.listdir(data_type)) == 5
23
+
24
+ annotations_path = Path("tests/data/annotations")
25
+ assert os.listdir(annotations_path) == ["instances_val.json", "instances_train.json"]
26
+
27
+
28
+ def test_prepare_weight():
29
+ prepare_weight()
tests/test_tools/test_drawer.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ from PIL import Image
5
+ from torch import tensor
6
+
7
+ project_root = Path(__file__).resolve().parent.parent.parent
8
+ sys.path.append(str(project_root))
9
+
10
+ from yolo.config.config import Config
11
+ from yolo.model.yolo import YOLO
12
+ from yolo.tools.drawer import draw_bboxes, draw_model
13
+
14
+
15
+ def test_draw_model_by_config(train_cfg: Config):
16
+ """Test the drawing of a model based on a configuration."""
17
+ draw_model(model_cfg=train_cfg.model)
18
+
19
+
20
+ def test_draw_model_by_model(model: YOLO):
21
+ """Test the drawing of a YOLO model."""
22
+ draw_model(model=model)
23
+
24
+
25
+ def test_draw_bboxes():
26
+ """Test drawing bounding boxes on an image."""
27
+ predictions = tensor([[0, 60, 60, 160, 160, 0.5], [0, 40, 40, 120, 120, 0.5]])
28
+ pil_image = Image.open("tests/data/images/train/000000050725.jpg")
29
+ draw_bboxes(pil_image, [predictions])
tests/test_tools/test_solver.py CHANGED
@@ -1,114 +1,70 @@
1
  import sys
2
  from pathlib import Path
3
- from unittest.mock import MagicMock, patch
4
 
5
  import pytest
6
- import torch
7
- from hydra import compose, initialize
8
 
9
  project_root = Path(__file__).resolve().parent.parent.parent
10
  sys.path.append(str(project_root))
11
 
12
- from yolo.config.config import (
13
- Config,
14
- DataConfig,
15
- LossConfig,
16
- TrainConfig,
17
- ValidationConfig,
18
- )
19
- from yolo.model.yolo import YOLO, create_model
20
- from yolo.tools.data_loader import create_dataloader
21
- from yolo.tools.loss_functions import create_loss_function
22
- from yolo.tools.solver import ( # Adjust the import to your module
23
- ModelTester,
24
- ModelTrainer,
25
- ModelValidator,
26
- )
27
  from yolo.utils.bounding_box_utils import Vec2Box
28
- from yolo.utils.logging_utils import ProgressLogger
29
- from yolo.utils.model_utils import (
30
- ExponentialMovingAverage,
31
- create_optimizer,
32
- create_scheduler,
33
- )
34
 
35
 
36
  @pytest.fixture
37
- def cfg() -> Config:
38
- with initialize(config_path="../../yolo/config", version_base=None):
39
- cfg: Config = compose(config_name="config")
40
- cfg.weight = None
41
- return cfg
42
 
43
 
44
- @pytest.fixture
45
- def cfg_validaion() -> Config:
46
- with initialize(config_path="../../yolo/config", version_base=None):
47
- cfg: Config = compose(config_name="config", overrides=["task=validation"])
48
- cfg.weight = None
49
- return cfg
50
 
51
 
52
- @pytest.fixture
53
- def cfg_inference() -> Config:
54
- with initialize(config_path="../../yolo/config", version_base=None):
55
- cfg: Config = compose(config_name="config", overrides=["task=inference"])
56
- cfg.weight = None
57
- return cfg
58
 
59
 
60
  @pytest.fixture
61
- def device() -> torch.device:
62
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
- return device
64
 
65
 
66
- @pytest.fixture
67
- def model(cfg: Config, device) -> YOLO:
68
- model = create_model(cfg.model, weight_path=None)
69
- return model.to(device)
70
 
71
 
72
- @pytest.fixture
73
- def vec2box(cfg: Config, model: YOLO, device) -> Vec2Box:
74
- model = create_model(cfg.model, weight_path=None).to(device)
75
- vec2box = Vec2Box(model, cfg.image_size, device)
76
- return vec2box
77
 
78
 
79
  @pytest.fixture
80
- def progress_logger(cfg: Config):
81
- progress_logger = ProgressLogger(cfg, exp_name=cfg.name)
82
- return progress_logger
83
-
84
-
85
- # def test_model_trainer_initialization(cfg: Config, model: YOLO, vec2box: Vec2Box, progress_logger, device):
86
- # trainer = ModelTrainer(cfg, model, vec2box, progress_logger, device, use_ddp=False)
87
- # assert trainer.model == model
88
- # assert trainer.device == device
89
- # assert trainer.optimizer is not None
90
- # assert trainer.scheduler is not None
91
- # assert trainer.loss_fn is not None
92
- # assert trainer.progress == progress_logger
93
-
94
 
95
- # def test_model_trainer_train_one_batch(config, model, vec2box, progress_logger, device):
96
- # trainer = ModelTrainer(config, model, vec2box, progress_logger, device, use_ddp=False)
97
- # images = torch.rand(1, 3, 224, 224)
98
- # targets = torch.rand(1, 5)
99
- # loss_item = trainer.train_one_batch(images, targets)
100
- # assert isinstance(loss_item, dict)
101
 
 
102
 
103
- def test_model_validator_initialization(cfg_validaion: Config, model: YOLO, vec2box: Vec2Box, progress_logger, device):
104
- validator = ModelValidator(cfg_validaion.task, cfg_validaion.dataset, model, vec2box, progress_logger, device)
105
- assert validator.model == model
106
- assert validator.device == device
107
- assert validator.progress == progress_logger
108
 
109
 
110
- def test_model_tester_initialization(cfg_inference: Config, model: YOLO, vec2box: Vec2Box, progress_logger, device):
111
- tester = ModelTester(cfg_inference, model, vec2box, progress_logger, device)
112
- assert tester.model == model
113
- assert tester.device == device
114
- assert tester.progress == progress_logger
 
1
  import sys
2
  from pathlib import Path
 
3
 
4
  import pytest
5
+ from torch import allclose, tensor
 
6
 
7
  project_root = Path(__file__).resolve().parent.parent.parent
8
  sys.path.append(str(project_root))
9
 
10
+ from yolo.config.config import Config
11
+ from yolo.model.yolo import YOLO
12
+ from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader
13
+ from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
 
 
 
 
 
 
 
 
 
 
 
14
  from yolo.utils.bounding_box_utils import Vec2Box
 
 
 
 
 
 
15
 
16
 
17
  @pytest.fixture
18
+ def model_validator(validation_cfg: Config, model: YOLO, vec2box: Vec2Box, validation_progress_logger, device):
19
+ validator = ModelValidator(
20
+ validation_cfg.task, validation_cfg.dataset, model, vec2box, validation_progress_logger, device
21
+ )
22
+ return validator
23
 
24
 
25
+ def test_model_validator_initialization(model_validator: ModelValidator):
26
+ assert isinstance(model_validator.model, YOLO)
27
+ assert hasattr(model_validator, "solve")
 
 
 
28
 
29
 
30
+ def test_model_validator_solve_mock_dataset(model_validator: ModelValidator, validation_dataloader: YoloDataLoader):
31
+ mAPs = model_validator.solve(validation_dataloader)
32
+ except_mAPs = {"mAP.5": tensor(0.6969), "mAP.5:.95": tensor(0.4195)}
33
+ assert allclose(mAPs["mAP.5"], except_mAPs["mAP.5"], rtol=1e-4)
34
+ print(mAPs)
35
+ assert allclose(mAPs["mAP.5:.95"], except_mAPs["mAP.5:.95"], rtol=1e-4)
36
 
37
 
38
  @pytest.fixture
39
+ def model_tester(inference_cfg: Config, model: YOLO, vec2box: Vec2Box, validation_progress_logger, device):
40
+ tester = ModelTester(inference_cfg, model, vec2box, validation_progress_logger, device)
41
+ return tester
42
 
43
 
44
+ def test_model_tester_initialization(model_tester: ModelTester):
45
+ assert isinstance(model_tester.model, YOLO)
46
+ assert hasattr(model_tester, "solve")
 
47
 
48
 
49
+ def test_model_tester_solve_single_image(model_tester: ModelTester, file_stream_data_loader: StreamDataLoader):
50
+ model_tester.solve(file_stream_data_loader)
 
 
 
51
 
52
 
53
  @pytest.fixture
54
+ def model_trainer(train_cfg: Config, model: YOLO, vec2box: Vec2Box, train_progress_logger, device):
55
+ train_cfg.task.epoch = 2
56
+ trainer = ModelTrainer(train_cfg, model, vec2box, train_progress_logger, device, use_ddp=False)
57
+ return trainer
 
 
 
 
 
 
 
 
 
 
58
 
 
 
 
 
 
 
59
 
60
+ def test_model_trainer_initialization(model_trainer: ModelTrainer):
61
 
62
+ assert isinstance(model_trainer.model, YOLO)
63
+ assert hasattr(model_trainer, "solve")
64
+ assert model_trainer.optimizer is not None
65
+ assert model_trainer.scheduler is not None
66
+ assert model_trainer.loss_fn is not None
67
 
68
 
69
+ def test_model_trainer_solve_mock_dataset(model_trainer: ModelTrainer, train_dataloader: YoloDataLoader):
70
+ model_trainer.solve(train_dataloader)
 
 
 
yolo/tools/data_loader.py CHANGED
@@ -111,7 +111,7 @@ class YoloDataset(Dataset):
111
  logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
112
  return data
113
 
114
- def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[torch.Tensor, None]:
115
  """
116
  Loads and validates bounding box data is [0, 1] from a label file.
117
 
@@ -119,7 +119,7 @@ class YoloDataset(Dataset):
119
  label_path (str): The filepath to the label file containing bounding box data.
120
 
121
  Returns:
122
- torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
123
  """
124
  bboxes = []
125
  for seg_data in seg_data_one_img:
@@ -145,7 +145,7 @@ class YoloDataset(Dataset):
145
  indices = torch.randint(0, len(self), (num,))
146
  return [self.get_data(idx)[:2] for idx in indices]
147
 
148
- def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
149
  img, bboxes, img_path = self.get_data(idx)
150
  img, bboxes, rev_tensor = self.transform(img, bboxes)
151
  return img, bboxes, rev_tensor, img_path
@@ -170,17 +170,17 @@ class YoloDataLoader(DataLoader):
170
  collate_fn=self.collate_fn,
171
  )
172
 
173
- def collate_fn(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
174
  """
175
  A collate function to handle batching of images and their corresponding targets.
176
 
177
  Args:
178
  batch (list of tuples): Each tuple contains:
179
- - image (torch.Tensor): The image tensor.
180
- - labels (torch.Tensor): The tensor of labels for the image.
181
 
182
  Returns:
183
- Tuple[torch.Tensor, List[torch.Tensor]]: A tuple containing:
184
  - A tensor of batched images.
185
  - A list of tensors, each corresponding to bboxes for each image in the batch.
186
  """
@@ -213,7 +213,7 @@ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: st
213
 
214
  class StreamDataLoader:
215
  def __init__(self, data_cfg: DataConfig):
216
- self.source = Path(data_cfg.source)
217
  self.running = True
218
  self.is_stream = isinstance(self.source, int) or str(self.source).lower().startswith("rtmp://")
219
 
@@ -225,6 +225,7 @@ class StreamDataLoader:
225
 
226
  self.cap = cv2.VideoCapture(self.source)
227
  else:
 
228
  self.queue = Queue()
229
  self.thread = Thread(target=self.load_source)
230
  self.thread.start()
 
111
  logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
112
  return data
113
 
114
+ def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]:
115
  """
116
  Loads and validates bounding box data is [0, 1] from a label file.
117
 
 
119
  label_path (str): The filepath to the label file containing bounding box data.
120
 
121
  Returns:
122
+ Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
123
  """
124
  bboxes = []
125
  for seg_data in seg_data_one_img:
 
145
  indices = torch.randint(0, len(self), (num,))
146
  return [self.get_data(idx)[:2] for idx in indices]
147
 
148
+ def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
149
  img, bboxes, img_path = self.get_data(idx)
150
  img, bboxes, rev_tensor = self.transform(img, bboxes)
151
  return img, bboxes, rev_tensor, img_path
 
170
  collate_fn=self.collate_fn,
171
  )
172
 
173
+ def collate_fn(self, batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
174
  """
175
  A collate function to handle batching of images and their corresponding targets.
176
 
177
  Args:
178
  batch (list of tuples): Each tuple contains:
179
+ - image (Tensor): The image tensor.
180
+ - labels (Tensor): The tensor of labels for the image.
181
 
182
  Returns:
183
+ Tuple[Tensor, List[Tensor]]: A tuple containing:
184
  - A tensor of batched images.
185
  - A list of tensors, each corresponding to bboxes for each image in the batch.
186
  """
 
213
 
214
  class StreamDataLoader:
215
  def __init__(self, data_cfg: DataConfig):
216
+ self.source = data_cfg.source
217
  self.running = True
218
  self.is_stream = isinstance(self.source, int) or str(self.source).lower().startswith("rtmp://")
219
 
 
225
 
226
  self.cap = cv2.VideoCapture(self.source)
227
  else:
228
+ self.source = Path(self.source)
229
  self.queue = Queue()
230
  self.thread = Thread(target=self.load_source)
231
  self.thread.start()
yolo/tools/dataset_preparation.py CHANGED
@@ -82,7 +82,7 @@ def prepare_dataset(dataset_cfg: DatasetConfig, task: str):
82
  logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
83
 
84
 
85
- def prepare_weight(download_link: Optional[str] = None, weight_path: Path = "v9-c.pt"):
86
  weight_name = weight_path.name
87
  if download_link is None:
88
  download_link = "https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/"
@@ -97,13 +97,3 @@ def prepare_weight(download_link: Optional[str] = None, weight_path: Path = "v9-
97
  download_file(weight_link, weight_path)
98
  except requests.exceptions.RequestException as e:
99
  logger.warning(f"Failed to download the weight file: {e}")
100
-
101
-
102
- if __name__ == "__main__":
103
- import sys
104
-
105
- sys.path.append("./")
106
- from utils.logging_utils import custom_logger
107
-
108
- custom_logger()
109
- prepare_weight()
 
82
  logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
83
 
84
 
85
+ def prepare_weight(download_link: Optional[str] = None, weight_path: Path = Path("v9-c.pt")):
86
  weight_name = weight_path.name
87
  if download_link is None:
88
  download_link = "https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/"
 
97
  download_file(weight_link, weight_path)
98
  except requests.exceptions.RequestException as e:
99
  logger.warning(f"Failed to download the weight file: {e}")
 
 
 
 
 
 
 
 
 
 
yolo/tools/drawer.py CHANGED
@@ -7,6 +7,9 @@ from loguru import logger
7
  from PIL import Image, ImageDraw, ImageFont
8
  from torchvision.transforms.functional import to_pil_image
9
 
 
 
 
10
 
11
  def draw_bboxes(
12
  img: Union[Image.Image, torch.Tensor],
@@ -62,7 +65,7 @@ def draw_bboxes(
62
  return img
63
 
64
 
65
- def draw_model(*, model_cfg=None, model=None, v7_base=False):
66
  from graphviz import Digraph
67
 
68
  if model_cfg:
 
7
  from PIL import Image, ImageDraw, ImageFont
8
  from torchvision.transforms.functional import to_pil_image
9
 
10
+ from yolo.config.config import ModelConfig
11
+ from yolo.model.yolo import YOLO
12
+
13
 
14
  def draw_bboxes(
15
  img: Union[Image.Image, torch.Tensor],
 
65
  return img
66
 
67
 
68
+ def draw_model(*, model_cfg: ModelConfig = None, model: YOLO = None, v7_base=False):
69
  from graphviz import Digraph
70
 
71
  if model_cfg:
yolo/utils/logging_utils.py CHANGED
@@ -138,7 +138,8 @@ class ProgressLogger(Progress):
138
  def finish_train(self):
139
  self.remove_task(self.task_epoch)
140
  self.stop()
141
- self.wandb.finish()
 
142
 
143
 
144
  def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
 
138
  def finish_train(self):
139
  self.remove_task(self.task_epoch)
140
  self.stop()
141
+ if self.use_wandb:
142
+ self.wandb.finish()
143
 
144
 
145
  def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):