Spaces:
Sleeping
Sleeping
File size: 2,990 Bytes
e6d4b46 099ac14 e6d4b46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
from os.path import join
import hydra
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from DenseAV.denseav.data.AVDatasets import AVDataModule
from DenseAV.denseav.shared import load_trained_model
@hydra.main(config_path="configs", config_name="av_align.yaml")
def my_app(cfg: DictConfig) -> None:
from saved_models import saved_model_dict
seed_everything(0)
print(OmegaConf.to_yaml(cfg))
models_to_eval = [
"denseav_language",
"denseav_sound",
]
checkpoint_dir = "../checkpoints"
saved_models = saved_model_dict(checkpoint_dir)
for model_name in models_to_eval:
model_info = saved_models[model_name]
extra_data_args = model_info["data_args"] if "data_args" in model_info else {}
model_info["extra_args"]["output_root"] = "../"
model_info["extra_args"]["neg_audio"] = False
model_info["extra_args"]["image_mixup"] = 0.0
model = load_trained_model(join(checkpoint_dir, model_info["chkpt_name"]), model_info["extra_args"])
model.set_full_train(True)
if model.image_model_type == "dinov2":
load_size = cfg.load_size * 2
else:
load_size = cfg.load_size
if model.image_model_type == "davenet":
batch_size = cfg.batch_size // 2
elif model.image_model_type == "imagebind":
batch_size = cfg.batch_size
else:
batch_size = cfg.batch_size
print(load_size)
data_args = dict(
dataset_name=cfg.dataset_name,
load_size=load_size,
image_aug=cfg.image_aug,
audio_aug=cfg.audio_aug,
audio_model_type=model.audio_model_type,
pytorch_data_dir=cfg.pytorch_data_dir,
use_cached_embs=model.use_cached_embs,
batch_size=batch_size,
num_workers=cfg.num_workers,
extra_audio_masking=False,
use_original_val_set=False,
use_extra_val_sets=True,
use_caption=True,
data_for_plotting=False,
n_frames=None,
audio_level=False,
neg_audio=False,
quad_mixup=0.0,
bg_mixup=0.0,
patch_mixup=0.0,
patch_size=8,
)
data_args = {**data_args, **extra_data_args}
datamodule = AVDataModule(**data_args)
log_dir = join(cfg.output_root, "logs", "evaluate", model_name)
print(log_dir)
tb_logger = TensorBoardLogger(log_dir, default_hp_metric=False)
trainer = Trainer(
accelerator='gpu',
strategy="ddp",
devices=cfg.num_gpus,
logger=tb_logger)
trainer.validate(model, datamodule)
if __name__ == "__main__":
my_app()
|