ar08's picture
Upload 1040 files
246d201 verified
raw
history blame
9.15 kB
import asyncio
import time
from copy import deepcopy
import socketio
from openhands.controller.agent import Agent
from openhands.core.config import AppConfig
from openhands.core.const.guide_url import TROUBLESHOOTING_URL
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import AgentState
from openhands.events.action import MessageAction, NullAction
from openhands.events.event import Event, EventSource
from openhands.events.observation import (
AgentStateChangedObservation,
CmdOutputObservation,
NullObservation,
)
from openhands.events.observation.error import ErrorObservation
from openhands.events.serialization import event_from_dict, event_to_dict
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.server.session.agent_session import AgentSession
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.settings import Settings
from openhands.storage.files import FileStore
ROOM_KEY = 'room:{sid}'
class Session:
sid: str
sio: socketio.AsyncServer | None
last_active_ts: int = 0
is_alive: bool = True
agent_session: AgentSession
loop: asyncio.AbstractEventLoop
config: AppConfig
file_store: FileStore
user_id: str | None
def __init__(
self,
sid: str,
config: AppConfig,
file_store: FileStore,
sio: socketio.AsyncServer | None,
user_id: str | None = None,
):
self.sid = sid
self.sio = sio
self.last_active_ts = int(time.time())
self.file_store = file_store
self.agent_session = AgentSession(
sid, file_store, status_callback=self.queue_status_message
)
self.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER, self.on_event, self.sid
)
# Copying this means that when we update variables they are not applied to the shared global configuration!
self.config = deepcopy(config)
self.loop = asyncio.get_event_loop()
self.user_id = user_id
async def close(self):
if self.sio:
await self.sio.emit(
'oh_event',
event_to_dict(
AgentStateChangedObservation('', AgentState.STOPPED.value)
),
to=ROOM_KEY.format(sid=self.sid),
)
self.is_alive = False
await self.agent_session.close()
async def initialize_agent(self, settings: Settings, initial_user_msg: str | None):
self.agent_session.event_stream.add_event(
AgentStateChangedObservation('', AgentState.LOADING),
EventSource.ENVIRONMENT,
)
agent_cls = settings.agent or self.config.default_agent
self.config.security.confirmation_mode = (
self.config.security.confirmation_mode
if settings.confirmation_mode is None
else settings.confirmation_mode
)
self.config.security.security_analyzer = (
settings.security_analyzer or self.config.security.security_analyzer
)
max_iterations = settings.max_iterations or self.config.max_iterations
# This is a shallow copy of the default LLM config, so changes here will
# persist if we retrieve the default LLM config again when constructing
# the agent
default_llm_config = self.config.get_llm_config()
default_llm_config.model = settings.llm_model or ''
default_llm_config.api_key = settings.llm_api_key
default_llm_config.base_url = settings.llm_base_url
# TODO: override other LLM config & agent config groups (#2075)
llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
agent_config = self.config.get_agent_config(agent_cls)
agent = Agent.get_cls(agent_cls)(llm, agent_config)
github_token = None
selected_repository = None
if isinstance(settings, ConversationInitData):
github_token = settings.github_token
selected_repository = settings.selected_repository
try:
await self.agent_session.start(
runtime_name=self.config.runtime,
config=self.config,
agent=agent,
max_iterations=max_iterations,
max_budget_per_task=self.config.max_budget_per_task,
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
agent_configs=self.config.get_agent_configs(),
github_token=github_token,
selected_repository=selected_repository,
initial_user_msg=initial_user_msg,
)
except Exception as e:
logger.exception(f'Error creating agent_session: {e}')
await self.send_error(
f'Error creating agent_session. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
)
return
def on_event(self, event: Event):
asyncio.get_event_loop().run_until_complete(self._on_event(event))
async def _on_event(self, event: Event):
"""Callback function for events that mainly come from the agent.
Event is the base class for any agent action and observation.
Args:
event: The agent event (Observation or Action).
"""
if isinstance(event, NullAction):
return
if isinstance(event, NullObservation):
return
if event.source == EventSource.AGENT:
await self.send(event_to_dict(event))
elif event.source == EventSource.USER:
await self.send(event_to_dict(event))
# NOTE: ipython observations are not sent here currently
elif event.source == EventSource.ENVIRONMENT and isinstance(
event, (CmdOutputObservation, AgentStateChangedObservation)
):
# feedback from the environment to agent actions is understood as agent events by the UI
event_dict = event_to_dict(event)
event_dict['source'] = EventSource.AGENT
await self.send(event_dict)
elif isinstance(event, ErrorObservation):
# send error events as agent events to the UI
event_dict = event_to_dict(event)
event_dict['source'] = EventSource.AGENT
await self.send(event_dict)
async def dispatch(self, data: dict):
event = event_from_dict(data.copy())
# This checks if the model supports images
if isinstance(event, MessageAction) and event.image_urls:
controller = self.agent_session.controller
if controller:
if controller.agent.llm.config.disable_vision:
await self.send_error(
'Support for images is disabled for this model, try without an image.'
)
return
if not controller.agent.llm.vision_is_active():
await self.send_error(
'Model does not support image upload, change to a different model or try without an image.'
)
return
self.agent_session.event_stream.add_event(event, EventSource.USER)
async def send(self, data: dict[str, object]):
if asyncio.get_running_loop() != self.loop:
self.loop.create_task(self._send(data))
return
await self._send(data)
async def _send(self, data: dict[str, object]) -> bool:
try:
if not self.is_alive:
return False
if self.sio:
await self.sio.emit('oh_event', data, to=ROOM_KEY.format(sid=self.sid))
await asyncio.sleep(0.001) # This flushes the data to the client
self.last_active_ts = int(time.time())
return True
except RuntimeError as e:
logger.error(f'Error sending data to websocket: {str(e)}')
self.is_alive = False
return False
async def send_error(self, message: str):
"""Sends an error message to the client."""
await self.send({'error': True, 'message': message})
async def _send_status_message(self, msg_type: str, id: str, message: str):
"""Sends a status message to the client."""
if msg_type == 'error':
await self.agent_session.stop_agent_loop_for_error()
await self.send(
{'status_update': True, 'type': msg_type, 'id': id, 'message': message}
)
def queue_status_message(self, msg_type: str, id: str, message: str):
"""Queues a status message to be sent asynchronously."""
asyncio.run_coroutine_threadsafe(
self._send_status_message(msg_type, id, message), self.loop
)