Spaces:
Build error
Build error
import re | |
import traceback | |
from datatypes import ParseError, StepOutput, TaskState | |
from tasks.base import Task | |
from openhands.controller.state.state import State | |
class SimplifiedEnv: | |
INVALID_INPUT_MESSAGE = ( | |
"I don't understand your input. \n" | |
'If you want to execute code, please use <execute_ipython> YOUR_CODE_HERE </execute_ipython>.\n' | |
'If you want to give me an answer, please use <solution> YOUR_SOLUTION_HERE </solution>.\n' | |
'For example: The answer to the question is <solution> 42 </solution>. \n' | |
) | |
def __init__(self, agent_state: State, task: Task, task_config: dict[str, int]): | |
self.agent_state = agent_state | |
self.task = task | |
agent_action_count = { | |
'propose_solution': 0, | |
'use_tool': 0, | |
'invalid_action': 0, | |
} | |
# check if agent_state has attribute turn_info set | |
if hasattr(self.agent_state, 'propose_solution_count'): | |
agent_action_count['propose_solution'] = ( | |
self.agent_state.propose_solution_count | |
) | |
self.task_state = TaskState(agent_action_count=agent_action_count) | |
self.task_config = task_config | |
def step(self, lm_message: str): | |
observation = self.handle_propose_solution(lm_message) | |
self.check_max_iteration() | |
turn_info = ( | |
self.task_config['max_iterations'] - self.agent_state.iteration, | |
self.task_config['max_propose_solution'] | |
- self.task_state.agent_action_count['propose_solution'], | |
) | |
output = StepOutput( | |
observation=observation, | |
success=self.task_state.success, | |
turn_info=turn_info, | |
) | |
self.agent_state.propose_solution_count = self.task_state.agent_action_count[ | |
'propose_solution' | |
] | |
self.log_output(output) | |
return self.task_state | |
def handle_propose_solution(self, lm_message) -> str | None: | |
"""Propose answer to check the task success. | |
It might set self.state.finished = True if the task is successful. | |
""" | |
self.task_state.agent_action_count['propose_solution'] += 1 | |
try: | |
parsed = self.parse_propose_solution(lm_message) | |
task_success = self.check_task_success(parsed['answer']) | |
if task_success: | |
self.task_state.finished = True | |
self.task_state.success = True | |
self.task_state.terminate_reason = 'task_success' | |
# NOTE: should not return the function now, because we need to log the output | |
# Set state.finished = True will terminate the episode | |
except ParseError: | |
return SimplifiedEnv.INVALID_INPUT_MESSAGE | |
except Exception: | |
error_traceback = traceback.format_exc() | |
return f'{error_traceback}' | |
def parse_propose_solution(self, lm_message: str) -> dict: | |
"""Define the parsing logic.""" | |
lm_output = '\n' + lm_message + '\n' | |
answer = '\n'.join( | |
[ | |
i.strip() | |
for i in re.findall(r'<solution>(.*?)</solution>', lm_output, re.DOTALL) | |
] | |
) | |
if answer == '': | |
raise ParseError('No answer found.') | |
return {'answer': answer} | |
def log_output(self, output: StepOutput) -> None: | |
if self.task_state.finished: | |
return | |
content = output.to_str() | |
self.task_state.latest_output = output.to_dict() | |
self.task_state.latest_output['content'] = content | |
def check_task_success(self, answer: str) -> bool: | |
# log_message.info(f"STUDENT ANSWER: [{answer}]") | |
# log_message.info(f"REFERENCE ANSWER: [{self.task.reference}]") | |
return self.task.success(answer) | |
def check_max_iteration(self): | |
"""Check if the agent has reached the max iteration limit. | |
It might set self.state.finished = True if the agent has reached the max iteration limit. | |
""" | |
if self.task_state.finished: | |
# ignore if the episode is already finished (e.g., task success) | |
return | |
if ( | |
# propose solution > max output solution | |
self.task_state.agent_action_count['propose_solution'] | |
>= self.task_config['max_propose_solution'] | |
): | |
self.task_state.finished = True | |
self.task_state.success = False | |
self.task_state.terminate_reason = 'max_propose_steps' | |
elif self.agent_state.iteration >= self.task_config['max_iterations']: | |
self.task_state.finished = True | |
self.task_state.success = False | |
self.task_state.terminate_reason = 'max_iterations' | |