File size: 1,233 Bytes
a0c7025 8b3b3ef a0c7025 7f8235a 8b3b3ef a0c7025 3ebbbd9 8b3b3ef bd0409b 8b3b3ef 32405d5 ee709f5 4b8ec68 3ebbbd9 8b3b3ef bd0409b 3441a79 a0c7025 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import sys
from pathlib import Path
import hydra
from lightning import Trainer
project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
from yolo.config.config import Config
from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel
from yolo.utils.logging_utils import setup
@hydra.main(config_path="config", config_name="config", version_base=None)
def main(cfg: Config):
callbacks, loggers, save_path = setup(cfg)
trainer = Trainer(
accelerator="auto",
max_epochs=getattr(cfg.task, "epoch", None),
precision="16-mixed",
callbacks=callbacks,
logger=loggers,
log_every_n_steps=1,
gradient_clip_val=10,
gradient_clip_algorithm="value",
deterministic=True,
enable_progress_bar=not getattr(cfg, "quite", False),
default_root_dir=save_path,
)
if cfg.task.task == "train":
model = TrainModel(cfg)
trainer.fit(model)
if cfg.task.task == "validation":
model = ValidateModel(cfg)
trainer.validate(model)
if cfg.task.task == "inference":
model = InferenceModel(cfg)
trainer.predict(model)
if __name__ == "__main__":
main()
|