Spaces:
Building
Building
""" | |
Event Bus Implementation for Flare | |
================================== | |
Provides async event publishing and subscription mechanism | |
""" | |
import asyncio | |
from typing import Dict, List, Callable, Any, Optional | |
from enum import Enum | |
from dataclasses import dataclass, field | |
from datetime import datetime | |
import traceback | |
from collections import defaultdict | |
import sys | |
from utils.logger import log_info, log_error, log_debug, log_warning | |
class EventType(Enum): | |
"""All event types in the system""" | |
# Lifecycle events | |
SESSION_STARTED = "session_started" | |
SESSION_ENDED = "session_ended" | |
CONVERSATION_STARTED = "conversation_started" | |
CONVERSATION_ENDED = "conversation_ended" | |
# STT events | |
STT_STARTED = "stt_started" | |
STT_STOPPED = "stt_stopped" | |
STT_RESULT = "stt_result" | |
STT_ERROR = "stt_error" | |
STT_READY = "stt_ready" | |
# TTS events | |
TTS_STARTED = "tts_started" | |
TTS_CHUNK_READY = "tts_chunk_ready" | |
TTS_COMPLETED = "tts_completed" | |
TTS_ERROR = "tts_error" | |
TTS_STOPPED = "tts_stopped" | |
# Audio events | |
AUDIO_PLAYBACK_STARTED = "audio_playback_started" | |
AUDIO_PLAYBACK_COMPLETED = "audio_playback_completed" | |
AUDIO_BUFFER_LOW = "audio_buffer_low" | |
AUDIO_CHUNK_RECEIVED = "audio_chunk_received" | |
# LLM events | |
LLM_PROCESSING_STARTED = "llm_processing_started" | |
LLM_RESPONSE_READY = "llm_response_ready" | |
LLM_ERROR = "llm_error" | |
# Error events | |
CRITICAL_ERROR = "critical_error" | |
RECOVERABLE_ERROR = "recoverable_error" | |
# State events | |
STATE_TRANSITION = "state_transition" | |
STATE_ROLLBACK = "state_rollback" | |
# WebSocket events | |
WEBSOCKET_CONNECTED = "websocket_connected" | |
WEBSOCKET_DISCONNECTED = "websocket_disconnected" | |
WEBSOCKET_MESSAGE = "websocket_message" | |
WEBSOCKET_ERROR = "websocket_error" | |
class Event: | |
"""Event data structure""" | |
type: EventType | |
data: Dict[str, Any] | |
session_id: Optional[str] = None | |
timestamp: datetime = field(default_factory=datetime.utcnow) | |
priority: int = 0 | |
def __lt__(self, other): | |
"""Compare events by priority for PriorityQueue""" | |
if not isinstance(other, Event): | |
return NotImplemented | |
# Önce priority'ye göre karşılaştır | |
if self.priority != other.priority: | |
return self.priority < other.priority | |
# Priority eşitse timestamp'e göre karşılaştır | |
return self.timestamp < other.timestamp | |
def __eq__(self, other): | |
"""Check event equality""" | |
if not isinstance(other, Event): | |
return NotImplemented | |
return (self.priority == other.priority and | |
self.timestamp == other.timestamp and | |
self.type == other.type) | |
def __le__(self, other): | |
"""Less than or equal comparison""" | |
return self == other or self < other | |
def __gt__(self, other): | |
"""Greater than comparison""" | |
return not self <= other | |
def __ge__(self, other): | |
"""Greater than or equal comparison""" | |
return not self < other | |
def __post_init__(self): | |
if self.timestamp is None: | |
self.timestamp = datetime.utcnow() | |
def to_dict(self) -> Dict[str, Any]: | |
"""Convert to dictionary for serialization""" | |
return { | |
"type": self.type.value, | |
"session_id": self.session_id, | |
"data": self.data, | |
"timestamp": self.timestamp.isoformat(), | |
"priority": self.priority | |
} | |
class EventBus: | |
"""Central event bus for component communication with session isolation""" | |
def __init__(self): | |
self._subscribers: Dict[EventType, List[Callable]] = defaultdict(list) | |
self._session_handlers: Dict[str, Dict[EventType, List[Callable]]] = defaultdict(lambda: defaultdict(list)) | |
# Session-specific queues for parallel processing | |
self._session_queues: Dict[str, asyncio.PriorityQueue] = {} | |
self._session_processors: Dict[str, asyncio.Task] = {} | |
# Global queue for non-session events | |
self._global_queue: asyncio.PriorityQueue = asyncio.PriorityQueue() | |
self._global_processor: Optional[asyncio.Task] = None | |
self._running = False | |
self._event_history: List[Event] = [] | |
self._max_history_size = 1000 | |
async def start(self): | |
"""Start the event processor""" | |
if self._running: | |
log_warning("EventBus already running") | |
return | |
self._running = True | |
# Start global processor | |
self._global_processor = asyncio.create_task(self._process_global_events()) | |
log_info("✅ EventBus started") | |
async def stop(self): | |
"""Stop the event processor""" | |
self._running = False | |
# Stop all session processors | |
for session_id, task in list(self._session_processors.items()): | |
task.cancel() | |
try: | |
await asyncio.wait_for(task, timeout=2.0) | |
except (asyncio.TimeoutError, asyncio.CancelledError): | |
pass | |
# Stop global processor | |
if self._global_processor: | |
await self._global_queue.put((999, None)) # Sentinel | |
try: | |
await asyncio.wait_for(self._global_processor, timeout=5.0) | |
except asyncio.TimeoutError: | |
log_warning("EventBus global processor timeout, cancelling") | |
self._global_processor.cancel() | |
log_info("✅ EventBus stopped") | |
async def publish(self, event: Event): | |
"""Publish an event to the bus""" | |
if not self._running: | |
log_error("EventBus not running, cannot publish event", event_type=event.type.value) | |
return | |
# Add to history | |
self._event_history.append(event) | |
if len(self._event_history) > self._max_history_size: | |
self._event_history.pop(0) | |
# Route to appropriate queue | |
if event.session_id: | |
# Ensure session queue exists | |
if event.session_id not in self._session_queues: | |
await self._create_session_processor(event.session_id) | |
# Add to session queue | |
queue = self._session_queues[event.session_id] | |
await queue.put((-event.priority, event)) | |
else: | |
# Add to global queue | |
await self._global_queue.put((-event.priority, event)) | |
async def _create_session_processor(self, session_id: str): | |
"""Create a processor for session-specific events""" | |
if session_id in self._session_processors: | |
return | |
# Create queue | |
self._session_queues[session_id] = asyncio.PriorityQueue() | |
# Create processor task | |
task = asyncio.create_task(self._process_session_events(session_id)) | |
self._session_processors[session_id] = task | |
log_debug(f"📌 Created session processor", session_id=session_id) | |
async def _process_session_events(self, session_id: str): | |
"""Process events for a specific session""" | |
queue = self._session_queues[session_id] | |
log_info(f"🔄 Session event processor started", session_id=session_id) | |
while self._running: | |
try: | |
# Wait for event with timeout | |
priority, event = await asyncio.wait_for( | |
queue.get(), | |
timeout=60.0 # Longer timeout for sessions | |
) | |
# Check for session cleanup | |
if event is None: | |
break | |
# Process the event | |
await self._dispatch_event(event) | |
except asyncio.TimeoutError: | |
# Check if session is still active | |
if session_id not in self._session_handlers: | |
log_info(f"Session inactive, stopping processor", session_id=session_id) | |
break | |
continue | |
except Exception as e: | |
log_error( | |
f"❌ Error processing session event", | |
session_id=session_id, | |
error=str(e), | |
traceback=traceback.format_exc() | |
) | |
# Cleanup | |
self._session_queues.pop(session_id, None) | |
self._session_processors.pop(session_id, None) | |
log_info(f"🔄 Session event processor stopped", session_id=session_id) | |
async def _process_global_events(self): | |
"""Process global events (no session_id)""" | |
log_info("🔄 Global event processor started") | |
while self._running: | |
try: | |
priority, event = await asyncio.wait_for( | |
self._global_queue.get(), | |
timeout=1.0 | |
) | |
if event is None: # Sentinel | |
break | |
await self._dispatch_event(event) | |
except asyncio.TimeoutError: | |
continue | |
except Exception as e: | |
log_error( | |
"❌ Error processing global event", | |
error=str(e), | |
traceback=traceback.format_exc() | |
) | |
log_info("🔄 Global event processor stopped") | |
def subscribe(self, event_type: EventType, handler: Callable): | |
"""Subscribe to an event type globally""" | |
self._subscribers[event_type].append(handler) | |
log_debug(f"📌 Global subscription added", event_type=event_type.value) | |
def subscribe_session(self, session_id: str, event_type: EventType, handler: Callable): | |
"""Subscribe to an event type for a specific session""" | |
self._session_handlers[session_id][event_type].append(handler) | |
log_debug( | |
f"📌 Session subscription added", | |
event_type=event_type.value, | |
session_id=session_id | |
) | |
def unsubscribe(self, event_type: EventType, handler: Callable): | |
"""Unsubscribe from an event type""" | |
if handler in self._subscribers[event_type]: | |
self._subscribers[event_type].remove(handler) | |
log_debug(f"📌 Global subscription removed", event_type=event_type.value) | |
def unsubscribe_session(self, session_id: str, event_type: EventType = None): | |
"""Unsubscribe session handlers""" | |
if event_type: | |
# Remove specific event type for session | |
if session_id in self._session_handlers and event_type in self._session_handlers[session_id]: | |
del self._session_handlers[session_id][event_type] | |
else: | |
# Remove all handlers for session | |
if session_id in self._session_handlers: | |
del self._session_handlers[session_id] | |
log_debug(f"📌 All session subscriptions removed", session_id=session_id) | |
async def _dispatch_event(self, event: Event): | |
"""Dispatch event to all subscribers""" | |
try: | |
handlers = [] | |
# Get global handlers | |
if event.type in self._subscribers: | |
handlers.extend(self._subscribers[event.type]) | |
# Get session-specific handlers | |
if event.session_id in self._session_handlers: | |
if event.type in self._session_handlers[event.session_id]: | |
handlers.extend(self._session_handlers[event.session_id][event.type]) | |
if not handlers: | |
log_debug( | |
f"📭 No handlers for event", | |
event_type=event.type.value, | |
session_id=event.session_id | |
) | |
return | |
# Call all handlers concurrently | |
tasks = [] | |
for handler in handlers: | |
if asyncio.iscoroutinefunction(handler): | |
task = asyncio.create_task(handler(event)) | |
else: | |
# Wrap sync handler in async | |
task = asyncio.create_task(asyncio.to_thread(handler, event)) | |
tasks.append(task) | |
# Wait for all handlers to complete | |
results = await asyncio.gather(*tasks, return_exceptions=True) | |
# Log any exceptions | |
for i, result in enumerate(results): | |
if isinstance(result, Exception): | |
log_error( | |
f"❌ Handler error", | |
handler=handlers[i].__name__, | |
event_type=event.type.value, | |
error=str(result), | |
traceback=traceback.format_exception(type(result), result, result.__traceback__) | |
) | |
except Exception as e: | |
log_error( | |
f"❌ Error dispatching event", | |
event_type=event.type.value, | |
error=str(e), | |
traceback=traceback.format_exc() | |
) | |
def get_event_history(self, session_id: Optional[str] = None, event_type: Optional[EventType] = None) -> List[Event]: | |
"""Get event history with optional filters""" | |
history = self._event_history | |
if session_id: | |
history = [e for e in history if e.session_id == session_id] | |
if event_type: | |
history = [e for e in history if e.type == event_type] | |
return history | |
def clear_session_data(self, session_id: str): | |
"""Clear all session-related data and stop processor""" | |
# Remove session handlers | |
self.unsubscribe_session(session_id) | |
# Stop session processor | |
if session_id in self._session_processors: | |
task = self._session_processors[session_id] | |
task.cancel() | |
# Clear queues | |
self._session_queues.pop(session_id, None) | |
self._session_processors.pop(session_id, None) | |
# Remove session events from history | |
self._event_history = [e for e in self._event_history if e.session_id != session_id] | |
log_debug(f"🧹 Session data cleared", session_id=session_id) | |
# Global event bus instance | |
event_bus = EventBus() | |
# Helper functions for common event publishing patterns | |
async def publish_error(session_id: str, error_type: str, error_message: str, details: Dict[str, Any] = None): | |
"""Helper to publish error events""" | |
event = Event( | |
type=EventType.RECOVERABLE_ERROR if error_type != "critical" else EventType.CRITICAL_ERROR, | |
session_id=session_id, | |
data={ | |
"error_type": error_type, | |
"message": error_message, | |
"details": details or {} | |
}, | |
priority=10 # High priority for errors | |
) | |
await event_bus.publish(event) | |
async def publish_state_transition(session_id: str, from_state: str, to_state: str, reason: str = None): | |
"""Helper to publish state transition events""" | |
event = Event( | |
type=EventType.STATE_TRANSITION, | |
session_id=session_id, | |
data={ | |
"from_state": from_state, | |
"to_state": to_state, | |
"reason": reason | |
}, | |
priority=5 # Medium priority for state changes | |
) | |
await event_bus.publish(event) |