import sys | |
from pathlib import Path | |
import hydra | |
project_root = Path(__file__).resolve().parent.parent | |
sys.path.append(str(project_root)) | |
from yolo import ( | |
Config, | |
ModelTrainer, | |
ProgressLogger, | |
create_converter, | |
create_dataloader, | |
create_model, | |
) | |
from yolo.utils.model_utils import get_device | |
def main(cfg: Config): | |
progress = ProgressLogger(cfg, exp_name=cfg.name) | |
device, use_ddp = get_device(cfg.device) | |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task, use_ddp) | |
model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight) | |
model = model.to(device) | |
converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device) | |
solver = ModelTrainer(cfg, model, converter, progress, device) | |
progress.start() | |
solver.solve(dataloader) | |
if __name__ == "__main__": | |
main() | |