|
|
|
|
|
|
|
|
|
import os |
|
import glob |
|
import argparse |
|
import pprint |
|
import omegaconf |
|
|
|
from omegaconf import OmegaConf |
|
from torch.utils.data import DataLoader |
|
|
|
from mmpt.utils import load_config, set_seed |
|
from mmpt.evaluators import Evaluator |
|
from mmpt.evaluators import predictor as predictor_path |
|
from mmpt.tasks import Task |
|
from mmpt import processors |
|
from mmpt.datasets import MMDataset |
|
|
|
|
|
def get_dataloader(config): |
|
meta_processor_cls = getattr(processors, config.dataset.meta_processor) |
|
video_processor_cls = getattr(processors, config.dataset.video_processor) |
|
text_processor_cls = getattr(processors, config.dataset.text_processor) |
|
aligner_cls = getattr(processors, config.dataset.aligner) |
|
|
|
meta_processor = meta_processor_cls(config.dataset) |
|
video_processor = video_processor_cls(config.dataset) |
|
text_processor = text_processor_cls(config.dataset) |
|
aligner = aligner_cls(config.dataset) |
|
|
|
test_data = MMDataset( |
|
meta_processor, |
|
video_processor, |
|
text_processor, |
|
aligner, |
|
) |
|
print("test_len", len(test_data)) |
|
output = test_data[0] |
|
test_data.print_example(output) |
|
|
|
test_dataloader = DataLoader( |
|
test_data, |
|
batch_size=config.fairseq.dataset.batch_size, |
|
shuffle=False, |
|
num_workers=6, |
|
collate_fn=test_data.collater, |
|
) |
|
return test_dataloader |
|
|
|
|
|
def main(args): |
|
config = load_config(args) |
|
|
|
if isinstance(config, omegaconf.dictconfig.DictConfig): |
|
print(OmegaConf.to_yaml(config)) |
|
else: |
|
pp = pprint.PrettyPrinter(indent=4) |
|
pp.print(config) |
|
|
|
mmtask = Task.config_task(config) |
|
mmtask.build_model() |
|
|
|
test_dataloader = get_dataloader(config) |
|
checkpoint_search_path = os.path.dirname(config.eval.save_path) |
|
results = [] |
|
|
|
prefix = os.path.basename(args.taskconfig) |
|
if prefix.startswith("test"): |
|
|
|
if "best" not in config.fairseq.common_eval.path: |
|
print("eval each epoch.") |
|
for checkpoint in glob.glob(checkpoint_search_path + "/checkpoint*"): |
|
model = mmtask.load_checkpoint(checkpoint) |
|
ckpt = os.path.basename(checkpoint) |
|
evaluator = Evaluator(config) |
|
output = evaluator.evaluate( |
|
model, test_dataloader, ckpt + "_merged") |
|
results.append((checkpoint, output)) |
|
|
|
model = mmtask.load_checkpoint(config.fairseq.common_eval.path) |
|
evaluator = Evaluator(config) |
|
output = evaluator.evaluate(model, test_dataloader) |
|
results.append((config.fairseq.common_eval.path, output)) |
|
|
|
best_result = None |
|
best_metric = 0. |
|
for checkpoint, result in results: |
|
print(checkpoint) |
|
evaluator.metric.print_computed_metrics(result) |
|
best_score = evaluator.metric.best_metric(result) |
|
if best_score > best_metric: |
|
best_result = (checkpoint, result) |
|
best_metric = best_score |
|
print("best results:") |
|
print(best_result[0]) |
|
evaluator.metric.print_computed_metrics(best_result[1]) |
|
|
|
elif prefix.startswith("vis"): |
|
model = mmtask.load_checkpoint(config.fairseq.common_eval.path) |
|
predictor_cls = getattr(predictor_path, config.predictor) |
|
predictor = predictor_cls(config) |
|
predictor.predict_loop(model, test_dataloader, mmtask, None) |
|
else: |
|
raise ValueError("unknown prefix of the config file", args.taskconfig) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("taskconfig", type=str) |
|
args = parser.parse_args() |
|
main(args) |
|
|