# # SPDX-FileCopyrightText: Hadad # SPDX-License-Identifier: Apache-2.0 # import asyncio # Import the asyncio library to handle asynchronous operations and events import requests # Import the requests library for HTTP requests and session management import uuid # Import the uuid library to generate unique identifiers import threading # Import threading to run background timers for delayed operations from config import LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS # Import configuration variables that track marked provider keys and their failure attempts class SessionWithID(requests.Session): """ Custom session class extending requests.Session to add unique session identification and asynchronous cancellation control. This allows tracking individual user sessions and managing cancellation of ongoing HTTP requests asynchronously. """ def __init__(self): super().__init__() # Initialize the base requests.Session class self.session_id = str(uuid.uuid4()) # Generate and assign a unique string ID for this session instance to identify it uniquely self.stop_event = asyncio.Event() # Create an asyncio Event object used to signal when the session should stop or cancel operations self.cancel_token = {"cancelled": False} # Dictionary flag to indicate if the current session's operations have been cancelled def create_session(): """ Factory function to create and return a new SessionWithID instance. This should be called whenever a new user session starts or a chat session is reset, ensuring each session has its own unique ID and cancellation controls. """ return SessionWithID() def ensure_stop_event(sess): """ Utility function to verify that a given session object has the required asynchronous control attributes: stop_event and cancel_token. If they are missing (e.g., when restoring sessions from storage), this function adds them to maintain consistent session behavior. Parameters: - sess: The session object to check and update. """ if not hasattr(sess, "stop_event"): sess.stop_event = asyncio.Event() # Add an asyncio Event to signal stop requests if missing if not hasattr(sess, "cancel_token"): sess.cancel_token = {"cancelled": False} # Add a cancellation flag dictionary if missing def marked_item(item, marked, attempts): """ Mark a provider key or host as temporarily problematic after repeated failures to prevent using unreliable providers continuously. This function adds the item to a 'marked' set and increments its failure attempt count. If the failure count reaches 3 or more, a timer is started to automatically unmark the item after 5 minutes (300 seconds), allowing retries. Parameters: - item: The provider key or host identifier to mark as problematic. - marked: A set containing currently marked items. - attempts: A dictionary tracking the number of failure attempts per item. """ marked.add(item) # Add the item to the set of marked problematic providers attempts[item] = attempts.get(item, 0) + 1 # Increment the failure attempt count for this item, initializing if necessary if attempts[item] >= 3: # If the item has failed 3 or more times, schedule removal from marked after 5 minutes def remove(): marked.discard(item) # Remove the item from the marked set to allow retrying attempts.pop(item, None) # Remove the attempt count entry for this item to reset its failure state threading.Timer(300, remove).start() # Start a background timer that will call remove() after 300 seconds (5 minutes) def get_model_key(display, MODEL_MAPPING, DEFAULT_MODEL_KEY): """ Translate a human-readable model display name into its internal model key identifier. Searches the MODEL_MAPPING dictionary for the key whose value matches the display name. Returns the DEFAULT_MODEL_KEY if no matching display name is found. Parameters: - display: The display name of the model as a string. - MODEL_MAPPING: Dictionary mapping internal model keys to display names. - DEFAULT_MODEL_KEY: The fallback model key to return if no match is found. Returns: - The internal model key string corresponding to the display name. """ # Iterate through the MODEL_MAPPING dictionary items and return the key where the value matches the display name return next((k for k, v in MODEL_MAPPING.items() if v == display), DEFAULT_MODEL_KEY)