Spaces:
Running
Running
from ...smp import * | |
from ...utils import can_infer | |
try: | |
from latex2sympy2 import latex2sympy | |
except ImportError: | |
print('Please install latex2sympy2 by running "pip install latex2sympy2"') | |
FAIL_MSG = 'Failed to obtain answer via API.' | |
def is_equal(asw: str, gt_asw: str) -> bool: | |
if not isinstance(asw, str) != str or not isinstance(gt_asw, str): | |
print('Warning: input is not string') | |
print(asw, gt_asw) | |
asw = str(asw).lower().strip() | |
gt_asw = str(gt_asw).lower().strip() | |
if gt_asw == asw: | |
return True | |
try: | |
a = eval(gt_asw) | |
b = eval(asw) | |
if abs(a - b) < 1e-6: | |
return True | |
except: | |
pass | |
try: | |
a = latex2sympy(gt_asw) | |
b = latex2sympy(asw) | |
if abs(eval(str(a)) - eval(str(b))) < 1e-6: | |
return True | |
if abs(a - b) < 1e-6: | |
return True | |
except: | |
pass | |
return False | |
def get_gpt4_ICE(): | |
example_1 = """ | |
Hint: Please answer the question and provide the final answer at the end.\n | |
Question: Which number is missing?\n | |
Model response: The number missing in the sequence is 14.\n | |
Extracted answer: 14 | |
""" | |
example_2 = """ | |
Hint: Please answer the question and provide the final answer at the end.\n | |
Question: What is the fraction of females facing the camera?\n | |
Model response: The fraction of females facing the camera is 0.6, | |
which means that six out of ten females in the group are facing the camera.\n | |
Extracted answer: 0.6 | |
""" | |
example_3 = """ | |
Hint: Please answer the question and provide the final answer at the end.\n | |
Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)\n | |
Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.\n | |
Extracted answer: 1.45 | |
""" | |
example_4 = """ | |
Hint: Please answer the question and provide the final answer at the end.\n | |
Question: Between which two years does the line graph saw its maximum peak?\n | |
Model response: The line graph saw its maximum peak between 2007 and 2008.\n | |
Extracted answer: [2007, 2008] | |
""" | |
example_5 = """ | |
Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n | |
Question: What fraction of the shape is blue?\n | |
Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5\n | |
Model response: The correct answer is (B) 8/11.\n | |
Extracted answer: B | |
""" | |
return [example_1, example_2, example_3, example_4, example_5] | |
def build_mathv_gpt4_prompt(line): | |
task_description = """ | |
Please read the following example. | |
Then extract the answer from the model response and type it at the end of the prompt.\n | |
""" | |
question = line['question'] | |
prediction = str(line['prediction']) | |
prompt = task_description | |
examples = get_gpt4_ICE() | |
for example in examples: | |
prompt += example + '\n' | |
prompt += question + '\n' | |
prompt += 'Model respone: ' + prediction | |
prompt += 'Extracted answer:' | |
return prompt | |
def list_to_dict(lst): | |
return {chr(65 + i): val for i, val in enumerate(lst)} | |
def post_check(line, prefetch=False): | |
res = None | |
ans = line['answer'] | |
response = line['prediction'] if prefetch else line['res'] | |
try: | |
if len(eval(line['choices'])) > 0: | |
ans = line['answer'] | |
choices = list_to_dict(eval(line['choices'])) | |
res = can_infer(response, choices) | |
if prefetch: | |
return res | |
else: | |
res = str(response) | |
ans = str(ans) | |
except ValueError: | |
pass | |
if is_equal(res, ans): | |
return res if prefetch else True | |
else: | |
return False | |
def MATH_V_auxeval(model, line): | |
prompt = build_mathv_gpt4_prompt(line) | |
log = '' | |
retry = 5 | |
if post_check(line, prefetch=True): | |
res = post_check(line, prefetch=True) | |
return dict(log='Prefetch succeed', res=res) | |
for i in range(retry): | |
prediction = line['prediction'] | |
res = model.generate(prompt, temperature=i * 0.5) | |
if FAIL_MSG in res: | |
log += f'Try {i}: output is {prediction}, failed to parse.\n' | |
else: | |
log += 'Succeed' | |
return dict(log=log, res=res) | |
log += 'All 5 retries failed.\n' | |
return dict(log=log, res='') | |
def MATH_V_acc(result_file): | |
data = load(result_file) | |
tot = defaultdict(lambda: 0) | |
fetch = defaultdict(lambda: 0) | |
hit = defaultdict(lambda: 0) | |
lt = len(data) | |
for i in range(lt): | |
item = data.iloc[i] | |
cate = item['category'] | |
tot['Overall'] += 1 | |
tot[cate] += 1 | |
if item['log'] == 'Prefetch succeed': | |
fetch['Overall'] += 1 | |
fetch[cate] += 1 | |
if post_check(item, prefetch=False): | |
hit['Overall'] += 1 | |
hit[cate] += 1 | |
res = defaultdict(list) | |
for k in tot.keys(): | |
res['Subject'].append(k) | |
res['tot'].append(tot[k]) | |
res['prefetch'].append(fetch[k]) | |
res['hit'].append(hit[k]) | |
res['prefetch_rate'].append(fetch[k] / tot[k] * 100) | |
res['acc'].append(hit[k] / tot[k] * 100) | |
res = pd.DataFrame(res).sort_values('Subject', ignore_index=True) | |
return res | |