from os.path import join import hydra from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer from pytorch_lightning import seed_everything from pytorch_lightning.loggers import TensorBoardLogger from DenseAV.denseav.data.AVDatasets import AVDataModule from DenseAV.denseav.shared import load_trained_model @hydra.main(config_path="configs", config_name="av_align.yaml") def my_app(cfg: DictConfig) -> None: from saved_models import saved_model_dict seed_everything(0) print(OmegaConf.to_yaml(cfg)) models_to_eval = [ "denseav_language", "denseav_sound", ] checkpoint_dir = "../checkpoints" saved_models = saved_model_dict(checkpoint_dir) for model_name in models_to_eval: model_info = saved_models[model_name] extra_data_args = model_info["data_args"] if "data_args" in model_info else {} model_info["extra_args"]["output_root"] = "../" model_info["extra_args"]["neg_audio"] = False model_info["extra_args"]["image_mixup"] = 0.0 model = load_trained_model(join(checkpoint_dir, model_info["chkpt_name"]), model_info["extra_args"]) model.set_full_train(True) if model.image_model_type == "dinov2": load_size = cfg.load_size * 2 else: load_size = cfg.load_size if model.image_model_type == "davenet": batch_size = cfg.batch_size // 2 elif model.image_model_type == "imagebind": batch_size = cfg.batch_size else: batch_size = cfg.batch_size print(load_size) data_args = dict( dataset_name=cfg.dataset_name, load_size=load_size, image_aug=cfg.image_aug, audio_aug=cfg.audio_aug, audio_model_type=model.audio_model_type, pytorch_data_dir=cfg.pytorch_data_dir, use_cached_embs=model.use_cached_embs, batch_size=batch_size, num_workers=cfg.num_workers, extra_audio_masking=False, use_original_val_set=False, use_extra_val_sets=True, use_caption=True, data_for_plotting=False, n_frames=None, audio_level=False, neg_audio=False, quad_mixup=0.0, bg_mixup=0.0, patch_mixup=0.0, patch_size=8, ) data_args = {**data_args, **extra_data_args} datamodule = AVDataModule(**data_args) log_dir = join(cfg.output_root, "logs", "evaluate", model_name) print(log_dir) tb_logger = TensorBoardLogger(log_dir, default_hp_metric=False) trainer = Trainer( accelerator='gpu', strategy="ddp", devices=cfg.num_gpus, logger=tb_logger) trainer.validate(model, datamodule) if __name__ == "__main__": my_app()