PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
d28af7f verified
raw
history blame
3.94 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import glob
import argparse
import pprint
import omegaconf
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from mmpt.utils import load_config, set_seed
from mmpt.evaluators import Evaluator
from mmpt.evaluators import predictor as predictor_path
from mmpt.tasks import Task
from mmpt import processors
from mmpt.datasets import MMDataset
def get_dataloader(config):
meta_processor_cls = getattr(processors, config.dataset.meta_processor)
video_processor_cls = getattr(processors, config.dataset.video_processor)
text_processor_cls = getattr(processors, config.dataset.text_processor)
aligner_cls = getattr(processors, config.dataset.aligner)
meta_processor = meta_processor_cls(config.dataset)
video_processor = video_processor_cls(config.dataset)
text_processor = text_processor_cls(config.dataset)
aligner = aligner_cls(config.dataset)
test_data = MMDataset(
meta_processor,
video_processor,
text_processor,
aligner,
)
print("test_len", len(test_data))
output = test_data[0]
test_data.print_example(output)
test_dataloader = DataLoader(
test_data,
batch_size=config.fairseq.dataset.batch_size,
shuffle=False,
num_workers=6,
collate_fn=test_data.collater,
)
return test_dataloader
def main(args):
config = load_config(args)
if isinstance(config, omegaconf.dictconfig.DictConfig):
print(OmegaConf.to_yaml(config))
else:
pp = pprint.PrettyPrinter(indent=4)
pp.print(config)
mmtask = Task.config_task(config)
mmtask.build_model()
test_dataloader = get_dataloader(config)
checkpoint_search_path = os.path.dirname(config.eval.save_path)
results = []
prefix = os.path.basename(args.taskconfig)
if prefix.startswith("test"):
# loop all checkpoint for datasets without validation set.
if "best" not in config.fairseq.common_eval.path:
print("eval each epoch.")
for checkpoint in glob.glob(checkpoint_search_path + "/checkpoint*"):
model = mmtask.load_checkpoint(checkpoint)
ckpt = os.path.basename(checkpoint)
evaluator = Evaluator(config)
output = evaluator.evaluate(
model, test_dataloader, ckpt + "_merged")
results.append((checkpoint, output))
# use the one specified by the config lastly.
model = mmtask.load_checkpoint(config.fairseq.common_eval.path)
evaluator = Evaluator(config)
output = evaluator.evaluate(model, test_dataloader)
results.append((config.fairseq.common_eval.path, output))
best_result = None
best_metric = 0.
for checkpoint, result in results:
print(checkpoint)
evaluator.metric.print_computed_metrics(result)
best_score = evaluator.metric.best_metric(result)
if best_score > best_metric:
best_result = (checkpoint, result)
best_metric = best_score
print("best results:")
print(best_result[0])
evaluator.metric.print_computed_metrics(best_result[1])
elif prefix.startswith("vis"):
model = mmtask.load_checkpoint(config.fairseq.common_eval.path)
predictor_cls = getattr(predictor_path, config.predictor)
predictor = predictor_cls(config)
predictor.predict_loop(model, test_dataloader, mmtask, None)
else:
raise ValueError("unknown prefix of the config file", args.taskconfig)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("taskconfig", type=str)
args = parser.parse_args()
main(args)