Demo750's picture
Upload folder using huggingface_hub
569f484 verified
import torch
import torch.distributed as dist
from vlmeval.config import supported_VLM
from vlmeval.utils import track_progress_rich
from vlmeval.smp import *
FAIL_MSG = 'Failed to obtain answer via API.'
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, nargs='+', required=True)
parser.add_argument('--model', type=str, nargs='+', required=True)
parser.add_argument('--nproc', type=int, default=4, required=True)
parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
return args
# Only API model is accepted
def infer_data_api(work_dir, model_name, dataset, index_set=None, api_nproc=4, ignore_failed=False):
rank, world_size = get_rank_and_world_size()
assert rank == 0 and world_size == 1
dataset_name = dataset.dataset_name
data = dataset.data
if index_set is not None:
data = data[data['index'].isin(index_set)]
model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
assert getattr(model, 'is_api', False)
lt, indices = len(data), list(data['index'])
structs = [dataset.build_prompt(data.iloc[i]) for i in range(lt)]
out_file = f'{work_dir}/{model_name}_{dataset_name}_supp.pkl'
res = {}
if osp.exists(out_file):
res = load(out_file)
if ignore_failed:
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
structs = [s for i, s in zip(indices, structs) if i not in res]
indices = [i for i in indices if i not in res]
gen_func = model.generate
structs = [dict(message=struct, dataset=dataset_name) for struct in structs]
if len(structs):
track_progress_rich(gen_func, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices)
res = load(out_file)
if index_set is not None:
res = {k: v for k, v in res.items() if k in index_set}
os.remove(out_file)
return res
def infer_data(model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4):
dataset_name = dataset.dataset_name
prev_file = f'{work_dir}/{model_name}_{dataset_name}_PREV.pkl'
res = load(prev_file) if osp.exists(prev_file) else {}
if osp.exists(out_file):
res.update(load(out_file))
rank, world_size = get_rank_and_world_size()
sheet_indices = list(range(rank, len(dataset), world_size))
lt = len(sheet_indices)
data = dataset.data.iloc[sheet_indices]
data_indices = [i for i in data['index']]
# If finished, will exit without building the model
all_finished = True
for i in range(lt):
idx = data.iloc[i]['index']
if idx not in res:
all_finished = False
if all_finished:
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return
# Data need to be inferred
data = data[~data['index'].isin(res)]
lt = len(data)
model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
is_api = getattr(model, 'is_api', False)
if is_api:
lt, indices = len(data), list(data['index'])
supp = infer_data_api(
work_dir=work_dir,
model_name=model_name,
dataset=dataset,
index_set=set(indices),
api_nproc=api_nproc)
for idx in indices:
assert idx in supp
res.update(supp)
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return model_name
else:
model.set_dump_image(dataset.dump_image)
for i in tqdm(range(lt)):
idx = data.iloc[i]['index']
if idx in res:
continue
if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
struct = model.build_prompt(data.iloc[i], dataset=dataset_name)
else:
struct = dataset.build_prompt(data.iloc[i])
response = model.generate(message=struct, dataset=dataset_name)
torch.cuda.empty_cache()
if verbose:
print(response, flush=True)
res[idx] = response
if (i + 1) % 20 == 0:
dump(res, out_file)
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return model
# A wrapper for infer_data, do the pre & post processing
def infer_data_job(model, work_dir, model_name, dataset, verbose=False, api_nproc=4, ignore_failed=False):
rank, world_size = get_rank_and_world_size()
dataset_name = dataset.dataset_name
result_file = osp.join(work_dir, f'{model_name}_{dataset_name}.xlsx')
prev_file = f'{work_dir}/{model_name}_{dataset_name}_PREV.pkl'
if osp.exists(result_file):
if rank == 0:
data = load(result_file)
results = {k: v for k, v in zip(data['index'], data['prediction'])}
if not ignore_failed:
results = {k: v for k, v in results.items() if FAIL_MSG not in str(v)}
dump(results, prev_file)
if world_size > 1:
dist.barrier()
tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}.pkl')
out_file = tmpl.format(rank)
model = infer_data(
model, work_dir=work_dir, dataset=dataset, out_file=out_file, verbose=verbose, api_nproc=api_nproc)
if world_size > 1:
dist.barrier()
if rank == 0:
data_all = {}
for i in range(world_size):
data_all.update(load(tmpl.format(i)))
data = dataset.data
for x in data['index']:
assert x in data_all
data['prediction'] = [str(data_all[x]) for x in data['index']]
if 'image' in data:
data.pop('image')
dump(data, result_file)
for i in range(world_size):
os.remove(tmpl.format(i))
return model