ai / src /cores /session.py
hadadrjt's picture
ai: Refactor the code.
d17e7ef
raw
history blame
4.72 kB
#
# SPDX-FileCopyrightText: Hadad <[email protected]>
# SPDX-License-Identifier: Apache-2.0
#
import asyncio # Import the asyncio library to handle asynchronous operations and events
import requests # Import the requests library for HTTP requests and session management
import uuid # Import the uuid library to generate unique identifiers
import threading # Import threading to run background timers for delayed operations
from config import LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS # Import configuration variables that track marked provider keys and their failure attempts
class SessionWithID(requests.Session):
"""
Custom session class extending requests.Session to add unique session identification
and asynchronous cancellation control. This allows tracking individual user sessions
and managing cancellation of ongoing HTTP requests asynchronously.
"""
def __init__(self):
super().__init__() # Initialize the base requests.Session class
self.session_id = str(uuid.uuid4())
# Generate and assign a unique string ID for this session instance to identify it uniquely
self.stop_event = asyncio.Event()
# Create an asyncio Event object used to signal when the session should stop or cancel operations
self.cancel_token = {"cancelled": False}
# Dictionary flag to indicate if the current session's operations have been cancelled
def create_session():
"""
Factory function to create and return a new SessionWithID instance.
This should be called whenever a new user session starts or a chat session is reset,
ensuring each session has its own unique ID and cancellation controls.
"""
return SessionWithID()
def ensure_stop_event(sess):
"""
Utility function to verify that a given session object has the required asynchronous
control attributes: stop_event and cancel_token. If they are missing (e.g., when restoring
sessions from storage), this function adds them to maintain consistent session behavior.
Parameters:
- sess: The session object to check and update.
"""
if not hasattr(sess, "stop_event"):
sess.stop_event = asyncio.Event()
# Add an asyncio Event to signal stop requests if missing
if not hasattr(sess, "cancel_token"):
sess.cancel_token = {"cancelled": False}
# Add a cancellation flag dictionary if missing
def marked_item(item, marked, attempts):
"""
Mark a provider key or host as temporarily problematic after repeated failures to prevent
using unreliable providers continuously. This function adds the item to a 'marked' set
and increments its failure attempt count. If the failure count reaches 3 or more, a timer
is started to automatically unmark the item after 5 minutes (300 seconds), allowing retries.
Parameters:
- item: The provider key or host identifier to mark as problematic.
- marked: A set containing currently marked items.
- attempts: A dictionary tracking the number of failure attempts per item.
"""
marked.add(item)
# Add the item to the set of marked problematic providers
attempts[item] = attempts.get(item, 0) + 1
# Increment the failure attempt count for this item, initializing if necessary
if attempts[item] >= 3:
# If the item has failed 3 or more times, schedule removal from marked after 5 minutes
def remove():
marked.discard(item)
# Remove the item from the marked set to allow retrying
attempts.pop(item, None)
# Remove the attempt count entry for this item to reset its failure state
threading.Timer(300, remove).start()
# Start a background timer that will call remove() after 300 seconds (5 minutes)
def get_model_key(display, MODEL_MAPPING, DEFAULT_MODEL_KEY):
"""
Translate a human-readable model display name into its internal model key identifier.
Searches the MODEL_MAPPING dictionary for the key whose value matches the display name.
Returns the DEFAULT_MODEL_KEY if no matching display name is found.
Parameters:
- display: The display name of the model as a string.
- MODEL_MAPPING: Dictionary mapping internal model keys to display names.
- DEFAULT_MODEL_KEY: The fallback model key to return if no match is found.
Returns:
- The internal model key string corresponding to the display name.
"""
# Iterate through the MODEL_MAPPING dictionary items and return the key where the value matches the display name
return next((k for k, v in MODEL_MAPPING.items() if v == display), DEFAULT_MODEL_KEY)