Spaces:
Running
Running
from .text_base import TextBaseDataset | |
from .utils import build_judge, DEBUG_MESSAGE | |
from ..smp import * | |
class TextMCQDataset(TextBaseDataset): | |
TYPE = 'MCQ' | |
DATASET_URL = {} | |
DATASET_MD5 = {} | |
def build_prompt(self, line): | |
if isinstance(line, int): | |
line = self.data.iloc[line] | |
question = line['question'] | |
options = { | |
cand: line[cand] | |
for cand in string.ascii_uppercase | |
if cand in line and not pd.isna(line[cand]) | |
} | |
options_prompt = 'Options:\n' | |
for key, item in options.items(): | |
options_prompt += f'{key}. {item}\n' | |
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None | |
prompt = '' | |
if hint is not None: | |
prompt += f'Hint: {hint}\n' | |
prompt += f'Question: {question}\n' | |
if len(options): | |
prompt += options_prompt | |
prompt += 'Please select the correct answer from the options above. \n' | |
msgs = [] | |
msgs.append(dict(type='text', value=prompt)) | |
return msgs | |
def evaluate(self, eval_file, **judge_kwargs): | |
from .utils.multiple_choice import report_acc, report_acc_MMT, mcq_circular_eval, mcq_vanilla_eval | |
# assert dataset is not None | |
dataset_map = { | |
'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_EN_V11': 'MMBench_V11', | |
'MMBench_TEST_CN': 'MMBench_CN', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11' | |
} | |
dataset = self.dataset_name | |
if dataset in dataset_map: | |
dataset = dataset_map[dataset] | |
nproc = judge_kwargs.pop('nproc', 4) | |
circular = False | |
suffix = eval_file.split('.')[-1] | |
model = judge_kwargs.get('model', 'exact_matching') | |
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125'] | |
name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'} | |
name_str = name_str_map[model] if model in name_str_map else model | |
if model == 'exact_matching': | |
model = None | |
elif gpt_key_set(): | |
model = build_judge(**judge_kwargs) | |
if not model.working(): | |
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation') | |
warnings.warn(DEBUG_MESSAGE) | |
model = None | |
else: | |
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation') | |
model = None | |
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl') | |
data = load(eval_file) | |
data = data.sort_values(by='index') | |
data['prediction'] = [str(x) for x in data['prediction']] | |
# If not choice label, then use lower case | |
for k in data.keys(): | |
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k) | |
meta = self.data | |
meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])} | |
data_map = {x: y for x, y in zip(data['index'], data['question'])} | |
for k in data_map: | |
assert k in meta_q_map, ( | |
f'eval_file should be the same as or a subset of dataset {self.dataset_name}' | |
) | |
if circular: | |
data = mcq_circular_eval(model, data, meta, nproc, result_file, self.dataset_name) | |
else: | |
data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name) | |
# load split | |
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}')) | |
data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}')) | |
# May have different report acc functions for different datasets | |
if 'MMT' in dataset: | |
acc = report_acc_MMT(data) | |
else: | |
acc = report_acc(data) | |
score_file = eval_file.replace(f'.{suffix}', '_acc.csv') | |
dump(acc, score_file) | |
return acc | |
class CustomTextMCQDataset(TextMCQDataset): | |
def load_data(self, dataset): | |
data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv') | |
if file_size(data_path, 'GB') > 1: | |
local_path = data_path.replace('.tsv', '_local.tsv') | |
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None): | |
from ..tools import LOCALIZE | |
LOCALIZE(data_path, local_path) | |
data_path = local_path | |
return load(data_path) | |