import ast import logging import re import traceback from typing import Any import numpy as np from sympy import Rational from tasks.base import Task LOGGER = logging.getLogger('MINT') class ReasoningTask(Task): task_name = 'reasoning' def __init__(self, id: str, prompt: str, reference: str, **kwargs): super().__init__(**kwargs) self._id = id self._prompt = prompt.strip() self._reference = str(reference).strip().lower() def extract_answer(self, solution: str) -> str | None: """Extract the answer from the given solution.""" return solution.lower().strip() def compare_w_digits(self, reference: str, answer: str) -> bool: """Compare the reference and answer with digits.""" # if reference can and answer can both be converted to floats by float() try: float(reference) float(answer) return abs(float(reference) - float(answer)) <= 0.05 * abs(float(reference)) except ValueError: return reference in answer except Exception: raise ValueError(f'Cannot compare {reference} and {answer}') def success(self, solution: str) -> bool: answer = self.extract_answer(solution) return self.compare_w_digits(self._reference, answer) class MultipleChoiceTask(Task): """Subclass of Task for multiple choice tasks.""" task_name = 'reasoning' def __init__(self, id, prompt: str, reference: str, **kwargs): super().__init__(**kwargs) self._id = id self.hide_options = kwargs.get('hide_options', False) if self.hide_options: self._prompt = prompt.split('Options:')[0].strip() else: self._prompt = prompt self._reference = reference.strip().lower() self._options = self.extract_options(prompt) # if all options can be converted to float, strictly perform hide options try: for option in self._options.values(): float(option) self.hide_options = True except ValueError: pass self.metadata.update({'options': self._options}) def extract_answer(self, solution: str) -> str | None: # Extract the selected option from the solution solution = solution.lower().strip() for letter in 'abcdefghijklmnopqrstuvwxyz': if f'{letter})' in solution or f'{letter} )' in solution: print('SOLUTION', letter) return letter else: print('SOLUTION', solution) return solution def compare_w_digits(self, reference: str, answer: str) -> bool: if reference.isdigit() and answer.isdigit(): return abs(float(reference) - float(answer)) <= 0.05 * float(reference) else: return reference in answer def success(self, solution: str) -> bool: answer = self.extract_answer(solution) if self.compare_w_digits(self._reference, answer): return True else: correct_option = self._options[self._reference] wrong_option_list = list(self._options.values()) print('OPTIONS', correct_option, wrong_option_list) print('ANSWER', answer) for i in wrong_option_list: if i in correct_option: wrong_option_list.remove(i) for i in wrong_option_list: if self.compare_w_digits(i, answer) or (i in answer): return False if self.compare_w_digits(correct_option, answer) or ( correct_option in answer ): return True else: return False def extract_options(self, prompt: str) -> dict: # Find the possible option separators (comma, semicolon, or parentheses) prompt = prompt.split('Options: ')[-1] # Extract the options using the delimiter options_match = prompt.split(' , ') options = {} for i in range(len(options_match)): option = options_match[i].strip("[]' ") option = option.split(')') letter = option[0].lower().strip() content = ( option[1] .lower() .strip('.') .replace('. Which option is correct?', '') .replace('. Which one is correct?', '') .strip() ) options.update({letter: content}) return options # ==== TheoremQA ==== def compare_two_numbers(p, gt): if isinstance(p, (int, float)): pass elif isinstance(p, (bool, complex, dict, list, str, tuple)): return False else: raise ValueError(p) if isinstance(gt, float): return within_eps(pred=p, gt=gt) else: return round(p) == gt def compare_two_list(pred, gt): if not isinstance(pred, list): return False elif len(pred) != len(gt): return False elif any([not isinstance(x, (int, float)) for x in pred]): return False else: pred = sorted(pred) gt = sorted(gt) return all([compare_two_numbers(p, g) for p, g in zip(pred, gt)]) def within_eps(pred: float, gt: float): eps = abs(gt) * 0.04 if pred >= gt - eps and pred <= gt + eps: return True else: return False def parse_number_list(s: str): # Check if the string is a valid list by trying to parse it parsed_list = ast.literal_eval(s) return parsed_list def is_number(string): pattern = r'^[-+]?(\d{1,3}(,\d{3})*|(\d+))(\.\d+)?$' match = re.match(pattern, string) return bool(match) def is_scientific_number(string): pattern = r'^[-+]?\d+(\.\d+)?e[-]?\d+$' match = re.match(pattern, string) return bool(match) def contain_num_and_str(string): pattern_str = r'[a-zA-Z]' pattern_num = r'[0-9]' return bool(re.search(pattern_str, string) and re.search(pattern_num, string)) class TheoremqaTask(Task): task_name = 'reasoning' def __init__(self, id: str, prompt: str, reference: str, **kwargs): super().__init__(**kwargs) self._id = id self._prompt = ( 'Answer the following question with a number, a list of numbers or True or False. ' + prompt.strip() ) self._reference = reference self._answer_type = kwargs.get('answer_type') def extract_answer(self, solution: str) -> Any: """Extract the answer from the given solution.""" prediction = solution # Following the preprocessing steps from TheoremQA # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L170 # Preprocessing the string [Stage 1] if not isinstance(prediction, str): prediction = str(prediction) if prediction is not None else '0' # Replace special tokens if '=' in prediction: prediction = prediction.split('=')[-1].strip() if '≈' in prediction: prediction = prediction.split('≈')[-1].strip() if '`' in prediction: prediction = prediction.replace('`', '') if '$' in prediction: prediction = prediction.replace('$', '') if '°' in prediction: prediction = prediction.replace('°', '') # Detect the boolean keyword in the generation if prediction in ('true', 'yes', 'false', 'no'): if prediction in ('true', 'yes'): prediction = 'True' else: prediction = 'False' if 'True' in prediction or 'False' in prediction: prediction = 'True' if 'True' in prediction else 'False' # Detect the approximation keyword if 'approximately' in prediction: prediction = prediction.replace('approximately', '').strip() if ' or ' in prediction: prediction = prediction.split(' or ')[0] # Drop the units before and after the number if re.match(r'[-+]?(?:[\d,]*\.*\d+) [^0-9 ]+$', prediction): prediction = re.search( r'([-+]?(?:[\d,]*\.*\d+)) [^0-9 ]+$', prediction ).group(1) if re.match(r'[^0-9 ]+ [-+]?(?:[\d,]*\.*\d+)$', prediction): prediction = re.search( r'[^0-9 ]+ ([-+]?(?:[\d,]*\.*\d+))$', prediction ).group(1) if re.match(r'[-+]?(?:[\d,]*\.*\d+)[^\d]{1,2}$', prediction): prediction = re.search( r'([-+]?(?:[\d,]*\.*\d+))[^\d]{1,2}$', prediction ).group(1) if re.match(r'[^-+\d]{1,2}(?:[\d,]*\.*\d+)$', prediction): prediction = re.search( r'[^-+\d]{1,2}((?:[\d,]*\.*\d+))$', prediction ).group(1) # Preprocessing the number [Stage 1] if '10^' in prediction: prediction = re.sub(r'10\^(-?\d+)', r'math.pow(10, \1)', prediction) if ' x ' in prediction: prediction = prediction.replace(' x ', '*') if ' × ' in prediction: prediction = prediction.replace(' × ', '*') if is_number(prediction): prediction = prediction.replace(',', '') # Preprocessing the option [Stage 3] if ( 'a)' in prediction or 'a )' in prediction or prediction.lower().strip() == 'a' ): prediction = '(a)' if ( 'b)' in prediction or 'b )' in prediction or prediction.lower().strip() == 'b' ): prediction = '(b)' if ( 'c)' in prediction or 'c )' in prediction or prediction.lower().strip() == 'c' ): prediction = '(c)' if ( 'd)' in prediction or 'd )' in prediction or prediction.lower().strip() == 'd' ): prediction = '(d)' if ( '(a)' in prediction or '(b)' in prediction or '(c)' in prediction or '(d)' in prediction ): prediction = '"' + re.search(r'\([a-d]\)', prediction).group(0) + '"' # If the prediction is empty, use dummy '0' if not prediction: prediction = '0' # Converting the string answer to a number/list/bool/option try: prediction = eval(prediction) except Exception: LOGGER.warning( f'[TASK] Failed to convert the answer: {prediction}\n{traceback.format_exc()}' ) return None # failed to convert the answer # Performing common type conversion if isinstance(prediction, (set, tuple)): prediction = list(prediction) if isinstance(prediction[0], complex): prediction = [tmp.real for tmp in prediction] elif isinstance(prediction[0], Rational): prediction = [float(tmp) for tmp in prediction] elif isinstance(prediction, np.ndarray): prediction = prediction.tolist() else: if isinstance(prediction, complex): prediction = prediction.real elif isinstance(prediction, Rational): prediction = float(prediction) return prediction def success(self, solution: str) -> bool: """This checks whether the given solution can complete the current task.""" # Follow the implementation from TheoremQA # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L301C9-L317C1 prediction = self.extract_answer(solution) LOGGER.info(f'TheoremQA Parsed Prediction: {prediction}') answer_type = self._answer_type gt = self.extract_answer(self.reference) if isinstance(prediction, (str, int, float, list)): # Comparing prediction against the reference if answer_type in ['bool', 'option', 'Option']: cur_correct = int(prediction == f'({gt})') or int(prediction == gt) elif answer_type == 'integer': cur_correct = int(compare_two_numbers(prediction, gt)) elif answer_type == 'float': cur_correct = int(compare_two_numbers(prediction, gt)) elif answer_type in ['list of integer', 'list of float']: cur_correct = int(compare_two_list(prediction, gt)) else: cur_correct = 0 return bool(cur_correct)