Spaces:
Running
Running
from os.path import join | |
import hydra | |
from omegaconf import DictConfig, OmegaConf | |
from pytorch_lightning import Trainer | |
from pytorch_lightning import seed_everything | |
from pytorch_lightning.loggers import TensorBoardLogger | |
from DenseAV.denseav.data.AVDatasets import AVDataModule | |
from DenseAV.denseav.shared import load_trained_model | |
def my_app(cfg: DictConfig) -> None: | |
from saved_models import saved_model_dict | |
seed_everything(0) | |
print(OmegaConf.to_yaml(cfg)) | |
models_to_eval = [ | |
"denseav_language", | |
"denseav_sound", | |
] | |
checkpoint_dir = "../checkpoints" | |
saved_models = saved_model_dict(checkpoint_dir) | |
for model_name in models_to_eval: | |
model_info = saved_models[model_name] | |
extra_data_args = model_info["data_args"] if "data_args" in model_info else {} | |
model_info["extra_args"]["output_root"] = "../" | |
model_info["extra_args"]["neg_audio"] = False | |
model_info["extra_args"]["image_mixup"] = 0.0 | |
model = load_trained_model(join(checkpoint_dir, model_info["chkpt_name"]), model_info["extra_args"]) | |
model.set_full_train(True) | |
if model.image_model_type == "dinov2": | |
load_size = cfg.load_size * 2 | |
else: | |
load_size = cfg.load_size | |
if model.image_model_type == "davenet": | |
batch_size = cfg.batch_size // 2 | |
elif model.image_model_type == "imagebind": | |
batch_size = cfg.batch_size | |
else: | |
batch_size = cfg.batch_size | |
print(load_size) | |
data_args = dict( | |
dataset_name=cfg.dataset_name, | |
load_size=load_size, | |
image_aug=cfg.image_aug, | |
audio_aug=cfg.audio_aug, | |
audio_model_type=model.audio_model_type, | |
pytorch_data_dir=cfg.pytorch_data_dir, | |
use_cached_embs=model.use_cached_embs, | |
batch_size=batch_size, | |
num_workers=cfg.num_workers, | |
extra_audio_masking=False, | |
use_original_val_set=False, | |
use_extra_val_sets=True, | |
use_caption=True, | |
data_for_plotting=False, | |
n_frames=None, | |
audio_level=False, | |
neg_audio=False, | |
quad_mixup=0.0, | |
bg_mixup=0.0, | |
patch_mixup=0.0, | |
patch_size=8, | |
) | |
data_args = {**data_args, **extra_data_args} | |
datamodule = AVDataModule(**data_args) | |
log_dir = join(cfg.output_root, "logs", "evaluate", model_name) | |
print(log_dir) | |
tb_logger = TensorBoardLogger(log_dir, default_hp_metric=False) | |
trainer = Trainer( | |
accelerator='gpu', | |
strategy="ddp", | |
devices=cfg.num_gpus, | |
logger=tb_logger) | |
trainer.validate(model, datamodule) | |
if __name__ == "__main__": | |
my_app() | |