|
import asyncio
|
|
import time
|
|
from copy import deepcopy
|
|
|
|
import socketio
|
|
|
|
from openhands.controller.agent import Agent
|
|
from openhands.core.config import AppConfig
|
|
from openhands.core.const.guide_url import TROUBLESHOOTING_URL
|
|
from openhands.core.logger import openhands_logger as logger
|
|
from openhands.core.schema import AgentState
|
|
from openhands.events.action import MessageAction, NullAction
|
|
from openhands.events.event import Event, EventSource
|
|
from openhands.events.observation import (
|
|
AgentStateChangedObservation,
|
|
CmdOutputObservation,
|
|
NullObservation,
|
|
)
|
|
from openhands.events.observation.error import ErrorObservation
|
|
from openhands.events.serialization import event_from_dict, event_to_dict
|
|
from openhands.events.stream import EventStreamSubscriber
|
|
from openhands.llm.llm import LLM
|
|
from openhands.server.session.agent_session import AgentSession
|
|
from openhands.server.session.conversation_init_data import ConversationInitData
|
|
from openhands.server.settings import Settings
|
|
from openhands.storage.files import FileStore
|
|
|
|
ROOM_KEY = 'room:{sid}'
|
|
|
|
|
|
class Session:
|
|
sid: str
|
|
sio: socketio.AsyncServer | None
|
|
last_active_ts: int = 0
|
|
is_alive: bool = True
|
|
agent_session: AgentSession
|
|
loop: asyncio.AbstractEventLoop
|
|
config: AppConfig
|
|
file_store: FileStore
|
|
user_id: str | None
|
|
|
|
def __init__(
|
|
self,
|
|
sid: str,
|
|
config: AppConfig,
|
|
file_store: FileStore,
|
|
sio: socketio.AsyncServer | None,
|
|
user_id: str | None = None,
|
|
):
|
|
self.sid = sid
|
|
self.sio = sio
|
|
self.last_active_ts = int(time.time())
|
|
self.file_store = file_store
|
|
self.agent_session = AgentSession(
|
|
sid, file_store, status_callback=self.queue_status_message
|
|
)
|
|
self.agent_session.event_stream.subscribe(
|
|
EventStreamSubscriber.SERVER, self.on_event, self.sid
|
|
)
|
|
|
|
self.config = deepcopy(config)
|
|
self.loop = asyncio.get_event_loop()
|
|
self.user_id = user_id
|
|
|
|
async def close(self):
|
|
if self.sio:
|
|
await self.sio.emit(
|
|
'oh_event',
|
|
event_to_dict(
|
|
AgentStateChangedObservation('', AgentState.STOPPED.value)
|
|
),
|
|
to=ROOM_KEY.format(sid=self.sid),
|
|
)
|
|
self.is_alive = False
|
|
await self.agent_session.close()
|
|
|
|
async def initialize_agent(self, settings: Settings, initial_user_msg: str | None):
|
|
self.agent_session.event_stream.add_event(
|
|
AgentStateChangedObservation('', AgentState.LOADING),
|
|
EventSource.ENVIRONMENT,
|
|
)
|
|
|
|
agent_cls = settings.agent or self.config.default_agent
|
|
self.config.security.confirmation_mode = (
|
|
self.config.security.confirmation_mode
|
|
if settings.confirmation_mode is None
|
|
else settings.confirmation_mode
|
|
)
|
|
self.config.security.security_analyzer = (
|
|
settings.security_analyzer or self.config.security.security_analyzer
|
|
)
|
|
max_iterations = settings.max_iterations or self.config.max_iterations
|
|
|
|
|
|
|
|
|
|
default_llm_config = self.config.get_llm_config()
|
|
default_llm_config.model = settings.llm_model or ''
|
|
default_llm_config.api_key = settings.llm_api_key
|
|
default_llm_config.base_url = settings.llm_base_url
|
|
|
|
|
|
|
|
llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
|
|
agent_config = self.config.get_agent_config(agent_cls)
|
|
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
|
|
|
github_token = None
|
|
selected_repository = None
|
|
if isinstance(settings, ConversationInitData):
|
|
github_token = settings.github_token
|
|
selected_repository = settings.selected_repository
|
|
|
|
try:
|
|
await self.agent_session.start(
|
|
runtime_name=self.config.runtime,
|
|
config=self.config,
|
|
agent=agent,
|
|
max_iterations=max_iterations,
|
|
max_budget_per_task=self.config.max_budget_per_task,
|
|
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
|
|
agent_configs=self.config.get_agent_configs(),
|
|
github_token=github_token,
|
|
selected_repository=selected_repository,
|
|
initial_user_msg=initial_user_msg,
|
|
)
|
|
except Exception as e:
|
|
logger.exception(f'Error creating agent_session: {e}')
|
|
await self.send_error(
|
|
f'Error creating agent_session. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
|
|
)
|
|
return
|
|
|
|
def on_event(self, event: Event):
|
|
asyncio.get_event_loop().run_until_complete(self._on_event(event))
|
|
|
|
async def _on_event(self, event: Event):
|
|
"""Callback function for events that mainly come from the agent.
|
|
Event is the base class for any agent action and observation.
|
|
|
|
Args:
|
|
event: The agent event (Observation or Action).
|
|
"""
|
|
if isinstance(event, NullAction):
|
|
return
|
|
if isinstance(event, NullObservation):
|
|
return
|
|
if event.source == EventSource.AGENT:
|
|
await self.send(event_to_dict(event))
|
|
elif event.source == EventSource.USER:
|
|
await self.send(event_to_dict(event))
|
|
|
|
elif event.source == EventSource.ENVIRONMENT and isinstance(
|
|
event, (CmdOutputObservation, AgentStateChangedObservation)
|
|
):
|
|
|
|
event_dict = event_to_dict(event)
|
|
event_dict['source'] = EventSource.AGENT
|
|
await self.send(event_dict)
|
|
elif isinstance(event, ErrorObservation):
|
|
|
|
event_dict = event_to_dict(event)
|
|
event_dict['source'] = EventSource.AGENT
|
|
await self.send(event_dict)
|
|
|
|
async def dispatch(self, data: dict):
|
|
event = event_from_dict(data.copy())
|
|
|
|
if isinstance(event, MessageAction) and event.image_urls:
|
|
controller = self.agent_session.controller
|
|
if controller:
|
|
if controller.agent.llm.config.disable_vision:
|
|
await self.send_error(
|
|
'Support for images is disabled for this model, try without an image.'
|
|
)
|
|
return
|
|
if not controller.agent.llm.vision_is_active():
|
|
await self.send_error(
|
|
'Model does not support image upload, change to a different model or try without an image.'
|
|
)
|
|
return
|
|
self.agent_session.event_stream.add_event(event, EventSource.USER)
|
|
|
|
async def send(self, data: dict[str, object]):
|
|
if asyncio.get_running_loop() != self.loop:
|
|
self.loop.create_task(self._send(data))
|
|
return
|
|
await self._send(data)
|
|
|
|
async def _send(self, data: dict[str, object]) -> bool:
|
|
try:
|
|
if not self.is_alive:
|
|
return False
|
|
if self.sio:
|
|
await self.sio.emit('oh_event', data, to=ROOM_KEY.format(sid=self.sid))
|
|
await asyncio.sleep(0.001)
|
|
self.last_active_ts = int(time.time())
|
|
return True
|
|
except RuntimeError as e:
|
|
logger.error(f'Error sending data to websocket: {str(e)}')
|
|
self.is_alive = False
|
|
return False
|
|
|
|
async def send_error(self, message: str):
|
|
"""Sends an error message to the client."""
|
|
await self.send({'error': True, 'message': message})
|
|
|
|
async def _send_status_message(self, msg_type: str, id: str, message: str):
|
|
"""Sends a status message to the client."""
|
|
if msg_type == 'error':
|
|
await self.agent_session.stop_agent_loop_for_error()
|
|
|
|
await self.send(
|
|
{'status_update': True, 'type': msg_type, 'id': id, 'message': message}
|
|
)
|
|
|
|
def queue_status_message(self, msg_type: str, id: str, message: str):
|
|
"""Queues a status message to be sent asynchronously."""
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._send_status_message(msg_type, id, message), self.loop
|
|
)
|
|
|