✅ [Add] test, increase test coverage for dev mode
Browse files- tests/conftest.py +91 -0
- tests/test_tools/test_data_loader.py +71 -0
- tests/test_tools/test_dataset_preparation.py +29 -0
- tests/test_tools/test_drawer.py +29 -0
- tests/test_tools/test_solver.py +39 -83
- yolo/tools/data_loader.py +9 -8
- yolo/tools/dataset_preparation.py +1 -11
- yolo/tools/drawer.py +4 -1
- yolo/utils/logging_utils.py +2 -1
tests/conftest.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import pytest
|
5 |
+
import torch
|
6 |
+
from hydra import compose, initialize
|
7 |
+
|
8 |
+
project_root = Path(__file__).resolve().parent.parent
|
9 |
+
sys.path.append(str(project_root))
|
10 |
+
|
11 |
+
from yolo import Config, Vec2Box, create_model
|
12 |
+
from yolo.model.yolo import YOLO
|
13 |
+
from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader
|
14 |
+
from yolo.utils.logging_utils import ProgressLogger, set_seed
|
15 |
+
|
16 |
+
|
17 |
+
def pytest_configure(config):
|
18 |
+
config.addinivalue_line("markers", "requires_cuda: mark test to run only if CUDA is available")
|
19 |
+
|
20 |
+
|
21 |
+
def get_cfg(overrides=[]) -> Config:
|
22 |
+
config_path = "../yolo/config"
|
23 |
+
with initialize(config_path=config_path, version_base=None):
|
24 |
+
cfg: Config = compose(config_name="config", overrides=overrides)
|
25 |
+
set_seed(cfg.lucky_number)
|
26 |
+
return cfg
|
27 |
+
|
28 |
+
|
29 |
+
@pytest.fixture(scope="session")
|
30 |
+
def train_cfg() -> Config:
|
31 |
+
return get_cfg(overrides=["task=train", "dataset=mock"])
|
32 |
+
|
33 |
+
|
34 |
+
@pytest.fixture(scope="session")
|
35 |
+
def validation_cfg():
|
36 |
+
return get_cfg(overrides=["task=validation", "dataset=mock"])
|
37 |
+
|
38 |
+
|
39 |
+
@pytest.fixture(scope="session")
|
40 |
+
def inference_cfg():
|
41 |
+
return get_cfg(overrides=["task=inference"])
|
42 |
+
|
43 |
+
|
44 |
+
@pytest.fixture(scope="session")
|
45 |
+
def device():
|
46 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
47 |
+
|
48 |
+
|
49 |
+
@pytest.fixture(scope="session")
|
50 |
+
def train_progress_logger(train_cfg: Config):
|
51 |
+
progress_logger = ProgressLogger(train_cfg, exp_name=train_cfg.name)
|
52 |
+
return progress_logger
|
53 |
+
|
54 |
+
|
55 |
+
@pytest.fixture(scope="session")
|
56 |
+
def validation_progress_logger(validation_cfg: Config):
|
57 |
+
progress_logger = ProgressLogger(validation_cfg, exp_name=validation_cfg.name)
|
58 |
+
return progress_logger
|
59 |
+
|
60 |
+
|
61 |
+
@pytest.fixture(scope="session")
|
62 |
+
def model(train_cfg: Config, device) -> YOLO:
|
63 |
+
model = create_model(train_cfg.model)
|
64 |
+
return model.to(device)
|
65 |
+
|
66 |
+
|
67 |
+
@pytest.fixture(scope="session")
|
68 |
+
def vec2box(train_cfg: Config, model: YOLO, device) -> Vec2Box:
|
69 |
+
vec2box = Vec2Box(model, train_cfg.image_size, device)
|
70 |
+
return vec2box
|
71 |
+
|
72 |
+
|
73 |
+
@pytest.fixture(scope="session")
|
74 |
+
def train_dataloader(train_cfg: Config):
|
75 |
+
return YoloDataLoader(train_cfg.task.data, train_cfg.dataset, train_cfg.task.task)
|
76 |
+
|
77 |
+
|
78 |
+
@pytest.fixture(scope="session")
|
79 |
+
def validation_dataloader(validation_cfg: Config):
|
80 |
+
return YoloDataLoader(validation_cfg.task.data, validation_cfg.dataset, validation_cfg.task.task)
|
81 |
+
|
82 |
+
|
83 |
+
@pytest.fixture(scope="session")
|
84 |
+
def file_stream_data_loader(inference_cfg: Config):
|
85 |
+
return StreamDataLoader(inference_cfg.task.data)
|
86 |
+
|
87 |
+
|
88 |
+
@pytest.fixture(scope="session")
|
89 |
+
def directory_stream_data_loader(inference_cfg: Config):
|
90 |
+
inference_cfg.task.data.source = "tests/data/images/train"
|
91 |
+
return StreamDataLoader(inference_cfg.task.data)
|
tests/test_tools/test_data_loader.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import pytest
|
5 |
+
|
6 |
+
project_root = Path(__file__).resolve().parent.parent.parent
|
7 |
+
sys.path.append(str(project_root))
|
8 |
+
|
9 |
+
from yolo.config.config import Config, TrainConfig
|
10 |
+
from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader, create_dataloader
|
11 |
+
|
12 |
+
|
13 |
+
def test_create_dataloader_cache(train_cfg: Config):
|
14 |
+
train_cfg.task.data.shuffle = False
|
15 |
+
train_cfg.task.data.batch_size = 2
|
16 |
+
|
17 |
+
cache_file = Path("tests/data/train.cache")
|
18 |
+
cache_file.unlink(missing_ok=True)
|
19 |
+
|
20 |
+
make_cache_loader = create_dataloader(train_cfg.task.data, train_cfg.dataset)
|
21 |
+
load_cache_loader = create_dataloader(train_cfg.task.data, train_cfg.dataset)
|
22 |
+
m_batch_size, m_images, _, m_reverse_tensors, m_image_paths = next(iter(make_cache_loader))
|
23 |
+
l_batch_size, l_images, _, l_reverse_tensors, l_image_paths = next(iter(load_cache_loader))
|
24 |
+
assert m_batch_size == l_batch_size
|
25 |
+
assert m_images.shape == l_images.shape
|
26 |
+
assert m_reverse_tensors.shape == l_reverse_tensors.shape
|
27 |
+
assert m_image_paths == l_image_paths
|
28 |
+
|
29 |
+
|
30 |
+
def test_training_data_loader_correctness(train_dataloader: YoloDataLoader):
|
31 |
+
"""Test that the training data loader produces correctly shaped data and metadata."""
|
32 |
+
batch_size, images, _, reverse_tensors, image_paths = next(iter(train_dataloader))
|
33 |
+
assert batch_size == 2
|
34 |
+
assert images.shape == (2, 3, 640, 640)
|
35 |
+
assert reverse_tensors.shape == (2, 5)
|
36 |
+
expected_paths = [
|
37 |
+
Path("tests/data/images/train/000000050725.jpg"),
|
38 |
+
Path("tests/data/images/train/000000167848.jpg"),
|
39 |
+
]
|
40 |
+
assert image_paths == expected_paths
|
41 |
+
|
42 |
+
|
43 |
+
def test_validation_data_loader_correctness(validation_dataloader: YoloDataLoader):
|
44 |
+
batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader))
|
45 |
+
assert batch_size == 4
|
46 |
+
assert images.shape == (4, 3, 640, 640)
|
47 |
+
assert targets.shape == (4, 18, 5)
|
48 |
+
assert reverse_tensors.shape == (4, 5)
|
49 |
+
expected_paths = [
|
50 |
+
Path("tests/data/images/val/000000151480.jpg"),
|
51 |
+
Path("tests/data/images/val/000000284106.jpg"),
|
52 |
+
Path("tests/data/images/val/000000323571.jpg"),
|
53 |
+
Path("tests/data/images/val/000000570456.jpg"),
|
54 |
+
]
|
55 |
+
assert image_paths == expected_paths
|
56 |
+
|
57 |
+
|
58 |
+
def test_file_stream_data_loader_frame(file_stream_data_loader: StreamDataLoader):
|
59 |
+
"""Test the frame output from the file stream data loader."""
|
60 |
+
frame, rev_tensor, origin_frame = next(iter(file_stream_data_loader))
|
61 |
+
assert frame.shape == (1, 3, 640, 640)
|
62 |
+
assert rev_tensor.shape == (1, 5)
|
63 |
+
assert origin_frame.size == (1024, 768)
|
64 |
+
|
65 |
+
|
66 |
+
def test_directory_stream_data_loader_frame(directory_stream_data_loader: StreamDataLoader):
|
67 |
+
"""Test the frame output from the directory stream data loader."""
|
68 |
+
frame, rev_tensor, origin_frame = next(iter(directory_stream_data_loader))
|
69 |
+
assert frame.shape == (1, 3, 640, 640)
|
70 |
+
assert rev_tensor.shape == (1, 5)
|
71 |
+
assert origin_frame.size == (480, 640)
|
tests/test_tools/test_dataset_preparation.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
project_root = Path(__file__).resolve().parent.parent.parent
|
7 |
+
sys.path.append(str(project_root))
|
8 |
+
|
9 |
+
from yolo.config.config import Config
|
10 |
+
from yolo.tools.dataset_preparation import prepare_dataset, prepare_weight
|
11 |
+
|
12 |
+
|
13 |
+
def test_prepare_dataset(train_cfg: Config):
|
14 |
+
dataset_path = Path("tests/data")
|
15 |
+
if dataset_path.exists():
|
16 |
+
shutil.rmtree(dataset_path)
|
17 |
+
prepare_dataset(train_cfg.dataset, task="train")
|
18 |
+
prepare_dataset(train_cfg.dataset, task="val")
|
19 |
+
|
20 |
+
images_path = Path("tests/data/images")
|
21 |
+
for data_type in images_path.iterdir():
|
22 |
+
assert len(os.listdir(data_type)) == 5
|
23 |
+
|
24 |
+
annotations_path = Path("tests/data/annotations")
|
25 |
+
assert os.listdir(annotations_path) == ["instances_val.json", "instances_train.json"]
|
26 |
+
|
27 |
+
|
28 |
+
def test_prepare_weight():
|
29 |
+
prepare_weight()
|
tests/test_tools/test_drawer.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
from torch import tensor
|
6 |
+
|
7 |
+
project_root = Path(__file__).resolve().parent.parent.parent
|
8 |
+
sys.path.append(str(project_root))
|
9 |
+
|
10 |
+
from yolo.config.config import Config
|
11 |
+
from yolo.model.yolo import YOLO
|
12 |
+
from yolo.tools.drawer import draw_bboxes, draw_model
|
13 |
+
|
14 |
+
|
15 |
+
def test_draw_model_by_config(train_cfg: Config):
|
16 |
+
"""Test the drawing of a model based on a configuration."""
|
17 |
+
draw_model(model_cfg=train_cfg.model)
|
18 |
+
|
19 |
+
|
20 |
+
def test_draw_model_by_model(model: YOLO):
|
21 |
+
"""Test the drawing of a YOLO model."""
|
22 |
+
draw_model(model=model)
|
23 |
+
|
24 |
+
|
25 |
+
def test_draw_bboxes():
|
26 |
+
"""Test drawing bounding boxes on an image."""
|
27 |
+
predictions = tensor([[0, 60, 60, 160, 160, 0.5], [0, 40, 40, 120, 120, 0.5]])
|
28 |
+
pil_image = Image.open("tests/data/images/train/000000050725.jpg")
|
29 |
+
draw_bboxes(pil_image, [predictions])
|
tests/test_tools/test_solver.py
CHANGED
@@ -1,114 +1,70 @@
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
-
from unittest.mock import MagicMock, patch
|
4 |
|
5 |
import pytest
|
6 |
-
import
|
7 |
-
from hydra import compose, initialize
|
8 |
|
9 |
project_root = Path(__file__).resolve().parent.parent.parent
|
10 |
sys.path.append(str(project_root))
|
11 |
|
12 |
-
from yolo.config.config import
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
TrainConfig,
|
17 |
-
ValidationConfig,
|
18 |
-
)
|
19 |
-
from yolo.model.yolo import YOLO, create_model
|
20 |
-
from yolo.tools.data_loader import create_dataloader
|
21 |
-
from yolo.tools.loss_functions import create_loss_function
|
22 |
-
from yolo.tools.solver import ( # Adjust the import to your module
|
23 |
-
ModelTester,
|
24 |
-
ModelTrainer,
|
25 |
-
ModelValidator,
|
26 |
-
)
|
27 |
from yolo.utils.bounding_box_utils import Vec2Box
|
28 |
-
from yolo.utils.logging_utils import ProgressLogger
|
29 |
-
from yolo.utils.model_utils import (
|
30 |
-
ExponentialMovingAverage,
|
31 |
-
create_optimizer,
|
32 |
-
create_scheduler,
|
33 |
-
)
|
34 |
|
35 |
|
36 |
@pytest.fixture
|
37 |
-
def
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
return
|
42 |
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
cfg: Config = compose(config_name="config", overrides=["task=validation"])
|
48 |
-
cfg.weight = None
|
49 |
-
return cfg
|
50 |
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
|
59 |
|
60 |
@pytest.fixture
|
61 |
-
def
|
62 |
-
|
63 |
-
return
|
64 |
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
return model.to(device)
|
70 |
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
model = create_model(cfg.model, weight_path=None).to(device)
|
75 |
-
vec2box = Vec2Box(model, cfg.image_size, device)
|
76 |
-
return vec2box
|
77 |
|
78 |
|
79 |
@pytest.fixture
|
80 |
-
def
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
# def test_model_trainer_initialization(cfg: Config, model: YOLO, vec2box: Vec2Box, progress_logger, device):
|
86 |
-
# trainer = ModelTrainer(cfg, model, vec2box, progress_logger, device, use_ddp=False)
|
87 |
-
# assert trainer.model == model
|
88 |
-
# assert trainer.device == device
|
89 |
-
# assert trainer.optimizer is not None
|
90 |
-
# assert trainer.scheduler is not None
|
91 |
-
# assert trainer.loss_fn is not None
|
92 |
-
# assert trainer.progress == progress_logger
|
93 |
-
|
94 |
|
95 |
-
# def test_model_trainer_train_one_batch(config, model, vec2box, progress_logger, device):
|
96 |
-
# trainer = ModelTrainer(config, model, vec2box, progress_logger, device, use_ddp=False)
|
97 |
-
# images = torch.rand(1, 3, 224, 224)
|
98 |
-
# targets = torch.rand(1, 5)
|
99 |
-
# loss_item = trainer.train_one_batch(images, targets)
|
100 |
-
# assert isinstance(loss_item, dict)
|
101 |
|
|
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
assert
|
106 |
-
assert
|
107 |
-
assert
|
108 |
|
109 |
|
110 |
-
def
|
111 |
-
|
112 |
-
assert tester.model == model
|
113 |
-
assert tester.device == device
|
114 |
-
assert tester.progress == progress_logger
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
|
|
3 |
|
4 |
import pytest
|
5 |
+
from torch import allclose, tensor
|
|
|
6 |
|
7 |
project_root = Path(__file__).resolve().parent.parent.parent
|
8 |
sys.path.append(str(project_root))
|
9 |
|
10 |
+
from yolo.config.config import Config
|
11 |
+
from yolo.model.yolo import YOLO
|
12 |
+
from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader
|
13 |
+
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
from yolo.utils.bounding_box_utils import Vec2Box
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
@pytest.fixture
|
18 |
+
def model_validator(validation_cfg: Config, model: YOLO, vec2box: Vec2Box, validation_progress_logger, device):
|
19 |
+
validator = ModelValidator(
|
20 |
+
validation_cfg.task, validation_cfg.dataset, model, vec2box, validation_progress_logger, device
|
21 |
+
)
|
22 |
+
return validator
|
23 |
|
24 |
|
25 |
+
def test_model_validator_initialization(model_validator: ModelValidator):
|
26 |
+
assert isinstance(model_validator.model, YOLO)
|
27 |
+
assert hasattr(model_validator, "solve")
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
+
def test_model_validator_solve_mock_dataset(model_validator: ModelValidator, validation_dataloader: YoloDataLoader):
|
31 |
+
mAPs = model_validator.solve(validation_dataloader)
|
32 |
+
except_mAPs = {"mAP.5": tensor(0.6969), "mAP.5:.95": tensor(0.4195)}
|
33 |
+
assert allclose(mAPs["mAP.5"], except_mAPs["mAP.5"], rtol=1e-4)
|
34 |
+
print(mAPs)
|
35 |
+
assert allclose(mAPs["mAP.5:.95"], except_mAPs["mAP.5:.95"], rtol=1e-4)
|
36 |
|
37 |
|
38 |
@pytest.fixture
|
39 |
+
def model_tester(inference_cfg: Config, model: YOLO, vec2box: Vec2Box, validation_progress_logger, device):
|
40 |
+
tester = ModelTester(inference_cfg, model, vec2box, validation_progress_logger, device)
|
41 |
+
return tester
|
42 |
|
43 |
|
44 |
+
def test_model_tester_initialization(model_tester: ModelTester):
|
45 |
+
assert isinstance(model_tester.model, YOLO)
|
46 |
+
assert hasattr(model_tester, "solve")
|
|
|
47 |
|
48 |
|
49 |
+
def test_model_tester_solve_single_image(model_tester: ModelTester, file_stream_data_loader: StreamDataLoader):
|
50 |
+
model_tester.solve(file_stream_data_loader)
|
|
|
|
|
|
|
51 |
|
52 |
|
53 |
@pytest.fixture
|
54 |
+
def model_trainer(train_cfg: Config, model: YOLO, vec2box: Vec2Box, train_progress_logger, device):
|
55 |
+
train_cfg.task.epoch = 2
|
56 |
+
trainer = ModelTrainer(train_cfg, model, vec2box, train_progress_logger, device, use_ddp=False)
|
57 |
+
return trainer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
def test_model_trainer_initialization(model_trainer: ModelTrainer):
|
61 |
|
62 |
+
assert isinstance(model_trainer.model, YOLO)
|
63 |
+
assert hasattr(model_trainer, "solve")
|
64 |
+
assert model_trainer.optimizer is not None
|
65 |
+
assert model_trainer.scheduler is not None
|
66 |
+
assert model_trainer.loss_fn is not None
|
67 |
|
68 |
|
69 |
+
def test_model_trainer_solve_mock_dataset(model_trainer: ModelTrainer, train_dataloader: YoloDataLoader):
|
70 |
+
model_trainer.solve(train_dataloader)
|
|
|
|
|
|
yolo/tools/data_loader.py
CHANGED
@@ -111,7 +111,7 @@ class YoloDataset(Dataset):
|
|
111 |
logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
|
112 |
return data
|
113 |
|
114 |
-
def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[
|
115 |
"""
|
116 |
Loads and validates bounding box data is [0, 1] from a label file.
|
117 |
|
@@ -119,7 +119,7 @@ class YoloDataset(Dataset):
|
|
119 |
label_path (str): The filepath to the label file containing bounding box data.
|
120 |
|
121 |
Returns:
|
122 |
-
|
123 |
"""
|
124 |
bboxes = []
|
125 |
for seg_data in seg_data_one_img:
|
@@ -145,7 +145,7 @@ class YoloDataset(Dataset):
|
|
145 |
indices = torch.randint(0, len(self), (num,))
|
146 |
return [self.get_data(idx)[:2] for idx in indices]
|
147 |
|
148 |
-
def __getitem__(self, idx) ->
|
149 |
img, bboxes, img_path = self.get_data(idx)
|
150 |
img, bboxes, rev_tensor = self.transform(img, bboxes)
|
151 |
return img, bboxes, rev_tensor, img_path
|
@@ -170,17 +170,17 @@ class YoloDataLoader(DataLoader):
|
|
170 |
collate_fn=self.collate_fn,
|
171 |
)
|
172 |
|
173 |
-
def collate_fn(self, batch: List[Tuple[
|
174 |
"""
|
175 |
A collate function to handle batching of images and their corresponding targets.
|
176 |
|
177 |
Args:
|
178 |
batch (list of tuples): Each tuple contains:
|
179 |
-
- image (
|
180 |
-
- labels (
|
181 |
|
182 |
Returns:
|
183 |
-
Tuple[
|
184 |
- A tensor of batched images.
|
185 |
- A list of tensors, each corresponding to bboxes for each image in the batch.
|
186 |
"""
|
@@ -213,7 +213,7 @@ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: st
|
|
213 |
|
214 |
class StreamDataLoader:
|
215 |
def __init__(self, data_cfg: DataConfig):
|
216 |
-
self.source =
|
217 |
self.running = True
|
218 |
self.is_stream = isinstance(self.source, int) or str(self.source).lower().startswith("rtmp://")
|
219 |
|
@@ -225,6 +225,7 @@ class StreamDataLoader:
|
|
225 |
|
226 |
self.cap = cv2.VideoCapture(self.source)
|
227 |
else:
|
|
|
228 |
self.queue = Queue()
|
229 |
self.thread = Thread(target=self.load_source)
|
230 |
self.thread.start()
|
|
|
111 |
logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
|
112 |
return data
|
113 |
|
114 |
+
def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]:
|
115 |
"""
|
116 |
Loads and validates bounding box data is [0, 1] from a label file.
|
117 |
|
|
|
119 |
label_path (str): The filepath to the label file containing bounding box data.
|
120 |
|
121 |
Returns:
|
122 |
+
Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
|
123 |
"""
|
124 |
bboxes = []
|
125 |
for seg_data in seg_data_one_img:
|
|
|
145 |
indices = torch.randint(0, len(self), (num,))
|
146 |
return [self.get_data(idx)[:2] for idx in indices]
|
147 |
|
148 |
+
def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
|
149 |
img, bboxes, img_path = self.get_data(idx)
|
150 |
img, bboxes, rev_tensor = self.transform(img, bboxes)
|
151 |
return img, bboxes, rev_tensor, img_path
|
|
|
170 |
collate_fn=self.collate_fn,
|
171 |
)
|
172 |
|
173 |
+
def collate_fn(self, batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
|
174 |
"""
|
175 |
A collate function to handle batching of images and their corresponding targets.
|
176 |
|
177 |
Args:
|
178 |
batch (list of tuples): Each tuple contains:
|
179 |
+
- image (Tensor): The image tensor.
|
180 |
+
- labels (Tensor): The tensor of labels for the image.
|
181 |
|
182 |
Returns:
|
183 |
+
Tuple[Tensor, List[Tensor]]: A tuple containing:
|
184 |
- A tensor of batched images.
|
185 |
- A list of tensors, each corresponding to bboxes for each image in the batch.
|
186 |
"""
|
|
|
213 |
|
214 |
class StreamDataLoader:
|
215 |
def __init__(self, data_cfg: DataConfig):
|
216 |
+
self.source = data_cfg.source
|
217 |
self.running = True
|
218 |
self.is_stream = isinstance(self.source, int) or str(self.source).lower().startswith("rtmp://")
|
219 |
|
|
|
225 |
|
226 |
self.cap = cv2.VideoCapture(self.source)
|
227 |
else:
|
228 |
+
self.source = Path(self.source)
|
229 |
self.queue = Queue()
|
230 |
self.thread = Thread(target=self.load_source)
|
231 |
self.thread.start()
|
yolo/tools/dataset_preparation.py
CHANGED
@@ -82,7 +82,7 @@ def prepare_dataset(dataset_cfg: DatasetConfig, task: str):
|
|
82 |
logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
|
83 |
|
84 |
|
85 |
-
def prepare_weight(download_link: Optional[str] = None, weight_path: Path = "v9-c.pt"):
|
86 |
weight_name = weight_path.name
|
87 |
if download_link is None:
|
88 |
download_link = "https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/"
|
@@ -97,13 +97,3 @@ def prepare_weight(download_link: Optional[str] = None, weight_path: Path = "v9-
|
|
97 |
download_file(weight_link, weight_path)
|
98 |
except requests.exceptions.RequestException as e:
|
99 |
logger.warning(f"Failed to download the weight file: {e}")
|
100 |
-
|
101 |
-
|
102 |
-
if __name__ == "__main__":
|
103 |
-
import sys
|
104 |
-
|
105 |
-
sys.path.append("./")
|
106 |
-
from utils.logging_utils import custom_logger
|
107 |
-
|
108 |
-
custom_logger()
|
109 |
-
prepare_weight()
|
|
|
82 |
logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
|
83 |
|
84 |
|
85 |
+
def prepare_weight(download_link: Optional[str] = None, weight_path: Path = Path("v9-c.pt")):
|
86 |
weight_name = weight_path.name
|
87 |
if download_link is None:
|
88 |
download_link = "https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/"
|
|
|
97 |
download_file(weight_link, weight_path)
|
98 |
except requests.exceptions.RequestException as e:
|
99 |
logger.warning(f"Failed to download the weight file: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
yolo/tools/drawer.py
CHANGED
@@ -7,6 +7,9 @@ from loguru import logger
|
|
7 |
from PIL import Image, ImageDraw, ImageFont
|
8 |
from torchvision.transforms.functional import to_pil_image
|
9 |
|
|
|
|
|
|
|
10 |
|
11 |
def draw_bboxes(
|
12 |
img: Union[Image.Image, torch.Tensor],
|
@@ -62,7 +65,7 @@ def draw_bboxes(
|
|
62 |
return img
|
63 |
|
64 |
|
65 |
-
def draw_model(*, model_cfg=None, model=None, v7_base=False):
|
66 |
from graphviz import Digraph
|
67 |
|
68 |
if model_cfg:
|
|
|
7 |
from PIL import Image, ImageDraw, ImageFont
|
8 |
from torchvision.transforms.functional import to_pil_image
|
9 |
|
10 |
+
from yolo.config.config import ModelConfig
|
11 |
+
from yolo.model.yolo import YOLO
|
12 |
+
|
13 |
|
14 |
def draw_bboxes(
|
15 |
img: Union[Image.Image, torch.Tensor],
|
|
|
65 |
return img
|
66 |
|
67 |
|
68 |
+
def draw_model(*, model_cfg: ModelConfig = None, model: YOLO = None, v7_base=False):
|
69 |
from graphviz import Digraph
|
70 |
|
71 |
if model_cfg:
|
yolo/utils/logging_utils.py
CHANGED
@@ -138,7 +138,8 @@ class ProgressLogger(Progress):
|
|
138 |
def finish_train(self):
|
139 |
self.remove_task(self.task_epoch)
|
140 |
self.stop()
|
141 |
-
self.
|
|
|
142 |
|
143 |
|
144 |
def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
|
|
|
138 |
def finish_train(self):
|
139 |
self.remove_task(self.task_epoch)
|
140 |
self.stop()
|
141 |
+
if self.use_wandb:
|
142 |
+
self.wandb.finish()
|
143 |
|
144 |
|
145 |
def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
|