import os | |
import shutil | |
import sys | |
from pathlib import Path | |
project_root = Path(__file__).resolve().parent.parent.parent | |
sys.path.append(str(project_root)) | |
from yolo.config.config import Config | |
from yolo.tools.dataset_preparation import prepare_dataset, prepare_weight | |
def test_prepare_dataset(train_cfg: Config): | |
dataset_path = Path("tests/data") | |
if dataset_path.exists(): | |
shutil.rmtree(dataset_path) | |
prepare_dataset(train_cfg.dataset, task="train") | |
prepare_dataset(train_cfg.dataset, task="val") | |
images_path = Path("tests/data/images") | |
for data_type in images_path.iterdir(): | |
assert len(os.listdir(data_type)) == 5 | |
annotations_path = Path("tests/data/annotations") | |
assert "instances_val.json" in os.listdir(annotations_path) | |
assert "instances_train.json" in os.listdir(annotations_path) | |
def test_prepare_weight(): | |
prepare_weight() | |