Spaces:
Build error
Build error
import asyncio | |
import queue | |
import threading | |
from concurrent.futures import ThreadPoolExecutor | |
from datetime import datetime | |
from enum import Enum | |
from functools import partial | |
from typing import Any, Callable | |
from openhands.core.logger import openhands_logger as logger | |
from openhands.events.event import Event, EventSource | |
from openhands.events.event_store import EventStore | |
from openhands.events.serialization.event import event_from_dict, event_to_dict | |
from openhands.io import json | |
from openhands.storage import FileStore | |
from openhands.storage.locations import ( | |
get_conversation_dir, | |
) | |
from openhands.utils.async_utils import call_sync_from_async | |
from openhands.utils.shutdown_listener import should_continue | |
class EventStreamSubscriber(str, Enum): | |
AGENT_CONTROLLER = 'agent_controller' | |
SECURITY_ANALYZER = 'security_analyzer' | |
RESOLVER = 'openhands_resolver' | |
SERVER = 'server' | |
RUNTIME = 'runtime' | |
MEMORY = 'memory' | |
MAIN = 'main' | |
TEST = 'test' | |
async def session_exists( | |
sid: str, file_store: FileStore, user_id: str | None = None | |
) -> bool: | |
try: | |
await call_sync_from_async(file_store.list, get_conversation_dir(sid, user_id)) | |
return True | |
except FileNotFoundError: | |
return False | |
class EventStream(EventStore): | |
secrets: dict[str, str] | |
# For each subscriber ID, there is a map of callback functions - useful | |
# when there are multiple listeners | |
_subscribers: dict[str, dict[str, Callable]] | |
_lock: threading.Lock | |
_queue: queue.Queue[Event] | |
_queue_thread: threading.Thread | |
_queue_loop: asyncio.AbstractEventLoop | None | |
_thread_pools: dict[str, dict[str, ThreadPoolExecutor]] | |
_thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]] | |
_write_page_cache: list[dict] | |
def __init__(self, sid: str, file_store: FileStore, user_id: str | None = None): | |
super().__init__(sid, file_store, user_id) | |
self._stop_flag = threading.Event() | |
self._queue: queue.Queue[Event] = queue.Queue() | |
self._thread_pools = {} | |
self._thread_loops = {} | |
self._queue_loop = None | |
self._queue_thread = threading.Thread(target=self._run_queue_loop) | |
self._queue_thread.daemon = True | |
self._queue_thread.start() | |
self._subscribers = {} | |
self._lock = threading.Lock() | |
self.secrets = {} | |
self._write_page_cache = [] | |
def _init_thread_loop(self, subscriber_id: str, callback_id: str) -> None: | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
if subscriber_id not in self._thread_loops: | |
self._thread_loops[subscriber_id] = {} | |
self._thread_loops[subscriber_id][callback_id] = loop | |
def close(self) -> None: | |
self._stop_flag.set() | |
if self._queue_thread.is_alive(): | |
self._queue_thread.join() | |
subscriber_ids = list(self._subscribers.keys()) | |
for subscriber_id in subscriber_ids: | |
callback_ids = list(self._subscribers[subscriber_id].keys()) | |
for callback_id in callback_ids: | |
self._clean_up_subscriber(subscriber_id, callback_id) | |
# Clear queue | |
while not self._queue.empty(): | |
self._queue.get() | |
def _clean_up_subscriber(self, subscriber_id: str, callback_id: str) -> None: | |
if subscriber_id not in self._subscribers: | |
logger.warning(f'Subscriber not found during cleanup: {subscriber_id}') | |
return | |
if callback_id not in self._subscribers[subscriber_id]: | |
logger.warning(f'Callback not found during cleanup: {callback_id}') | |
return | |
if ( | |
subscriber_id in self._thread_loops | |
and callback_id in self._thread_loops[subscriber_id] | |
): | |
loop = self._thread_loops[subscriber_id][callback_id] | |
current_task = asyncio.current_task(loop) | |
pending = [ | |
task for task in asyncio.all_tasks(loop) if task is not current_task | |
] | |
for task in pending: | |
task.cancel() | |
try: | |
loop.stop() | |
loop.close() | |
except Exception as e: | |
logger.warning( | |
f'Error closing loop for {subscriber_id}/{callback_id}: {e}' | |
) | |
del self._thread_loops[subscriber_id][callback_id] | |
if ( | |
subscriber_id in self._thread_pools | |
and callback_id in self._thread_pools[subscriber_id] | |
): | |
pool = self._thread_pools[subscriber_id][callback_id] | |
pool.shutdown() | |
del self._thread_pools[subscriber_id][callback_id] | |
del self._subscribers[subscriber_id][callback_id] | |
def subscribe( | |
self, | |
subscriber_id: EventStreamSubscriber, | |
callback: Callable[[Event], None], | |
callback_id: str, | |
) -> None: | |
initializer = partial(self._init_thread_loop, subscriber_id, callback_id) | |
pool = ThreadPoolExecutor(max_workers=1, initializer=initializer) | |
if subscriber_id not in self._subscribers: | |
self._subscribers[subscriber_id] = {} | |
self._thread_pools[subscriber_id] = {} | |
if callback_id in self._subscribers[subscriber_id]: | |
raise ValueError( | |
f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}' | |
) | |
self._subscribers[subscriber_id][callback_id] = callback | |
self._thread_pools[subscriber_id][callback_id] = pool | |
def unsubscribe( | |
self, subscriber_id: EventStreamSubscriber, callback_id: str | |
) -> None: | |
if subscriber_id not in self._subscribers: | |
logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}') | |
return | |
if callback_id not in self._subscribers[subscriber_id]: | |
logger.warning(f'Callback not found during unsubscribe: {callback_id}') | |
return | |
self._clean_up_subscriber(subscriber_id, callback_id) | |
def add_event(self, event: Event, source: EventSource) -> None: | |
if event.id != Event.INVALID_ID: | |
raise ValueError( | |
f'Event already has an ID:{event.id}. It was probably added back to the EventStream from inside a handler, triggering a loop.' | |
) | |
event._timestamp = datetime.now().isoformat() | |
event._source = source # type: ignore [attr-defined] | |
with self._lock: | |
event._id = self.cur_id # type: ignore [attr-defined] | |
self.cur_id += 1 | |
# Take a copy of the current write page | |
current_write_page = self._write_page_cache | |
data = event_to_dict(event) | |
data = self._replace_secrets(data) | |
event = event_from_dict(data) | |
current_write_page.append(data) | |
# If the page is full, create a new page for future events / other threads to use | |
if len(current_write_page) == self.cache_size: | |
self._write_page_cache = [] | |
if event.id is not None: | |
# Write the event to the store - this can take some time | |
self.file_store.write( | |
self._get_filename_for_id(event.id, self.user_id), json.dumps(data) | |
) | |
# Store the cache page last - if it is not present during reads then it will simply be bypassed. | |
self._store_cache_page(current_write_page) | |
self._queue.put(event) | |
def _store_cache_page(self, current_write_page: list[dict]): | |
"""Store a page in the cache. Reading individual events is slow when there are a lot of them, so we use pages.""" | |
if len(current_write_page) < self.cache_size: | |
return | |
start = current_write_page[0]['id'] | |
end = start + self.cache_size | |
contents = json.dumps(current_write_page) | |
cache_filename = self._get_filename_for_cache(start, end) | |
self.file_store.write(cache_filename, contents) | |
def set_secrets(self, secrets: dict[str, str]) -> None: | |
self.secrets = secrets.copy() | |
def update_secrets(self, secrets: dict[str, str]) -> None: | |
self.secrets.update(secrets) | |
def _replace_secrets(self, data: dict[str, Any]) -> dict[str, Any]: | |
for key in data: | |
if isinstance(data[key], dict): | |
data[key] = self._replace_secrets(data[key]) | |
elif isinstance(data[key], str): | |
for secret in self.secrets.values(): | |
data[key] = data[key].replace(secret, '<secret_hidden>') | |
return data | |
def _run_queue_loop(self) -> None: | |
self._queue_loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(self._queue_loop) | |
try: | |
self._queue_loop.run_until_complete(self._process_queue()) | |
finally: | |
self._queue_loop.close() | |
async def _process_queue(self) -> None: | |
while should_continue() and not self._stop_flag.is_set(): | |
event = None | |
try: | |
event = self._queue.get(timeout=0.1) | |
except queue.Empty: | |
continue | |
# pass each event to each callback in order | |
for key in sorted(self._subscribers.keys()): | |
callbacks = self._subscribers[key] | |
# Create a copy of the keys to avoid "dictionary changed size during iteration" error | |
callback_ids = list(callbacks.keys()) | |
for callback_id in callback_ids: | |
# Check if callback_id still exists (might have been removed during iteration) | |
if callback_id in callbacks: | |
callback = callbacks[callback_id] | |
pool = self._thread_pools[key][callback_id] | |
future = pool.submit(callback, event) | |
future.add_done_callback( | |
self._make_error_handler(callback_id, key) | |
) | |
def _make_error_handler( | |
self, callback_id: str, subscriber_id: str | |
) -> Callable[[Any], None]: | |
def _handle_callback_error(fut: Any) -> None: | |
try: | |
# This will raise any exception that occurred during callback execution | |
fut.result() | |
except Exception as e: | |
logger.error( | |
f'Error in event callback {callback_id} for subscriber {subscriber_id}: {str(e)}', | |
) | |
# Re-raise in the main thread so the error is not swallowed | |
raise e | |
return _handle_callback_error | |