flare / chat_session /session.py
ciyidogan's picture
Upload 134 files
edec17e verified
"""
Optimized Session Management for Flare Platform
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from datetime import datetime
import json
import secrets
import hashlib
import time
from config.config_models import VersionConfig, IntentConfig
from utils.logger import log_debug, log_info
@dataclass
class Session:
"""Optimized session for future Redis storage"""
MAX_CHAT_HISTORY: int = field(default=20, init=False, repr=False)
session_id: str
project_name: str
version_no: int
is_realtime: Optional[bool] = False
locale: Optional[str] = "tr"
# State management - string for better debugging
state: str = "idle" # idle | collect_params | call_api | humanize
# Minimal stored data
current_intent: Optional[str] = None
variables: Dict[str, str] = field(default_factory=dict)
project_id: Optional[int] = None
version_id: Optional[int] = None
# Chat history - limited to recent messages
chat_history: List[Dict[str, str]] = field(default_factory=list)
# Metadata
created_at: str = field(default_factory=lambda: datetime.utcnow().isoformat())
last_activity: str = field(default_factory=lambda: datetime.utcnow().isoformat())
# Parameter collection state
awaiting_parameters: List[str] = field(default_factory=list)
asked_parameters: Dict[str, int] = field(default_factory=dict)
unanswered_parameters: List[str] = field(default_factory=list)
parameter_ask_rounds: int = 0
# Transient data (not serialized to Redis)
_version_config: Optional[VersionConfig] = field(default=None, init=False, repr=False)
_intent_config: Optional[IntentConfig] = field(default=None, init=False, repr=False)
_auth_tokens: Dict[str, Dict] = field(default_factory=dict, init=False, repr=False)
def add_message(self, role: str, content: str) -> None:
"""Add message to chat history with size limit"""
message = {
"role": role,
"content": content,
"timestamp": datetime.utcnow().isoformat()
}
self.chat_history.append(message)
# Keep only recent messages
if len(self.chat_history) > self.MAX_CHAT_HISTORY:
self.chat_history = self.chat_history[-self.MAX_CHAT_HISTORY:]
# Update activity
self.last_activity = datetime.utcnow().isoformat()
log_debug(
f"Message added to session",
session_id=self.session_id,
role=role,
history_size=len(self.chat_history)
)
def add_turn(self, role: str, content: str) -> None:
"""Alias for add_message for compatibility"""
self.add_message(role, content)
def set_version_config(self, config: VersionConfig) -> None:
"""Set transient version config"""
self._version_config = config
def get_version_config(self) -> Optional[VersionConfig]:
"""Get transient version config"""
return self._version_config
def set_intent_config(self, config: IntentConfig) -> None:
"""Set current intent config"""
self._intent_config = config
self.current_intent = config.name if config else None
def get_intent_config(self) -> Optional[IntentConfig]:
"""Get current intent config"""
return self._intent_config
def reset_flow(self) -> None:
"""Reset conversation flow to idle"""
self.state = "idle"
self.current_intent = None
self._intent_config = None
self.awaiting_parameters = []
self.asked_parameters = {}
self.unanswered_parameters = []
self.parameter_ask_rounds = 0
log_debug(
f"Session flow reset",
session_id=self.session_id
)
def to_redis(self) -> str:
"""Serialize for Redis storage"""
data = {
'session_id': self.session_id,
'project_name': self.project_name,
'version_no': self.version_no,
'state': self.state,
'current_intent': self.current_intent,
'variables': self.variables,
'project_id': self.project_id,
'version_id': self.version_id,
'chat_history': self.chat_history[-self.MAX_CHAT_HISTORY:],
'created_at': self.created_at,
'last_activity': self.last_activity,
'awaiting_parameters': self.awaiting_parameters,
'asked_parameters': self.asked_parameters,
'unanswered_parameters': self.unanswered_parameters,
'parameter_ask_rounds': self.parameter_ask_rounds,
'is_realtime': self.is_realtime
}
return json.dumps(data, ensure_ascii=False)
@classmethod
def from_redis(cls, data: str) -> 'Session':
"""Deserialize from Redis"""
obj = json.loads(data)
return cls(**obj)
def get_state_info(self) -> dict:
"""Get debug info about current state"""
return {
'state': self.state,
'intent': self.current_intent,
'variables': list(self.variables.keys()),
'history_length': len(self.chat_history),
'awaiting_params': self.awaiting_parameters,
'last_activity': self.last_activity
}
def get_auth_token(self, api_name: str) -> Optional[Dict]:
"""Get cached auth token for API"""
return self._auth_tokens.get(api_name)
def set_auth_token(self, api_name: str, token_data: Dict) -> None:
"""Cache auth token for API"""
self._auth_tokens[api_name] = token_data
def is_expired(self, timeout_minutes: int = 30) -> bool:
"""Check if session is expired"""
last_activity_time = datetime.fromisoformat(self.last_activity.replace('Z', '+00:00'))
current_time = datetime.utcnow()
elapsed_minutes = (current_time - last_activity_time).total_seconds() / 60
return elapsed_minutes > timeout_minutes
def generate_secure_session_id() -> str:
"""Generate cryptographically secure session ID"""
# Use secrets for secure random generation
random_bytes = secrets.token_bytes(32)
# Add timestamp for uniqueness
timestamp = str(int(time.time() * 1000000))
# Combine and hash
combined = random_bytes + timestamp.encode()
session_id = hashlib.sha256(combined).hexdigest()
return f"session_{session_id[:32]}"
class SessionStore:
"""In-memory session store (to be replaced with Redis)"""
def __init__(self):
self._sessions: Dict[str, Session] = {}
self._lock = threading.Lock()
def create_session(
self,
project_name: str,
version_no: int,
is_realtime: bool = False,
locale: str = "tr"
) -> Session:
"""Create new session"""
session_id = generate_secure_session_id()
session = Session(
session_id=session_id,
project_name=project_name,
version_no=version_no,
is_realtime=is_realtime,
locale=locale
)
with self._lock:
self._sessions[session_id] = session
log_info(
"Session created",
session_id=session_id,
project=project_name,
version=version_no,
is_realtime=is_realtime,
locale=locale
)
return session
def get_session(self, session_id: str) -> Optional[Session]:
"""Get session by ID"""
with self._lock:
session = self._sessions.get(session_id)
if session and session.is_expired():
log_info(f"Session expired", session_id=session_id)
self.delete_session(session_id)
return None
return session
def update_session(self, session: Session) -> None:
"""Update session in store"""
session.last_activity = datetime.utcnow().isoformat()
with self._lock:
self._sessions[session.session_id] = session
def delete_session(self, session_id: str) -> None:
"""Delete session"""
with self._lock:
if session_id in self._sessions:
del self._sessions[session_id]
log_info(f"Session deleted", session_id=session_id)
def cleanup_expired_sessions(self, timeout_minutes: int = 30) -> int:
"""Clean up expired sessions"""
expired_count = 0
with self._lock:
expired_ids = [
sid for sid, session in self._sessions.items()
if session.is_expired(timeout_minutes)
]
for session_id in expired_ids:
del self._sessions[session_id]
expired_count += 1
if expired_count > 0:
log_info(
f"Cleaned up expired sessions",
count=expired_count
)
return expired_count
def get_active_session_count(self) -> int:
"""Get count of active sessions"""
with self._lock:
return len(self._sessions)
def get_session_stats(self) -> Dict[str, Any]:
"""Get session statistics"""
with self._lock:
realtime_count = sum(
1 for s in self._sessions.values()
if s.is_realtime
)
return {
'total_sessions': len(self._sessions),
'realtime_sessions': realtime_count,
'regular_sessions': len(self._sessions) - realtime_count
}
# Global session store instance
import threading
session_store = SessionStore()
# Session cleanup task
def start_session_cleanup(interval_minutes: int = 5, timeout_minutes: int = 30):
"""Start background task to clean up expired sessions"""
import asyncio
async def cleanup_task():
while True:
try:
expired = session_store.cleanup_expired_sessions(timeout_minutes)
if expired > 0:
log_info(f"Session cleanup completed", expired=expired)
except Exception as e:
log_error(f"Session cleanup error", error=str(e))
await asyncio.sleep(interval_minutes * 60)
# Run in background
asyncio.create_task(cleanup_task())