Spaces:
Building
Building
File size: 10,609 Bytes
edec17e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 |
"""
Optimized Session Management for Flare Platform
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from datetime import datetime
import json
import secrets
import hashlib
import time
from config.config_models import VersionConfig, IntentConfig
from utils.logger import log_debug, log_info
@dataclass
class Session:
"""Optimized session for future Redis storage"""
MAX_CHAT_HISTORY: int = field(default=20, init=False, repr=False)
session_id: str
project_name: str
version_no: int
is_realtime: Optional[bool] = False
locale: Optional[str] = "tr"
# State management - string for better debugging
state: str = "idle" # idle | collect_params | call_api | humanize
# Minimal stored data
current_intent: Optional[str] = None
variables: Dict[str, str] = field(default_factory=dict)
project_id: Optional[int] = None
version_id: Optional[int] = None
# Chat history - limited to recent messages
chat_history: List[Dict[str, str]] = field(default_factory=list)
# Metadata
created_at: str = field(default_factory=lambda: datetime.utcnow().isoformat())
last_activity: str = field(default_factory=lambda: datetime.utcnow().isoformat())
# Parameter collection state
awaiting_parameters: List[str] = field(default_factory=list)
asked_parameters: Dict[str, int] = field(default_factory=dict)
unanswered_parameters: List[str] = field(default_factory=list)
parameter_ask_rounds: int = 0
# Transient data (not serialized to Redis)
_version_config: Optional[VersionConfig] = field(default=None, init=False, repr=False)
_intent_config: Optional[IntentConfig] = field(default=None, init=False, repr=False)
_auth_tokens: Dict[str, Dict] = field(default_factory=dict, init=False, repr=False)
def add_message(self, role: str, content: str) -> None:
"""Add message to chat history with size limit"""
message = {
"role": role,
"content": content,
"timestamp": datetime.utcnow().isoformat()
}
self.chat_history.append(message)
# Keep only recent messages
if len(self.chat_history) > self.MAX_CHAT_HISTORY:
self.chat_history = self.chat_history[-self.MAX_CHAT_HISTORY:]
# Update activity
self.last_activity = datetime.utcnow().isoformat()
log_debug(
f"Message added to session",
session_id=self.session_id,
role=role,
history_size=len(self.chat_history)
)
def add_turn(self, role: str, content: str) -> None:
"""Alias for add_message for compatibility"""
self.add_message(role, content)
def set_version_config(self, config: VersionConfig) -> None:
"""Set transient version config"""
self._version_config = config
def get_version_config(self) -> Optional[VersionConfig]:
"""Get transient version config"""
return self._version_config
def set_intent_config(self, config: IntentConfig) -> None:
"""Set current intent config"""
self._intent_config = config
self.current_intent = config.name if config else None
def get_intent_config(self) -> Optional[IntentConfig]:
"""Get current intent config"""
return self._intent_config
def reset_flow(self) -> None:
"""Reset conversation flow to idle"""
self.state = "idle"
self.current_intent = None
self._intent_config = None
self.awaiting_parameters = []
self.asked_parameters = {}
self.unanswered_parameters = []
self.parameter_ask_rounds = 0
log_debug(
f"Session flow reset",
session_id=self.session_id
)
def to_redis(self) -> str:
"""Serialize for Redis storage"""
data = {
'session_id': self.session_id,
'project_name': self.project_name,
'version_no': self.version_no,
'state': self.state,
'current_intent': self.current_intent,
'variables': self.variables,
'project_id': self.project_id,
'version_id': self.version_id,
'chat_history': self.chat_history[-self.MAX_CHAT_HISTORY:],
'created_at': self.created_at,
'last_activity': self.last_activity,
'awaiting_parameters': self.awaiting_parameters,
'asked_parameters': self.asked_parameters,
'unanswered_parameters': self.unanswered_parameters,
'parameter_ask_rounds': self.parameter_ask_rounds,
'is_realtime': self.is_realtime
}
return json.dumps(data, ensure_ascii=False)
@classmethod
def from_redis(cls, data: str) -> 'Session':
"""Deserialize from Redis"""
obj = json.loads(data)
return cls(**obj)
def get_state_info(self) -> dict:
"""Get debug info about current state"""
return {
'state': self.state,
'intent': self.current_intent,
'variables': list(self.variables.keys()),
'history_length': len(self.chat_history),
'awaiting_params': self.awaiting_parameters,
'last_activity': self.last_activity
}
def get_auth_token(self, api_name: str) -> Optional[Dict]:
"""Get cached auth token for API"""
return self._auth_tokens.get(api_name)
def set_auth_token(self, api_name: str, token_data: Dict) -> None:
"""Cache auth token for API"""
self._auth_tokens[api_name] = token_data
def is_expired(self, timeout_minutes: int = 30) -> bool:
"""Check if session is expired"""
last_activity_time = datetime.fromisoformat(self.last_activity.replace('Z', '+00:00'))
current_time = datetime.utcnow()
elapsed_minutes = (current_time - last_activity_time).total_seconds() / 60
return elapsed_minutes > timeout_minutes
def generate_secure_session_id() -> str:
"""Generate cryptographically secure session ID"""
# Use secrets for secure random generation
random_bytes = secrets.token_bytes(32)
# Add timestamp for uniqueness
timestamp = str(int(time.time() * 1000000))
# Combine and hash
combined = random_bytes + timestamp.encode()
session_id = hashlib.sha256(combined).hexdigest()
return f"session_{session_id[:32]}"
class SessionStore:
"""In-memory session store (to be replaced with Redis)"""
def __init__(self):
self._sessions: Dict[str, Session] = {}
self._lock = threading.Lock()
def create_session(
self,
project_name: str,
version_no: int,
is_realtime: bool = False,
locale: str = "tr"
) -> Session:
"""Create new session"""
session_id = generate_secure_session_id()
session = Session(
session_id=session_id,
project_name=project_name,
version_no=version_no,
is_realtime=is_realtime,
locale=locale
)
with self._lock:
self._sessions[session_id] = session
log_info(
"Session created",
session_id=session_id,
project=project_name,
version=version_no,
is_realtime=is_realtime,
locale=locale
)
return session
def get_session(self, session_id: str) -> Optional[Session]:
"""Get session by ID"""
with self._lock:
session = self._sessions.get(session_id)
if session and session.is_expired():
log_info(f"Session expired", session_id=session_id)
self.delete_session(session_id)
return None
return session
def update_session(self, session: Session) -> None:
"""Update session in store"""
session.last_activity = datetime.utcnow().isoformat()
with self._lock:
self._sessions[session.session_id] = session
def delete_session(self, session_id: str) -> None:
"""Delete session"""
with self._lock:
if session_id in self._sessions:
del self._sessions[session_id]
log_info(f"Session deleted", session_id=session_id)
def cleanup_expired_sessions(self, timeout_minutes: int = 30) -> int:
"""Clean up expired sessions"""
expired_count = 0
with self._lock:
expired_ids = [
sid for sid, session in self._sessions.items()
if session.is_expired(timeout_minutes)
]
for session_id in expired_ids:
del self._sessions[session_id]
expired_count += 1
if expired_count > 0:
log_info(
f"Cleaned up expired sessions",
count=expired_count
)
return expired_count
def get_active_session_count(self) -> int:
"""Get count of active sessions"""
with self._lock:
return len(self._sessions)
def get_session_stats(self) -> Dict[str, Any]:
"""Get session statistics"""
with self._lock:
realtime_count = sum(
1 for s in self._sessions.values()
if s.is_realtime
)
return {
'total_sessions': len(self._sessions),
'realtime_sessions': realtime_count,
'regular_sessions': len(self._sessions) - realtime_count
}
# Global session store instance
import threading
session_store = SessionStore()
# Session cleanup task
def start_session_cleanup(interval_minutes: int = 5, timeout_minutes: int = 30):
"""Start background task to clean up expired sessions"""
import asyncio
async def cleanup_task():
while True:
try:
expired = session_store.cleanup_expired_sessions(timeout_minutes)
if expired > 0:
log_info(f"Session cleanup completed", expired=expired)
except Exception as e:
log_error(f"Session cleanup error", error=str(e))
await asyncio.sleep(interval_minutes * 60)
# Run in background
asyncio.create_task(cleanup_task())
|