|
import sys |
|
from pathlib import Path |
|
|
|
import hydra |
|
from lightning import Trainer |
|
|
|
project_root = Path(__file__).resolve().parent.parent |
|
sys.path.append(str(project_root)) |
|
|
|
from yolo.config.config import Config |
|
from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel |
|
from yolo.utils.logging_utils import setup |
|
|
|
|
|
@hydra.main(config_path="config", config_name="config", version_base=None) |
|
def main(cfg: Config): |
|
callbacks, loggers, save_path = setup(cfg) |
|
|
|
trainer = Trainer( |
|
accelerator="auto", |
|
max_epochs=getattr(cfg.task, "epoch", None), |
|
precision="16-mixed", |
|
callbacks=callbacks, |
|
logger=loggers, |
|
log_every_n_steps=1, |
|
gradient_clip_val=10, |
|
gradient_clip_algorithm="value", |
|
deterministic=True, |
|
enable_progress_bar=not getattr(cfg, "quite", False), |
|
default_root_dir=save_path, |
|
) |
|
|
|
if cfg.task.task == "train": |
|
model = TrainModel(cfg) |
|
trainer.fit(model) |
|
if cfg.task.task == "validation": |
|
model = ValidateModel(cfg) |
|
trainer.validate(model) |
|
if cfg.task.task == "inference": |
|
model = InferenceModel(cfg) |
|
trainer.predict(model) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|