|
import datetime |
|
import shutil |
|
import time |
|
|
|
import hydra |
|
import lightning as L |
|
from hydra.utils import instantiate |
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
from src.test.utils import evaluate |
|
from src.tools.files import json_dump |
|
from src.tools.utils import calculate_model_params |
|
|
|
|
|
@hydra.main(version_base=None, config_path="configs", config_name="train") |
|
def main(cfg: DictConfig): |
|
L.seed_everything(cfg.seed, workers=True) |
|
fabric = instantiate(cfg.trainer.fabric) |
|
fabric.launch() |
|
fabric.logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) |
|
|
|
if fabric.global_rank == 0: |
|
json_dump(OmegaConf.to_container(cfg, resolve=True), "hydra.json") |
|
|
|
data = instantiate(cfg.data) |
|
loader_train = fabric.setup_dataloaders(data.train_dataloader()) |
|
if cfg.val: |
|
loader_val = fabric.setup_dataloaders(data.val_dataloader()) |
|
|
|
model = instantiate(cfg.model) |
|
calculate_model_params(model) |
|
|
|
optimizer = instantiate( |
|
cfg.model.optimizer, params=model.parameters(), _partial_=False |
|
) |
|
model, optimizer = fabric.setup(model, optimizer) |
|
|
|
scheduler = instantiate(cfg.model.scheduler) |
|
|
|
fabric.print("Start training") |
|
start_time = time.time() |
|
for epoch in range(cfg.trainer.max_epochs): |
|
scheduler(optimizer, epoch) |
|
|
|
columns = shutil.get_terminal_size().columns |
|
fabric.print("-" * columns) |
|
fabric.print(f"Epoch {epoch + 1}/{cfg.trainer.max_epochs}".center(columns)) |
|
|
|
train(model, loader_train, optimizer, fabric, epoch, cfg) |
|
|
|
if cfg.val: |
|
fabric.print("Evaluate") |
|
evaluate(model, loader_val, fabric=fabric) |
|
|
|
state = { |
|
"epoch": epoch, |
|
"model": model, |
|
"optimizer": optimizer, |
|
"scheduler": scheduler, |
|
} |
|
if cfg.trainer.save_ckpt == "all": |
|
fabric.save(f"ckpt_{epoch}.ckpt", state) |
|
elif cfg.trainer.save_ckpt == "last": |
|
fabric.save("ckpt_last.ckpt", state) |
|
|
|
fabric.barrier() |
|
|
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
fabric.print(f"Training time {total_time_str}") |
|
|
|
for dataset in cfg.test: |
|
columns = shutil.get_terminal_size().columns |
|
fabric.print("-" * columns) |
|
fabric.print(f"Testing on {cfg.test[dataset].dataname}".center(columns)) |
|
|
|
data = instantiate(cfg.test[dataset]) |
|
test_loader = fabric.setup_dataloaders(data.test_dataloader()) |
|
|
|
test = instantiate(cfg.test[dataset].test) |
|
test(model, test_loader, fabric=fabric) |
|
|
|
fabric.logger.finalize("success") |
|
fabric.print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")) |
|
|
|
|
|
def train(model, train_loader, optimizer, fabric, epoch, cfg): |
|
model.train() |
|
|
|
for batch_idx, batch in enumerate(train_loader): |
|
optimizer.zero_grad() |
|
loss = model(batch, fabric) |
|
fabric.backward(loss) |
|
optimizer.step() |
|
|
|
if batch_idx % cfg.trainer.print_interval == 0: |
|
fabric.print( |
|
f"[{100.0 * batch_idx / len(train_loader):.0f}%]\tLoss: {loss.item():.6f}" |
|
) |
|
if batch_idx % cfg.trainer.log_interval == 0: |
|
fabric.log_dict( |
|
{ |
|
"loss": loss.item(), |
|
"lr": optimizer.param_groups[0]["lr"], |
|
"epoch": epoch, |
|
} |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|