import torch import torch.distributed as dist from vlmeval.config import supported_VLM from vlmeval.utils import track_progress_rich from vlmeval.smp import * FAIL_MSG = 'Failed to obtain answer via API.' def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--data', type=str, nargs='+', required=True) parser.add_argument('--model', type=str, nargs='+', required=True) parser.add_argument('--nproc', type=int, default=4, required=True) parser.add_argument('--verbose', action='store_true') args = parser.parse_args() return args # Only API model is accepted def infer_data_api(work_dir, model_name, dataset, nframe=8, pack=False, samples_dict={}, api_nproc=4): rank, world_size = get_rank_and_world_size() assert rank == 0 and world_size == 1 dataset_name = dataset.dataset_name model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name assert getattr(model, 'is_api', False) indices = list(samples_dict.keys()) structs = [dataset.build_prompt(samples_dict[idx], num_frames=nframe, video_llm=getattr(model, 'VIDEO_LLM', False)) for idx in indices] packstr = 'pack' if pack else 'nopack' out_file = f'{work_dir}/{model_name}_{dataset_name}_{nframe}frame_{packstr}_supp.pkl' res = load(out_file) if osp.exists(out_file) else {} structs = [s for i, s in zip(indices, structs) if i not in res] indices = [i for i in indices if i not in res] gen_func = model.generate structs = [dict(message=struct, dataset=dataset_name) for struct in structs] if len(structs): track_progress_rich(gen_func, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices) res = load(out_file) return res def infer_data(model_name, work_dir, dataset, out_file, nframe=8, pack=False, verbose=False, api_nproc=4): res = load(out_file) if osp.exists(out_file) else {} rank, world_size = get_rank_and_world_size() dataset_name = dataset.dataset_name sample_indices = list(dataset.videos) if pack else list(dataset.data['index']) samples = list(dataset.videos) if pack else list(range(len(dataset.data))) sample_map = {i: s for i, s in zip(sample_indices, samples)} sample_indices_sub = sample_indices[rank::world_size] if np.all([idx in res for idx in sample_indices_sub]): return model_name sample_indices_subrem = [x for x in sample_indices_sub if x not in res] model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name is_api = getattr(model, 'is_api', False) if is_api: assert world_size == 1 supp = infer_data_api( work_dir=work_dir, model_name=model_name, dataset=dataset, nframe=nframe, pack=pack, samples_dict={k: sample_map[k] for k in sample_indices_subrem}, api_nproc=api_nproc) for k in sample_indices_subrem: assert k in supp res.update(supp) dump(res, out_file) return model_name for i, idx in tqdm(enumerate(sample_indices_subrem)): if idx in res: continue # adapt to model frame sample number first nframe = getattr(model, 'nframe', 0) if getattr(model, 'nframe', 0) > 0 else nframe # when using video-llm, build prompt returns video+question; otherwise, several frames+question struct = dataset.build_prompt(sample_map[idx], num_frames=nframe, video_llm=getattr(model, 'VIDEO_LLM', False)) response = model.generate(message=struct, dataset=dataset_name) torch.cuda.empty_cache() if verbose: print(response, flush=True) res[idx] = response if (i + 1) % 20 == 0: dump(res, out_file) res = {k: res[k] for k in sample_indices_sub} dump(res, out_file) return model # A wrapper for infer_data, do the pre & post processing def infer_data_job_video( model, work_dir, model_name, dataset, nframe=8, pack=False, verbose=False, subtitle=False, api_nproc=4): dataset_name = dataset.dataset_name packstr = 'pack' if pack else 'nopack' rank, world_size = get_rank_and_world_size() result_file = osp.join(work_dir, f'{model_name}_{dataset_name}_{nframe}frame_{packstr}.xlsx') if dataset_name == 'Video-MME': subtitle_str = 'subs' if subtitle else 'nosubs' result_file = result_file.replace('.xlsx', f'_{subtitle_str}.xlsx') # Dump Predictions to Prev File if result file exists if osp.exists(result_file): return model_name tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}_{nframe}frame_{packstr}.pkl') if dataset_name == 'Video-MME': subtitle_str = 'subs' if subtitle else 'nosubs' tmpl = tmpl.replace('.pkl', f'_{subtitle_str}.pkl') out_file = tmpl.format(rank) model = infer_data( model, work_dir=work_dir, dataset=dataset, nframe=nframe, pack=pack, out_file=out_file, verbose=verbose, api_nproc=api_nproc) if world_size > 1: dist.barrier() if rank == 0: data_all = {} for i in range(world_size): data_all.update(load(tmpl.format(i))) meta = dataset.data if dataset_name == 'MMBench-Video' and pack: meta, vstats = dataset.load_pack_answers(data_all) print(f'Statitics of Pack Video Inference: {vstats}') else: for x in meta['index']: assert x in data_all meta['prediction'] = [str(data_all[x]) for x in meta['index']] if 'image' in meta: meta.pop('image') dump(meta, result_file) for i in range(world_size): os.remove(tmpl.format(i)) return model