YOLO / yolo /lazy.py
henry000's picture
🔨 [Update] gradient clip strategy
ee709f5
import sys
from pathlib import Path
import hydra
from lightning import Trainer
project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
from yolo.config.config import Config
from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel
from yolo.utils.logging_utils import setup
@hydra.main(config_path="config", config_name="config", version_base=None)
def main(cfg: Config):
callbacks, loggers, save_path = setup(cfg)
trainer = Trainer(
accelerator="auto",
max_epochs=getattr(cfg.task, "epoch", None),
precision="16-mixed",
callbacks=callbacks,
logger=loggers,
log_every_n_steps=1,
gradient_clip_val=10,
gradient_clip_algorithm="value",
deterministic=True,
enable_progress_bar=not getattr(cfg, "quite", False),
default_root_dir=save_path,
)
if cfg.task.task == "train":
model = TrainModel(cfg)
trainer.fit(model)
if cfg.task.task == "validation":
model = ValidateModel(cfg)
trainer.validate(model)
if cfg.task.task == "inference":
model = InferenceModel(cfg)
trainer.predict(model)
if __name__ == "__main__":
main()