Spaces:
Build error
Build error
import asyncio | |
import json | |
import time | |
from logging import LoggerAdapter | |
from types import MappingProxyType | |
from typing import Callable, cast | |
from openhands.controller import AgentController | |
from openhands.controller.agent import Agent | |
from openhands.controller.replay import ReplayManager | |
from openhands.controller.state.state import State | |
from openhands.core.config import AgentConfig, LLMConfig, OpenHandsConfig | |
from openhands.core.exceptions import AgentRuntimeUnavailableError | |
from openhands.core.logger import OpenHandsLoggerAdapter | |
from openhands.core.schema.agent import AgentState | |
from openhands.events.action import ChangeAgentStateAction, MessageAction | |
from openhands.events.event import Event, EventSource | |
from openhands.events.stream import EventStream | |
from openhands.integrations.provider import ( | |
CUSTOM_SECRETS_TYPE, | |
PROVIDER_TOKEN_TYPE, | |
ProviderHandler, | |
) | |
from openhands.mcp import add_mcp_tools_to_agent | |
from openhands.memory.memory import Memory | |
from openhands.microagent.microagent import BaseMicroagent | |
from openhands.runtime import get_runtime_cls | |
from openhands.runtime.base import Runtime | |
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime | |
from openhands.security import SecurityAnalyzer, options | |
from openhands.storage.data_models.user_secrets import UserSecrets | |
from openhands.storage.files import FileStore | |
from openhands.utils.async_utils import EXECUTOR, call_sync_from_async | |
from openhands.utils.shutdown_listener import should_continue | |
WAIT_TIME_BEFORE_CLOSE = 90 | |
WAIT_TIME_BEFORE_CLOSE_INTERVAL = 5 | |
class AgentSession: | |
"""Represents a session with an Agent | |
Attributes: | |
controller: The AgentController instance for controlling the agent. | |
""" | |
sid: str | |
user_id: str | None | |
event_stream: EventStream | |
file_store: FileStore | |
controller: AgentController | None = None | |
runtime: Runtime | None = None | |
security_analyzer: SecurityAnalyzer | None = None | |
_starting: bool = False | |
_started_at: float = 0 | |
_closed: bool = False | |
loop: asyncio.AbstractEventLoop | None = None | |
logger: LoggerAdapter | |
def __init__( | |
self, | |
sid: str, | |
file_store: FileStore, | |
status_callback: Callable | None = None, | |
user_id: str | None = None, | |
) -> None: | |
"""Initializes a new instance of the Session class | |
Parameters: | |
- sid: The session ID | |
- file_store: Instance of the FileStore | |
""" | |
self.sid = sid | |
self.event_stream = EventStream(sid, file_store, user_id) | |
self.file_store = file_store | |
self._status_callback = status_callback | |
self.user_id = user_id | |
self.logger = OpenHandsLoggerAdapter( | |
extra={'session_id': sid, 'user_id': user_id} | |
) | |
async def start( | |
self, | |
runtime_name: str, | |
config: OpenHandsConfig, | |
agent: Agent, | |
max_iterations: int, | |
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None, | |
custom_secrets: CUSTOM_SECRETS_TYPE | None = None, | |
max_budget_per_task: float | None = None, | |
agent_to_llm_config: dict[str, LLMConfig] | None = None, | |
agent_configs: dict[str, AgentConfig] | None = None, | |
selected_repository: str | None = None, | |
selected_branch: str | None = None, | |
initial_message: MessageAction | None = None, | |
conversation_instructions: str | None = None, | |
replay_json: str | None = None, | |
) -> None: | |
"""Starts the Agent session | |
Parameters: | |
- runtime_name: The name of the runtime associated with the session | |
- config: | |
- agent: | |
- max_iterations: | |
- max_budget_per_task: | |
- agent_to_llm_config: | |
- agent_configs: | |
""" | |
if self.controller or self.runtime: | |
raise RuntimeError( | |
'Session already started. You need to close this session and start a new one.' | |
) | |
if self._closed: | |
self.logger.warning('Session closed before starting') | |
return | |
self._starting = True | |
started_at = time.time() | |
self._started_at = started_at | |
finished = False # For monitoring | |
runtime_connected = False | |
custom_secrets_handler = UserSecrets( | |
custom_secrets=custom_secrets if custom_secrets else {} | |
) | |
try: | |
self._create_security_analyzer(config.security.security_analyzer) | |
runtime_connected = await self._create_runtime( | |
runtime_name=runtime_name, | |
config=config, | |
agent=agent, | |
git_provider_tokens=git_provider_tokens, | |
custom_secrets=custom_secrets, | |
selected_repository=selected_repository, | |
selected_branch=selected_branch, | |
) | |
repo_directory = None | |
if self.runtime and runtime_connected and selected_repository: | |
repo_directory = selected_repository.split('/')[-1] | |
if git_provider_tokens: | |
provider_handler = ProviderHandler(provider_tokens=git_provider_tokens) | |
await provider_handler.set_event_stream_secrets(self.event_stream) | |
if custom_secrets: | |
custom_secrets_handler.set_event_stream_secrets(self.event_stream) | |
self.memory = await self._create_memory( | |
selected_repository=selected_repository, | |
repo_directory=repo_directory, | |
conversation_instructions=conversation_instructions, | |
custom_secrets_descriptions=custom_secrets_handler.get_custom_secrets_descriptions(), | |
) | |
# NOTE: this needs to happen before controller is created | |
# so MCP tools can be included into the SystemMessageAction | |
if self.runtime and runtime_connected and agent.config.enable_mcp: | |
await add_mcp_tools_to_agent(agent, self.runtime, self.memory, config) | |
if replay_json: | |
initial_message = self._run_replay( | |
initial_message, | |
replay_json, | |
agent, | |
config, | |
max_iterations, | |
max_budget_per_task, | |
agent_to_llm_config, | |
agent_configs, | |
) | |
else: | |
self.controller = self._create_controller( | |
agent, | |
config.security.confirmation_mode, | |
max_iterations, | |
max_budget_per_task=max_budget_per_task, | |
agent_to_llm_config=agent_to_llm_config, | |
agent_configs=agent_configs, | |
) | |
if not self._closed: | |
if initial_message: | |
self.event_stream.add_event(initial_message, EventSource.USER) | |
self.event_stream.add_event( | |
ChangeAgentStateAction(AgentState.RUNNING), | |
EventSource.ENVIRONMENT, | |
) | |
else: | |
self.event_stream.add_event( | |
ChangeAgentStateAction(AgentState.AWAITING_USER_INPUT), | |
EventSource.ENVIRONMENT, | |
) | |
finished = True | |
finally: | |
self._starting = False | |
success = finished and runtime_connected | |
self.logger.info( | |
'Agent session start', | |
extra={ | |
'signal': 'agent_session_start', | |
'success': success, | |
'duration': (time.time() - started_at), | |
}, | |
) | |
async def close(self) -> None: | |
"""Closes the Agent session""" | |
if self._closed: | |
return | |
self._closed = True | |
while self._starting and should_continue(): | |
self.logger.debug( | |
f'Waiting for initialization to finish before closing session {self.sid}' | |
) | |
await asyncio.sleep(WAIT_TIME_BEFORE_CLOSE_INTERVAL) | |
if time.time() <= self._started_at + WAIT_TIME_BEFORE_CLOSE: | |
self.logger.error( | |
f'Waited too long for initialization to finish before closing session {self.sid}' | |
) | |
break | |
if self.event_stream is not None: | |
self.event_stream.close() | |
if self.controller is not None: | |
end_state = self.controller.get_state() | |
end_state.save_to_session(self.sid, self.file_store, self.user_id) | |
await self.controller.close() | |
if self.runtime is not None: | |
EXECUTOR.submit(self.runtime.close) | |
if self.security_analyzer is not None: | |
await self.security_analyzer.close() | |
def _run_replay( | |
self, | |
initial_message: MessageAction | None, | |
replay_json: str, | |
agent: Agent, | |
config: OpenHandsConfig, | |
max_iterations: int, | |
max_budget_per_task: float | None, | |
agent_to_llm_config: dict[str, LLMConfig] | None, | |
agent_configs: dict[str, AgentConfig] | None, | |
) -> MessageAction: | |
""" | |
Replays a trajectory from a JSON file. Note that once the replay session | |
finishes, the controller will continue to run with further user instructions, | |
so we still need to pass llm configs, budget, etc., even though the replay | |
itself does not call LLM or cost money. | |
""" | |
assert initial_message is None | |
replay_events = ReplayManager.get_replay_events(json.loads(replay_json)) | |
self.controller = self._create_controller( | |
agent, | |
config.security.confirmation_mode, | |
max_iterations, | |
max_budget_per_task=max_budget_per_task, | |
agent_to_llm_config=agent_to_llm_config, | |
agent_configs=agent_configs, | |
replay_events=replay_events[1:], | |
) | |
assert isinstance(replay_events[0], MessageAction) | |
return replay_events[0] | |
def _create_security_analyzer(self, security_analyzer: str | None) -> None: | |
"""Creates a SecurityAnalyzer instance that will be used to analyze the agent actions | |
Parameters: | |
- security_analyzer: The name of the security analyzer to use | |
""" | |
if security_analyzer: | |
self.logger.debug(f'Using security analyzer: {security_analyzer}') | |
self.security_analyzer = options.SecurityAnalyzers.get( | |
security_analyzer, SecurityAnalyzer | |
)(self.event_stream) | |
def override_provider_tokens_with_custom_secret( | |
self, | |
git_provider_tokens: PROVIDER_TOKEN_TYPE | None, | |
custom_secrets: CUSTOM_SECRETS_TYPE | None, | |
): | |
if git_provider_tokens and custom_secrets: | |
tokens = dict(git_provider_tokens) | |
for provider, _ in tokens.items(): | |
token_name = ProviderHandler.get_provider_env_key(provider) | |
if token_name in custom_secrets or token_name.upper() in custom_secrets: | |
del tokens[provider] | |
return MappingProxyType(tokens) | |
return git_provider_tokens | |
async def _create_runtime( | |
self, | |
runtime_name: str, | |
config: OpenHandsConfig, | |
agent: Agent, | |
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None, | |
custom_secrets: CUSTOM_SECRETS_TYPE | None = None, | |
selected_repository: str | None = None, | |
selected_branch: str | None = None, | |
) -> bool: | |
"""Creates a runtime instance | |
Parameters: | |
- runtime_name: The name of the runtime associated with the session | |
- config: | |
- agent: | |
Return True on successfully connected, False if could not connect. | |
Raises if already created, possibly in other situations. | |
""" | |
if self.runtime is not None: | |
raise RuntimeError('Runtime already created') | |
custom_secrets_handler = UserSecrets(custom_secrets=custom_secrets or {}) | |
env_vars = custom_secrets_handler.get_env_vars() | |
self.logger.debug(f'Initializing runtime `{runtime_name}` now...') | |
runtime_cls = get_runtime_cls(runtime_name) | |
if runtime_cls == RemoteRuntime: | |
# If provider tokens is passed in custom secrets, then remove provider from provider tokens | |
# We prioritize provider tokens set in custom secrets | |
provider_tokens_without_gitlab = ( | |
self.override_provider_tokens_with_custom_secret( | |
git_provider_tokens, custom_secrets | |
) | |
) | |
self.runtime = runtime_cls( | |
config=config, | |
event_stream=self.event_stream, | |
sid=self.sid, | |
plugins=agent.sandbox_plugins, | |
status_callback=self._status_callback, | |
headless_mode=False, | |
attach_to_existing=False, | |
git_provider_tokens=provider_tokens_without_gitlab, | |
env_vars=env_vars, | |
user_id=self.user_id, | |
) | |
else: | |
provider_handler = ProviderHandler( | |
provider_tokens=git_provider_tokens | |
or cast(PROVIDER_TOKEN_TYPE, MappingProxyType({})) | |
) | |
# Merge git provider tokens with custom secrets before passing over to runtime | |
env_vars.update(await provider_handler.get_env_vars(expose_secrets=True)) | |
self.runtime = runtime_cls( | |
config=config, | |
event_stream=self.event_stream, | |
sid=self.sid, | |
plugins=agent.sandbox_plugins, | |
status_callback=self._status_callback, | |
headless_mode=False, | |
attach_to_existing=False, | |
env_vars=env_vars, | |
) | |
# FIXME: this sleep is a terrible hack. | |
# This is to give the websocket a second to connect, so that | |
# the status messages make it through to the frontend. | |
# We should find a better way to plumb status messages through. | |
await asyncio.sleep(1) | |
try: | |
await self.runtime.connect() | |
except AgentRuntimeUnavailableError as e: | |
self.logger.error(f'Runtime initialization failed: {e}') | |
if self._status_callback: | |
self._status_callback( | |
'error', 'STATUS$ERROR_RUNTIME_DISCONNECTED', str(e) | |
) | |
return False | |
await self.runtime.clone_or_init_repo( | |
git_provider_tokens, selected_repository, selected_branch | |
) | |
await call_sync_from_async(self.runtime.maybe_run_setup_script) | |
await call_sync_from_async(self.runtime.maybe_setup_git_hooks) | |
self.logger.debug( | |
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}' | |
) | |
return True | |
def _create_controller( | |
self, | |
agent: Agent, | |
confirmation_mode: bool, | |
max_iterations: int, | |
max_budget_per_task: float | None = None, | |
agent_to_llm_config: dict[str, LLMConfig] | None = None, | |
agent_configs: dict[str, AgentConfig] | None = None, | |
replay_events: list[Event] | None = None, | |
) -> AgentController: | |
"""Creates an AgentController instance | |
Parameters: | |
- agent: | |
- confirmation_mode: Whether to use confirmation mode | |
- max_iterations: | |
- max_budget_per_task: | |
- agent_to_llm_config: | |
- agent_configs: | |
""" | |
if self.controller is not None: | |
raise RuntimeError('Controller already created') | |
if self.runtime is None: | |
raise RuntimeError( | |
'Runtime must be initialized before the agent controller' | |
) | |
msg = ( | |
'\n--------------------------------- OpenHands Configuration ---------------------------------\n' | |
f'LLM: {agent.llm.config.model}\n' | |
f'Base URL: {agent.llm.config.base_url}\n' | |
f'Agent: {agent.name}\n' | |
f'Runtime: {self.runtime.__class__.__name__}\n' | |
f'Plugins: {[p.name for p in agent.sandbox_plugins] if agent.sandbox_plugins else "None"}\n' | |
'-------------------------------------------------------------------------------------------' | |
) | |
self.logger.debug(msg) | |
controller = AgentController( | |
sid=self.sid, | |
event_stream=self.event_stream, | |
agent=agent, | |
max_iterations=int(max_iterations), | |
max_budget_per_task=max_budget_per_task, | |
agent_to_llm_config=agent_to_llm_config, | |
agent_configs=agent_configs, | |
confirmation_mode=confirmation_mode, | |
headless_mode=False, | |
status_callback=self._status_callback, | |
initial_state=self._maybe_restore_state(), | |
replay_events=replay_events, | |
) | |
return controller | |
async def _create_memory( | |
self, | |
selected_repository: str | None, | |
repo_directory: str | None, | |
conversation_instructions: str | None, | |
custom_secrets_descriptions: dict[str, str], | |
) -> Memory: | |
memory = Memory( | |
event_stream=self.event_stream, | |
sid=self.sid, | |
status_callback=self._status_callback, | |
) | |
if self.runtime: | |
# sets available hosts and other runtime info | |
memory.set_runtime_info(self.runtime, custom_secrets_descriptions) | |
memory.set_conversation_instructions(conversation_instructions) | |
# loads microagents from repo/.openhands/microagents | |
microagents: list[BaseMicroagent] = await call_sync_from_async( | |
self.runtime.get_microagents_from_selected_repo, | |
selected_repository or None, | |
) | |
memory.load_user_workspace_microagents(microagents) | |
if selected_repository and repo_directory: | |
memory.set_repository_info(selected_repository, repo_directory) | |
return memory | |
def _maybe_restore_state(self) -> State | None: | |
"""Helper method to handle state restore logic.""" | |
restored_state = None | |
# Attempt to restore the state from session. | |
# Use a heuristic to figure out if we should have a state: | |
# if we have events in the stream. | |
try: | |
restored_state = State.restore_from_session( | |
self.sid, self.file_store, self.user_id | |
) | |
self.logger.debug(f'Restored state from session, sid: {self.sid}') | |
except Exception as e: | |
if self.event_stream.get_latest_event_id() > 0: | |
# if we have events, we should have a state | |
self.logger.warning(f'State could not be restored: {e}') | |
else: | |
self.logger.debug('No events found, no state to restore') | |
return restored_state | |
def get_state(self) -> AgentState | None: | |
controller = self.controller | |
if controller: | |
return controller.state.agent_state | |
if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE: | |
# If 5 minutes have elapsed and we still don't have a controller, something has gone wrong | |
return AgentState.ERROR | |
return None | |
def is_closed(self) -> bool: | |
return self._closed | |