import pytorch_lightning as pl from pytorch_lightning.loggers import WandbLogger import wandb from lit_mlp import LitMLP from data_module import birds, samples, batch_size from loggings import ImagePredictionLogger wandb_logger = WandbLogger(project="lit-wandb") trainer = pl.Trainer( logger=wandb_logger, log_every_n_steps=50, gpus=0, max_epochs=100, deterministic=True, callbacks=[ImagePredictionLogger(samples)] ) model = LitMLP(n_classes=18, batch_size=batch_size) trainer.fit(model, birds) trainer.test(datamodule=birds, ckpt_path=None) wandb.finish()