UnIVAL / evaluate.py
mshukor
init
26fd00c
raw
history blame
9.34 kB
#!/usr/bin/env python3 -u
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import logging
import os
import sys
import numpy as np
import torch
from fairseq import distributed_utils, options, tasks, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import progress_bar
from fairseq.utils import reset_logging
from omegaconf import DictConfig
from utils import checkpoint_utils
from utils.eval_utils import eval_step, merge_results
from utils.zero_shot_utils import zero_shot_step
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("ofa.evaluate")
from utils.utils import print_trainable_params_percentage, setup_for_distributed
def apply_half(t):
if t.dtype is torch.float32:
return t.to(dtype=torch.half)
return t
def main(cfg: DictConfig, **kwargs):
utils.import_user_module(cfg.common)
setup_for_distributed(distributed_utils.is_master(cfg.distributed_training))
reset_logging()
# logger.info(cfg)
assert (
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
), "Must specify batch size either with --max-tokens or --batch-size"
# Fix seed for stochastic decoding
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
np.random.seed(cfg.common.seed)
utils.set_torch_seed(cfg.common.seed)
use_fp16 = cfg.common.fp16
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
if use_cuda:
torch.cuda.set_device(cfg.distributed_training.device_id)
# Load ensemble
overrides = eval(cfg.common_eval.model_overrides)
# Deal with beam-search / all-candidate VQA eval
if cfg.task._name == "vqa_gen":
overrides['val_inference_type'] = "beamsearch" if kwargs['beam_search_vqa_eval'] else "allcand"
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
# print("cfg", cfg)
# print(kwargs)
# cfg.model.num_frames = kwargs["num_frames"]
# cfg.model.patch_frame_size = kwargs["patch_frame_size"]
# print("cfg.model", cfg.model)
# strict = getattr(kwargs, 'strict', True)
strict = kwargs['strict']
logger.info('load checkpoint, strict:{}'.format(strict))
if kwargs["zero_shot"]:
for arg_name, arg_val in overrides.items():
cfg.task[arg_name] = arg_val
# print("Zero-shot eval", cfg.task, cfg)
if hasattr(cfg.task, "add_caption"):
cfg.task.add_caption = False
print("cfg.task", cfg.task)
task = tasks.setup_task(cfg.task)
# cfg.criterion.sample_patch_num = 776
models, saved_cfg = checkpoint_utils.load_model_ensemble(
utils.split_paths(cfg.common_eval.path),
arg_overrides=overrides,
task=task,
suffix=cfg.checkpoint.checkpoint_suffix,
strict=((cfg.checkpoint.checkpoint_shard_count == 1) and strict),
num_shards=cfg.checkpoint.checkpoint_shard_count,
)
for m in models:
m.encoder.sample_patch_num = 776
saved_cfg.task = cfg.task
# print("saved_cfg", saved_cfg)
else:
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
utils.split_paths(cfg.common_eval.path),
arg_overrides=overrides,
suffix=cfg.checkpoint.checkpoint_suffix,
strict=((cfg.checkpoint.checkpoint_shard_count == 1) and strict),
num_shards=cfg.checkpoint.checkpoint_shard_count,
)
# task.cfg['evaluate_cfg'] = cfg.task
# print(task.cfg)
kwargs['evaluate_cfg'] = cfg.task
# print(kwargs)
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
# Move models to GPU
for model, ckpt_path in zip(models, utils.split_paths(cfg.common_eval.path)):
if kwargs['ema_eval']:
logger.info("loading EMA weights from {}".format(ckpt_path))
model.load_state_dict(checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model'])
model.eval()
print("use fp16", use_fp16)
if use_fp16:
model.half()
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
model.cuda()
model.prepare_for_inference_(cfg)
# Load dataset (possibly sharded)
itr = task.get_batch_iterator(
dataset=task.dataset(cfg.dataset.gen_subset),
max_tokens=cfg.dataset.max_tokens,
max_sentences=cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions(
task.max_positions(), *[m.max_positions() for m in models]
),
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
seed=cfg.common.seed,
num_shards=cfg.distributed_training.distributed_world_size,
shard_id=cfg.distributed_training.distributed_rank,
num_workers=cfg.dataset.num_workers,
data_buffer_size=cfg.dataset.data_buffer_size,
).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar(
itr,
log_format=cfg.common.log_format,
log_interval=cfg.common.log_interval,
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
)
# Initialize generator
generator = task.build_generator(models, cfg.generation)
results = []
score_sum = torch.FloatTensor([0]).cuda()
score_cnt = torch.FloatTensor([0]).cuda()
score_sum_list = []
score_cnt_list = []
for sample in progress:
if "net_input" not in sample:
continue
sample = utils.move_to_cuda(sample) if use_cuda else sample
sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
with torch.no_grad():
if kwargs["zero_shot"] and kwargs['noconstraints']:
result, scores = zero_shot_step(task, generator, models, sample)
else:
result, scores = eval_step(task, generator, models, sample, **kwargs)
### else refcoco res, score, other_scores
# print(scores)
scalar = False
if isinstance(scores, list):
if not isinstance(scores[0], list):
try:
tmp = sum(scores[0])
scalar=False
except:
scalar=True
# print(scalar)
# print(sum(scores[0]))
if isinstance(scores, list) and not scalar:
names = result[0]
result = result[1]
if len(score_sum_list) == 0:
score_sum_list = [torch.FloatTensor([0]).cuda() for i in range(len(scores))]
score_cnt_list = [torch.FloatTensor([0]).cuda() for i in range(len(scores))]
for i in range(len(scores)):
score_sum_list[i] += sum(scores[i]) if scores[i] is not None else 0
score_cnt_list[i] += len(scores[i]) if scores[i] is not None else 0
else:
for i in range(len(scores)):
score_sum_list[i] += sum(scores[i]) if scores[i] is not None else 0
score_cnt_list[i] += len(scores[i]) if scores[i] is not None else 0
else:
score_sum += sum(scores) if scores is not None else 0
score_cnt += len(scores) if scores is not None else 0
results += result
progress.log({"sentences": sample["nsentences"]})
### merge per metric
if len(score_sum_list) > 0:
print(names, len(score_sum_list))
for i in range(len(score_sum_list)):
print(names[i])
merge_results(task, cfg, logger, score_cnt_list[i], score_sum_list[i], results)
else:
merge_results(task, cfg, logger, score_cnt, score_sum, results)
def cli_main():
parser = options.get_generation_parser()
parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.")
parser.add_argument("--beam-search-vqa-eval", action='store_true', help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.")
parser.add_argument("--zero-shot", action='store_true')
parser.add_argument("--strict", action='store_false')
parser.add_argument("--noconstraints", action='store_true')
args = options.parse_args_and_arch(parser)
cfg = convert_namespace_to_omegaconf(args)
distributed_utils.call_main(
cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval,
zero_shot=args.zero_shot, strict=args.strict, noconstraints=args.noconstraints
)
if __name__ == "__main__":
cli_main()