from pydantic import BaseModel, Field from openhands.core.logger import openhands_logger as logger from openhands.events.action import ( Action, ChangeAgentStateAction, MessageAction, NullAction, ) from openhands.events.event import EventSource from openhands.events.observation import ( AgentStateChangedObservation, NullObservation, Observation, ) from openhands.events.serialization.event import event_to_dict from openhands.security.invariant.nodes import Function, Message, ToolCall, ToolOutput TraceElement = Message | ToolCall | ToolOutput | Function def get_next_id(trace: list[TraceElement]) -> str: used_ids = [el.id for el in trace if isinstance(el, ToolCall)] for i in range(1, len(used_ids) + 2): if str(i) not in used_ids: return str(i) return '1' def get_last_id( trace: list[TraceElement], ) -> str | None: for el in reversed(trace): if isinstance(el, ToolCall): return el.id return None def parse_action(trace: list[TraceElement], action: Action) -> list[TraceElement]: next_id = get_next_id(trace) inv_trace: list[TraceElement] = [] if isinstance(action, MessageAction): if action.source == EventSource.USER: inv_trace.append(Message(role='user', content=action.content)) else: inv_trace.append(Message(role='assistant', content=action.content)) elif isinstance(action, (NullAction, ChangeAgentStateAction)): pass elif hasattr(action, 'action') and action.action is not None: event_dict = event_to_dict(action) args = event_dict.get('args', {}) thought = args.pop('thought', None) function = Function(name=action.action, arguments=args) if thought is not None: inv_trace.append(Message(role='assistant', content=thought)) inv_trace.append(ToolCall(id=next_id, type='function', function=function)) else: logger.error(f'Unknown action type: {type(action)}') return inv_trace def parse_observation( trace: list[TraceElement], obs: Observation ) -> list[TraceElement]: last_id = get_last_id(trace) if isinstance(obs, (NullObservation, AgentStateChangedObservation)): return [] elif hasattr(obs, 'content') and obs.content is not None: return [ToolOutput(role='tool', content=obs.content, tool_call_id=last_id)] else: logger.error(f'Unknown observation type: {type(obs)}') return [] def parse_element( trace: list[TraceElement], element: Action | Observation ) -> list[TraceElement]: if isinstance(element, Action): return parse_action(trace, element) return parse_observation(trace, element) def parse_trace(trace: list[tuple[Action, Observation]]) -> list[TraceElement]: inv_trace: list[TraceElement] = [] for action, obs in trace: inv_trace.extend(parse_action(inv_trace, action)) inv_trace.extend(parse_observation(inv_trace, obs)) return inv_trace class InvariantState(BaseModel): trace: list[TraceElement] = Field(default_factory=list) def add_action(self, action: Action) -> None: self.trace.extend(parse_action(self.trace, action)) def add_observation(self, obs: Observation) -> None: self.trace.extend(parse_observation(self.trace, obs)) def concatenate(self, other: 'InvariantState') -> None: self.trace.extend(other.trace)