Spaces:
Building
Building
""" | |
Optimized Session Management for Flare Platform | |
""" | |
from dataclasses import dataclass, field | |
from typing import Dict, List, Optional, Any | |
from datetime import datetime | |
import json | |
import secrets | |
import hashlib | |
import time | |
from config.config_models import VersionConfig, IntentConfig | |
from utils.logger import log_debug, log_info | |
class Session: | |
"""Optimized session for future Redis storage""" | |
MAX_CHAT_HISTORY: int = field(default=20, init=False, repr=False) | |
session_id: str | |
project_name: str | |
version_no: int | |
is_realtime: Optional[bool] = False | |
locale: Optional[str] = "tr" | |
# State management - string for better debugging | |
state: str = "idle" # idle | collect_params | call_api | humanize | |
# Minimal stored data | |
current_intent: Optional[str] = None | |
variables: Dict[str, str] = field(default_factory=dict) | |
project_id: Optional[int] = None | |
version_id: Optional[int] = None | |
# Chat history - limited to recent messages | |
chat_history: List[Dict[str, str]] = field(default_factory=list) | |
# Metadata | |
created_at: str = field(default_factory=lambda: datetime.utcnow().isoformat()) | |
last_activity: str = field(default_factory=lambda: datetime.utcnow().isoformat()) | |
# Parameter collection state | |
awaiting_parameters: List[str] = field(default_factory=list) | |
asked_parameters: Dict[str, int] = field(default_factory=dict) | |
unanswered_parameters: List[str] = field(default_factory=list) | |
parameter_ask_rounds: int = 0 | |
# Transient data (not serialized to Redis) | |
_version_config: Optional[VersionConfig] = field(default=None, init=False, repr=False) | |
_intent_config: Optional[IntentConfig] = field(default=None, init=False, repr=False) | |
_auth_tokens: Dict[str, Dict] = field(default_factory=dict, init=False, repr=False) | |
def add_message(self, role: str, content: str) -> None: | |
"""Add message to chat history with size limit""" | |
message = { | |
"role": role, | |
"content": content, | |
"timestamp": datetime.utcnow().isoformat() | |
} | |
self.chat_history.append(message) | |
# Keep only recent messages | |
if len(self.chat_history) > self.MAX_CHAT_HISTORY: | |
self.chat_history = self.chat_history[-self.MAX_CHAT_HISTORY:] | |
# Update activity | |
self.last_activity = datetime.utcnow().isoformat() | |
log_debug( | |
f"Message added to session", | |
session_id=self.session_id, | |
role=role, | |
history_size=len(self.chat_history) | |
) | |
def add_turn(self, role: str, content: str) -> None: | |
"""Alias for add_message for compatibility""" | |
self.add_message(role, content) | |
def set_version_config(self, config: VersionConfig) -> None: | |
"""Set transient version config""" | |
self._version_config = config | |
def get_version_config(self) -> Optional[VersionConfig]: | |
"""Get transient version config""" | |
return self._version_config | |
def set_intent_config(self, config: IntentConfig) -> None: | |
"""Set current intent config""" | |
self._intent_config = config | |
self.current_intent = config.name if config else None | |
def get_intent_config(self) -> Optional[IntentConfig]: | |
"""Get current intent config""" | |
return self._intent_config | |
def reset_flow(self) -> None: | |
"""Reset conversation flow to idle""" | |
self.state = "idle" | |
self.current_intent = None | |
self._intent_config = None | |
self.awaiting_parameters = [] | |
self.asked_parameters = {} | |
self.unanswered_parameters = [] | |
self.parameter_ask_rounds = 0 | |
log_debug( | |
f"Session flow reset", | |
session_id=self.session_id | |
) | |
def to_redis(self) -> str: | |
"""Serialize for Redis storage""" | |
data = { | |
'session_id': self.session_id, | |
'project_name': self.project_name, | |
'version_no': self.version_no, | |
'state': self.state, | |
'current_intent': self.current_intent, | |
'variables': self.variables, | |
'project_id': self.project_id, | |
'version_id': self.version_id, | |
'chat_history': self.chat_history[-self.MAX_CHAT_HISTORY:], | |
'created_at': self.created_at, | |
'last_activity': self.last_activity, | |
'awaiting_parameters': self.awaiting_parameters, | |
'asked_parameters': self.asked_parameters, | |
'unanswered_parameters': self.unanswered_parameters, | |
'parameter_ask_rounds': self.parameter_ask_rounds, | |
'is_realtime': self.is_realtime | |
} | |
return json.dumps(data, ensure_ascii=False) | |
def from_redis(cls, data: str) -> 'Session': | |
"""Deserialize from Redis""" | |
obj = json.loads(data) | |
return cls(**obj) | |
def get_state_info(self) -> dict: | |
"""Get debug info about current state""" | |
return { | |
'state': self.state, | |
'intent': self.current_intent, | |
'variables': list(self.variables.keys()), | |
'history_length': len(self.chat_history), | |
'awaiting_params': self.awaiting_parameters, | |
'last_activity': self.last_activity | |
} | |
def get_auth_token(self, api_name: str) -> Optional[Dict]: | |
"""Get cached auth token for API""" | |
return self._auth_tokens.get(api_name) | |
def set_auth_token(self, api_name: str, token_data: Dict) -> None: | |
"""Cache auth token for API""" | |
self._auth_tokens[api_name] = token_data | |
def is_expired(self, timeout_minutes: int = 30) -> bool: | |
"""Check if session is expired""" | |
last_activity_time = datetime.fromisoformat(self.last_activity.replace('Z', '+00:00')) | |
current_time = datetime.utcnow() | |
elapsed_minutes = (current_time - last_activity_time).total_seconds() / 60 | |
return elapsed_minutes > timeout_minutes | |
def generate_secure_session_id() -> str: | |
"""Generate cryptographically secure session ID""" | |
# Use secrets for secure random generation | |
random_bytes = secrets.token_bytes(32) | |
# Add timestamp for uniqueness | |
timestamp = str(int(time.time() * 1000000)) | |
# Combine and hash | |
combined = random_bytes + timestamp.encode() | |
session_id = hashlib.sha256(combined).hexdigest() | |
return f"session_{session_id[:32]}" | |
class SessionStore: | |
"""In-memory session store (to be replaced with Redis)""" | |
def __init__(self): | |
self._sessions: Dict[str, Session] = {} | |
self._lock = threading.Lock() | |
def create_session( | |
self, | |
project_name: str, | |
version_no: int, | |
is_realtime: bool = False, | |
locale: str = "tr" | |
) -> Session: | |
"""Create new session""" | |
session_id = generate_secure_session_id() | |
session = Session( | |
session_id=session_id, | |
project_name=project_name, | |
version_no=version_no, | |
is_realtime=is_realtime, | |
locale=locale | |
) | |
with self._lock: | |
self._sessions[session_id] = session | |
log_info( | |
"Session created", | |
session_id=session_id, | |
project=project_name, | |
version=version_no, | |
is_realtime=is_realtime, | |
locale=locale | |
) | |
return session | |
def get_session(self, session_id: str) -> Optional[Session]: | |
"""Get session by ID""" | |
with self._lock: | |
session = self._sessions.get(session_id) | |
if session and session.is_expired(): | |
log_info(f"Session expired", session_id=session_id) | |
self.delete_session(session_id) | |
return None | |
return session | |
def update_session(self, session: Session) -> None: | |
"""Update session in store""" | |
session.last_activity = datetime.utcnow().isoformat() | |
with self._lock: | |
self._sessions[session.session_id] = session | |
def delete_session(self, session_id: str) -> None: | |
"""Delete session""" | |
with self._lock: | |
if session_id in self._sessions: | |
del self._sessions[session_id] | |
log_info(f"Session deleted", session_id=session_id) | |
def cleanup_expired_sessions(self, timeout_minutes: int = 30) -> int: | |
"""Clean up expired sessions""" | |
expired_count = 0 | |
with self._lock: | |
expired_ids = [ | |
sid for sid, session in self._sessions.items() | |
if session.is_expired(timeout_minutes) | |
] | |
for session_id in expired_ids: | |
del self._sessions[session_id] | |
expired_count += 1 | |
if expired_count > 0: | |
log_info( | |
f"Cleaned up expired sessions", | |
count=expired_count | |
) | |
return expired_count | |
def get_active_session_count(self) -> int: | |
"""Get count of active sessions""" | |
with self._lock: | |
return len(self._sessions) | |
def get_session_stats(self) -> Dict[str, Any]: | |
"""Get session statistics""" | |
with self._lock: | |
realtime_count = sum( | |
1 for s in self._sessions.values() | |
if s.is_realtime | |
) | |
return { | |
'total_sessions': len(self._sessions), | |
'realtime_sessions': realtime_count, | |
'regular_sessions': len(self._sessions) - realtime_count | |
} | |
# Global session store instance | |
import threading | |
session_store = SessionStore() | |
# Session cleanup task | |
def start_session_cleanup(interval_minutes: int = 5, timeout_minutes: int = 30): | |
"""Start background task to clean up expired sessions""" | |
import asyncio | |
async def cleanup_task(): | |
while True: | |
try: | |
expired = session_store.cleanup_expired_sessions(timeout_minutes) | |
if expired > 0: | |
log_info(f"Session cleanup completed", expired=expired) | |
except Exception as e: | |
log_error(f"Session cleanup error", error=str(e)) | |
await asyncio.sleep(interval_minutes * 60) | |
# Run in background | |
asyncio.create_task(cleanup_task()) | |