#!/usr/bin/env python3 -u # Copyright 2022 The OFA-Sys Team. # All rights reserved. # This source code is licensed under the Apache 2.0 license # found in the LICENSE file in the root directory. import logging import os import sys import numpy as np import torch from fairseq import distributed_utils, options, tasks, utils from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.utils import reset_logging from omegaconf import DictConfig from utils import checkpoint_utils from utils.eval_utils import eval_step, merge_results from utils.zero_shot_utils import zero_shot_step logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=sys.stdout, ) logger = logging.getLogger("ofa.evaluate") from utils.utils import print_trainable_params_percentage, setup_for_distributed def apply_half(t): if t.dtype is torch.float32: return t.to(dtype=torch.half) return t def main(cfg: DictConfig, **kwargs): utils.import_user_module(cfg.common) setup_for_distributed(distributed_utils.is_master(cfg.distributed_training)) reset_logging() # logger.info(cfg) assert ( cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" # Fix seed for stochastic decoding if cfg.common.seed is not None and not cfg.generation.no_seed_provided: np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) use_fp16 = cfg.common.fp16 use_cuda = torch.cuda.is_available() and not cfg.common.cpu if use_cuda: torch.cuda.set_device(cfg.distributed_training.device_id) # Load ensemble overrides = eval(cfg.common_eval.model_overrides) # Deal with beam-search / all-candidate VQA eval if cfg.task._name == "vqa_gen": overrides['val_inference_type'] = "beamsearch" if kwargs['beam_search_vqa_eval'] else "allcand" logger.info("loading model(s) from {}".format(cfg.common_eval.path)) # print("cfg", cfg) # print(kwargs) # cfg.model.num_frames = kwargs["num_frames"] # cfg.model.patch_frame_size = kwargs["patch_frame_size"] # print("cfg.model", cfg.model) # strict = getattr(kwargs, 'strict', True) strict = kwargs['strict'] logger.info('load checkpoint, strict:{}'.format(strict)) if kwargs["zero_shot"]: for arg_name, arg_val in overrides.items(): cfg.task[arg_name] = arg_val # print("Zero-shot eval", cfg.task, cfg) if hasattr(cfg.task, "add_caption"): cfg.task.add_caption = False print("cfg.task", cfg.task) task = tasks.setup_task(cfg.task) # cfg.criterion.sample_patch_num = 776 models, saved_cfg = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, suffix=cfg.checkpoint.checkpoint_suffix, strict=((cfg.checkpoint.checkpoint_shard_count == 1) and strict), num_shards=cfg.checkpoint.checkpoint_shard_count, ) for m in models: m.encoder.sample_patch_num = 776 saved_cfg.task = cfg.task # print("saved_cfg", saved_cfg) else: models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, suffix=cfg.checkpoint.checkpoint_suffix, strict=((cfg.checkpoint.checkpoint_shard_count == 1) and strict), num_shards=cfg.checkpoint.checkpoint_shard_count, ) # task.cfg['evaluate_cfg'] = cfg.task # print(task.cfg) kwargs['evaluate_cfg'] = cfg.task # print(kwargs) # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) # Move models to GPU for model, ckpt_path in zip(models, utils.split_paths(cfg.common_eval.path)): if kwargs['ema_eval']: logger.info("loading EMA weights from {}".format(ckpt_path)) model.load_state_dict(checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model']) model.eval() print("use fp16", use_fp16) if use_fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(cfg.dataset.gen_subset), max_tokens=cfg.dataset.max_tokens, max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models] ), ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, seed=cfg.common.seed, num_shards=cfg.distributed_training.distributed_world_size, shard_id=cfg.distributed_training.distributed_rank, num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) # Initialize generator generator = task.build_generator(models, cfg.generation) results = [] score_sum = torch.FloatTensor([0]).cuda() score_cnt = torch.FloatTensor([0]).cuda() score_sum_list = [] score_cnt_list = [] for sample in progress: if "net_input" not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample with torch.no_grad(): if kwargs["zero_shot"] and kwargs['noconstraints']: result, scores = zero_shot_step(task, generator, models, sample) else: result, scores = eval_step(task, generator, models, sample, **kwargs) ### else refcoco res, score, other_scores # print(scores) scalar = False if isinstance(scores, list): if not isinstance(scores[0], list): try: tmp = sum(scores[0]) scalar=False except: scalar=True # print(scalar) # print(sum(scores[0])) if isinstance(scores, list) and not scalar: names = result[0] result = result[1] if len(score_sum_list) == 0: score_sum_list = [torch.FloatTensor([0]).cuda() for i in range(len(scores))] score_cnt_list = [torch.FloatTensor([0]).cuda() for i in range(len(scores))] for i in range(len(scores)): score_sum_list[i] += sum(scores[i]) if scores[i] is not None else 0 score_cnt_list[i] += len(scores[i]) if scores[i] is not None else 0 else: for i in range(len(scores)): score_sum_list[i] += sum(scores[i]) if scores[i] is not None else 0 score_cnt_list[i] += len(scores[i]) if scores[i] is not None else 0 else: score_sum += sum(scores) if scores is not None else 0 score_cnt += len(scores) if scores is not None else 0 results += result progress.log({"sentences": sample["nsentences"]}) ### merge per metric if len(score_sum_list) > 0: print(names, len(score_sum_list)) for i in range(len(score_sum_list)): print(names[i]) merge_results(task, cfg, logger, score_cnt_list[i], score_sum_list[i], results) else: merge_results(task, cfg, logger, score_cnt, score_sum, results) def cli_main(): parser = options.get_generation_parser() parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.") parser.add_argument("--beam-search-vqa-eval", action='store_true', help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.") parser.add_argument("--zero-shot", action='store_true') parser.add_argument("--strict", action='store_false') parser.add_argument("--noconstraints", action='store_true') args = options.parse_args_and_arch(parser) cfg = convert_namespace_to_omegaconf(args) distributed_utils.call_main( cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval, zero_shot=args.zero_shot, strict=args.strict, noconstraints=args.noconstraints ) if __name__ == "__main__": cli_main()