Spaces:
Sleeping
Sleeping
File size: 2,275 Bytes
e2d4dfc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from abc import ABC, abstractmethod
import logging
from typing import List, Union
from fastapi import FastAPI
from fastapi.responses import Response
from aiproxy.queueclient import QueueClientBase
logger = logging.getLogger(__name__)
# Classes for filter
class RequestFilterBase(ABC):
@abstractmethod
async def filter(self, request_id: str, request_json: dict, request_headers: dict) -> Union[str, None]:
...
class ResponseFilterBase(ABC):
@abstractmethod
async def filter(self, request_id: str, response_json: dict) -> Union[dict, None]:
...
class FilterException(Exception):
def __init__(self, message: str, status_code: int = 400) -> None:
self.message = message
self.status_code = status_code
class RequestFilterException(FilterException): ...
class ResponseFilterException(FilterException): ...
class ProxyBase(ABC):
def __init__(
self,
*,
request_filters: List[RequestFilterBase] = None,
response_filters: List[ResponseFilterBase] = None,
access_logger_queue: QueueClientBase
):
# Filters
self.request_filters = request_filters or []
self.response_filters = response_filters or []
# Access logger queue
self.access_logger_queue = access_logger_queue
def add_filter(self, filter: Union[RequestFilterBase, ResponseFilterBase]):
if isinstance(filter, RequestFilterBase):
self.request_filters.append(filter)
logger.info(f"request filter: {filter.__class__.__name__}")
elif isinstance(filter, ResponseFilterBase):
self.response_filters.append(filter)
logger.info(f"response filter: {filter.__class__.__name__}")
else:
logger.warning(f"Invalid filter: {filter.__class__.__name__}")
def add_response_headers(self, response: Response, request_id: str, headers: dict = None):
response.headers["X-AIProxy-Request-Id"] = request_id
if headers:
for k, v in headers.items():
response.headers[k] = v
@abstractmethod
def add_route(self, app: FastAPI, base_url: str):
...
# @abstractmethod
# def add_completion_route(self, app: FastAPI, base_url: str):
# ... |