Spaces:
Running
Running
# | |
# SPDX-FileCopyrightText: Hadad <[email protected]> | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
import asyncio # Import the asyncio library to handle asynchronous operations and events | |
import requests # Import the requests library for HTTP requests and session management | |
import uuid # Import the uuid library to generate unique identifiers | |
import threading # Import threading to run background timers for delayed operations | |
from config import LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS # Import configuration variables that track marked provider keys and their failure attempts | |
class SessionWithID(requests.Session): | |
""" | |
Custom session class extending requests.Session to add unique session identification | |
and asynchronous cancellation control. This allows tracking individual user sessions | |
and managing cancellation of ongoing HTTP requests asynchronously. | |
""" | |
def __init__(self): | |
super().__init__() # Initialize the base requests.Session class | |
self.session_id = str(uuid.uuid4()) | |
# Generate and assign a unique string ID for this session instance to identify it uniquely | |
self.stop_event = asyncio.Event() | |
# Create an asyncio Event object used to signal when the session should stop or cancel operations | |
self.cancel_token = {"cancelled": False} | |
# Dictionary flag to indicate if the current session's operations have been cancelled | |
def create_session(): | |
""" | |
Factory function to create and return a new SessionWithID instance. | |
This should be called whenever a new user session starts or a chat session is reset, | |
ensuring each session has its own unique ID and cancellation controls. | |
""" | |
return SessionWithID() | |
def ensure_stop_event(sess): | |
""" | |
Utility function to verify that a given session object has the required asynchronous | |
control attributes: stop_event and cancel_token. If they are missing (e.g., when restoring | |
sessions from storage), this function adds them to maintain consistent session behavior. | |
Parameters: | |
- sess: The session object to check and update. | |
""" | |
if not hasattr(sess, "stop_event"): | |
sess.stop_event = asyncio.Event() | |
# Add an asyncio Event to signal stop requests if missing | |
if not hasattr(sess, "cancel_token"): | |
sess.cancel_token = {"cancelled": False} | |
# Add a cancellation flag dictionary if missing | |
def marked_item(item, marked, attempts): | |
""" | |
Mark a provider key or host as temporarily problematic after repeated failures to prevent | |
using unreliable providers continuously. This function adds the item to a 'marked' set | |
and increments its failure attempt count. If the failure count reaches 3 or more, a timer | |
is started to automatically unmark the item after 5 minutes (300 seconds), allowing retries. | |
Parameters: | |
- item: The provider key or host identifier to mark as problematic. | |
- marked: A set containing currently marked items. | |
- attempts: A dictionary tracking the number of failure attempts per item. | |
""" | |
marked.add(item) | |
# Add the item to the set of marked problematic providers | |
attempts[item] = attempts.get(item, 0) + 1 | |
# Increment the failure attempt count for this item, initializing if necessary | |
if attempts[item] >= 3: | |
# If the item has failed 3 or more times, schedule removal from marked after 5 minutes | |
def remove(): | |
marked.discard(item) | |
# Remove the item from the marked set to allow retrying | |
attempts.pop(item, None) | |
# Remove the attempt count entry for this item to reset its failure state | |
threading.Timer(300, remove).start() | |
# Start a background timer that will call remove() after 300 seconds (5 minutes) | |
def get_model_key(display, MODEL_MAPPING, DEFAULT_MODEL_KEY): | |
""" | |
Translate a human-readable model display name into its internal model key identifier. | |
Searches the MODEL_MAPPING dictionary for the key whose value matches the display name. | |
Returns the DEFAULT_MODEL_KEY if no matching display name is found. | |
Parameters: | |
- display: The display name of the model as a string. | |
- MODEL_MAPPING: Dictionary mapping internal model keys to display names. | |
- DEFAULT_MODEL_KEY: The fallback model key to return if no match is found. | |
Returns: | |
- The internal model key string corresponding to the display name. | |
""" | |
# Iterate through the MODEL_MAPPING dictionary items and return the key where the value matches the display name | |
return next((k for k, v in MODEL_MAPPING.items() if v == display), DEFAULT_MODEL_KEY) | |