|
""" |
|
This file come from: https://github.com/microsoft/ToRA/blob/main/src/utils/parser.py |
|
""" |
|
import re |
|
from typing import Any, Dict |
|
|
|
|
|
def _fix_fracs(string): |
|
substrs = string.split("\\frac") |
|
new_str = substrs[0] |
|
if len(substrs) > 1: |
|
substrs = substrs[1:] |
|
for substr in substrs: |
|
new_str += "\\frac" |
|
if len(substr) > 0 and substr[0] == "{": |
|
new_str += substr |
|
else: |
|
try: |
|
assert len(substr) >= 2 |
|
except: |
|
return string |
|
a = substr[0] |
|
b = substr[1] |
|
if b != "{": |
|
if len(substr) > 2: |
|
post_substr = substr[2:] |
|
new_str += "{" + a + "}{" + b + "}" + post_substr |
|
else: |
|
new_str += "{" + a + "}{" + b + "}" |
|
else: |
|
if len(substr) > 2: |
|
post_substr = substr[2:] |
|
new_str += "{" + a + "}" + b + post_substr |
|
else: |
|
new_str += "{" + a + "}" + b |
|
string = new_str |
|
return string |
|
|
|
|
|
def _fix_a_slash_b(string): |
|
if len(string.split("/")) != 2: |
|
return string |
|
a = string.split("/")[0] |
|
b = string.split("/")[1] |
|
try: |
|
if "sqrt" not in a: |
|
a = int(a) |
|
if "sqrt" not in b: |
|
b = int(b) |
|
assert string == "{}/{}".format(a, b) |
|
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" |
|
return new_string |
|
except: |
|
return string |
|
|
|
|
|
def _fix_sqrt(string): |
|
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) |
|
return _string |
|
|
|
|
|
def strip_string(string): |
|
string = str(string).strip() |
|
|
|
string = string.replace("\n", "") |
|
|
|
|
|
string = string.rstrip(".") |
|
|
|
|
|
string = string.replace("\\!", "") |
|
string = string.replace("\\ ", "") |
|
|
|
|
|
string = string.replace("\\\\", "\\") |
|
string = string.replace("\\\\", "\\") |
|
|
|
|
|
string = string.replace("tfrac", "frac") |
|
string = string.replace("dfrac", "frac") |
|
|
|
|
|
string = string.replace("\\left", "") |
|
string = string.replace("\\right", "") |
|
|
|
|
|
_string = re.sub(r"\\text{.*?}$", "", string).strip() |
|
if _string != "" and _string != string: |
|
|
|
string = _string |
|
|
|
|
|
string = string.replace("^{\\circ}", "") |
|
string = string.replace("^\\circ", "") |
|
|
|
|
|
string = string.replace("\\$", "") |
|
string = string.replace("$", "") |
|
|
|
string = string.replace("\\text", "") |
|
string = string.replace("x\\in", "") |
|
|
|
|
|
string = string.replace("\\%", "") |
|
string = string.replace("\%", "") |
|
string = string.replace("%", "") |
|
|
|
|
|
string = string.replace(" .", " 0.") |
|
string = string.replace("{.", "{0.") |
|
|
|
|
|
string = string.replace("\\cdot", "") |
|
|
|
|
|
string = string.replace("infinity", "\\infty") |
|
if "\\infty" not in string: |
|
string = string.replace("inf", "\\infty") |
|
string = string.replace("+\\inity", "\\infty") |
|
|
|
|
|
string = string.replace("and", "") |
|
string = string.replace("\\mathbf", "") |
|
|
|
|
|
string = re.sub(r"\\mbox{.*?}", "", string) |
|
|
|
|
|
string.replace("'", "") |
|
string.replace("\"", "") |
|
|
|
|
|
if "j" in string and "i" not in string: |
|
string = string.replace("j", "i") |
|
|
|
|
|
string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) |
|
string = re.sub(r"(\d+)\.0+$", r"\1", string) |
|
|
|
|
|
if len(string) == 0: |
|
return string |
|
if string[0] == ".": |
|
string = "0" + string |
|
|
|
|
|
if len(string.split("=")) == 2: |
|
if len(string.split("=")[0]) <= 2: |
|
string = string.split("=")[1] |
|
|
|
string = _fix_sqrt(string) |
|
string = string.replace(" ", "") |
|
|
|
|
|
string = _fix_fracs(string) |
|
|
|
|
|
string = _fix_a_slash_b(string) |
|
|
|
return string |
|
|
|
def extract_answer(pred_str): |
|
if 'boxed' in pred_str: |
|
ans = pred_str.split('boxed')[-1] |
|
if len(ans) == 0: |
|
return "" |
|
elif (ans[0] == '{'): |
|
stack = 1 |
|
a = '' |
|
for c in ans[1:]: |
|
if (c == '{'): |
|
stack += 1 |
|
a += c |
|
elif (c == '}'): |
|
stack -= 1 |
|
if (stack == 0): break |
|
a += c |
|
else: |
|
a += c |
|
else: |
|
a = ans.split('$')[0].strip() |
|
pred=a |
|
elif ('he answer is' in pred_str): |
|
pred = pred_str.split('he answer is')[-1].strip() |
|
elif extract_program_output(pred_str) != "": |
|
|
|
pred = extract_program_output(pred_str) |
|
else: |
|
pattern = '-?\d*\.?\d+' |
|
pred = re.findall(pattern, pred_str.replace(",", "")) |
|
if(len(pred) >= 1): |
|
pred = pred[-1] |
|
else: pred = '' |
|
|
|
|
|
pred = pred.split("\n")[0] |
|
if pred != "" and pred[0] == ":": |
|
pred = pred[1:] |
|
if pred != "" and pred[-1] == ".": |
|
pred = pred[:-1] |
|
if pred != "" and pred[-1] == "/": |
|
pred = pred[:-1] |
|
pred = strip_string(pred) |
|
return pred |
|
|
|
|
|
def extract_program(result: str, last_only=True): |
|
""" |
|
extract the program after "```python", and before "```" |
|
""" |
|
program = "" |
|
start = False |
|
for line in result.split("\n"): |
|
if line.startswith("```python"): |
|
if last_only: |
|
program = "" |
|
else: |
|
program += "\n# ========\n" |
|
start = True |
|
elif line.startswith("```"): |
|
start = False |
|
elif start: |
|
program += line + "\n" |
|
return program |
|
|
|
|
|
def extract_program_output(pred_str): |
|
""" |
|
extract output between the last ```output\n...\n``` |
|
""" |
|
if "```output" not in pred_str: |
|
return "" |
|
if '```output' in pred_str: |
|
pred_str = pred_str.split('```output')[-1] |
|
if '```' in pred_str: |
|
pred_str = pred_str.split('```')[0] |
|
output = pred_str.strip() |
|
return output |
|
|
|
|
|
def parse_ground_truth(example: Dict[str, Any], data_name): |
|
if 'gt_cot' in example: |
|
return example['gt_cot'], strip_string(example['gt']) |
|
|
|
|
|
if data_name in ["math", 'ocw']: |
|
gt_cot = example['solution'] |
|
gt_ans = extract_answer(gt_cot) |
|
elif data_name == "gsm8k": |
|
gt_cot, gt_ans = example['answer'].split("####") |
|
elif data_name == "gsm-hard": |
|
gt_cot, gt_ans = example['code'], example['target'] |
|
elif data_name == "svamp": |
|
gt_cot, gt_ans = example['Equation'], example['Answer'] |
|
elif data_name == "asdiv": |
|
gt_cot = example['formula'] |
|
gt_ans = re.sub(r"\(.*?\)", "", example['answer']) |
|
elif data_name == "mawps": |
|
gt_cot, gt_ans = None, example['target'] |
|
elif data_name == "tabmwp": |
|
gt_cot = example['solution'] |
|
gt_ans = example['answer'] |
|
if example['ans_type'] in ['integer_number', 'decimal_number']: |
|
if '/' in gt_ans: |
|
gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1]) |
|
elif ',' in gt_ans: |
|
gt_ans = float(gt_ans.replace(',', '')) |
|
elif '%' in gt_ans: |
|
gt_ans = float(gt_ans.split('%')[0]) / 100 |
|
else: |
|
gt_ans = float(gt_ans) |
|
elif data_name == "bbh": |
|
gt_cot, gt_ans = None, example['target'] |
|
else: |
|
raise NotImplementedError(data_name) |
|
|
|
gt_cot = str(gt_cot).strip() |
|
gt_ans = strip_string(gt_ans) |
|
return gt_cot, gt_ans |
|
|
|
|
|
def parse_question(example, data_name): |
|
question = "" |
|
if data_name == "asdiv": |
|
question = f"{example['body'].strip()} {example['question'].strip()}" |
|
elif data_name == "svamp": |
|
body = example["Body"].strip() |
|
if not body.endswith("."): |
|
body = body + "." |
|
question = f'{body} {example["Question"].strip()}' |
|
elif data_name == "tabmwp": |
|
title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else "" |
|
question = f'Read the following table {title_str}and answer a question:\n' |
|
question += f'{example["table"]}\n{example["question"]}' |
|
if example['choices']: |
|
question += f' Please select from the following options: {example["choices"]}' |
|
else: |
|
for key in ['question', 'problem', 'Question', 'input']: |
|
if key in example: |
|
question = example[key] |
|
break |
|
assert question != "" |
|
return question.strip() |
|
|
|
|
|
def run_execute(executor, result, prompt_type, execute=False): |
|
if not result or result == 'error': |
|
return None, None |
|
report = None |
|
|
|
if "program_only" in prompt_type: |
|
prediction = extract_program_output(result) |
|
elif prompt_type in ["pot", "pal"] and execute: |
|
code = extract_program(result) |
|
prediction, report = executor.apply(code) |
|
else: |
|
prediction = extract_answer(result) |
|
|
|
prediction = strip_string(prediction) |
|
return prediction, report |
|
|