# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os import glob import argparse import pprint import omegaconf from omegaconf import OmegaConf from torch.utils.data import DataLoader from mmpt.utils import load_config, set_seed from mmpt.evaluators import Evaluator from mmpt.evaluators import predictor as predictor_path from mmpt.tasks import Task from mmpt import processors from mmpt.datasets import MMDataset def get_dataloader(config): meta_processor_cls = getattr(processors, config.dataset.meta_processor) video_processor_cls = getattr(processors, config.dataset.video_processor) text_processor_cls = getattr(processors, config.dataset.text_processor) aligner_cls = getattr(processors, config.dataset.aligner) meta_processor = meta_processor_cls(config.dataset) video_processor = video_processor_cls(config.dataset) text_processor = text_processor_cls(config.dataset) aligner = aligner_cls(config.dataset) test_data = MMDataset( meta_processor, video_processor, text_processor, aligner, ) print("test_len", len(test_data)) output = test_data[0] test_data.print_example(output) test_dataloader = DataLoader( test_data, batch_size=config.fairseq.dataset.batch_size, shuffle=False, num_workers=6, collate_fn=test_data.collater, ) return test_dataloader def main(args): config = load_config(args) if isinstance(config, omegaconf.dictconfig.DictConfig): print(OmegaConf.to_yaml(config)) else: pp = pprint.PrettyPrinter(indent=4) pp.print(config) mmtask = Task.config_task(config) mmtask.build_model() test_dataloader = get_dataloader(config) checkpoint_search_path = os.path.dirname(config.eval.save_path) results = [] prefix = os.path.basename(args.taskconfig) if prefix.startswith("test"): # loop all checkpoint for datasets without validation set. if "best" not in config.fairseq.common_eval.path: print("eval each epoch.") for checkpoint in glob.glob(checkpoint_search_path + "/checkpoint*"): model = mmtask.load_checkpoint(checkpoint) ckpt = os.path.basename(checkpoint) evaluator = Evaluator(config) output = evaluator.evaluate( model, test_dataloader, ckpt + "_merged") results.append((checkpoint, output)) # use the one specified by the config lastly. model = mmtask.load_checkpoint(config.fairseq.common_eval.path) evaluator = Evaluator(config) output = evaluator.evaluate(model, test_dataloader) results.append((config.fairseq.common_eval.path, output)) best_result = None best_metric = 0. for checkpoint, result in results: print(checkpoint) evaluator.metric.print_computed_metrics(result) best_score = evaluator.metric.best_metric(result) if best_score > best_metric: best_result = (checkpoint, result) best_metric = best_score print("best results:") print(best_result[0]) evaluator.metric.print_computed_metrics(best_result[1]) elif prefix.startswith("vis"): model = mmtask.load_checkpoint(config.fairseq.common_eval.path) predictor_cls = getattr(predictor_path, config.predictor) predictor = predictor_cls(config) predictor.predict_loop(model, test_dataloader, mmtask, None) else: raise ValueError("unknown prefix of the config file", args.taskconfig) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("taskconfig", type=str) args = parser.parse_args() main(args)