|
from contextlib import asynccontextmanager |
|
from dataclasses import asdict, dataclass |
|
from enum import Enum |
|
import re |
|
from typing import ( |
|
TYPE_CHECKING, |
|
Any, |
|
AsyncGenerator, |
|
Dict, |
|
MutableMapping, |
|
Optional, |
|
cast, |
|
) |
|
import uuid |
|
|
|
from asgiref.typing import ( |
|
ASGI3Application, |
|
ASGIReceiveCallable, |
|
ASGIReceiveEvent, |
|
ASGISendCallable, |
|
ASGISendEvent, |
|
Scope as ASGIScope, |
|
) |
|
from loguru import logger |
|
from starlette.requests import Request |
|
|
|
from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE |
|
from open_webui.utils.auth import get_current_user, get_http_authorization_cred |
|
from open_webui.models.users import UserModel |
|
|
|
|
|
if TYPE_CHECKING: |
|
from loguru import Logger |
|
|
|
|
|
@dataclass(frozen=True) |
|
class AuditLogEntry: |
|
|
|
id: str |
|
user: dict[str, Any] |
|
audit_level: str |
|
verb: str |
|
request_uri: str |
|
user_agent: Optional[str] = None |
|
source_ip: Optional[str] = None |
|
|
|
request_object: Any = None |
|
|
|
response_object: Any = None |
|
response_status_code: Optional[int] = None |
|
|
|
|
|
class AuditLevel(str, Enum): |
|
NONE = "NONE" |
|
METADATA = "METADATA" |
|
REQUEST = "REQUEST" |
|
REQUEST_RESPONSE = "REQUEST_RESPONSE" |
|
|
|
|
|
class AuditLogger: |
|
""" |
|
A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly. |
|
|
|
Parameters: |
|
logger (Logger): An instance of Loguru’s logger. |
|
""" |
|
|
|
def __init__(self, logger: "Logger"): |
|
self.logger = logger.bind(auditable=True) |
|
|
|
def write( |
|
self, |
|
audit_entry: AuditLogEntry, |
|
*, |
|
log_level: str = "INFO", |
|
extra: Optional[dict] = None, |
|
): |
|
|
|
entry = asdict(audit_entry) |
|
|
|
if extra: |
|
entry["extra"] = extra |
|
|
|
self.logger.log( |
|
log_level, |
|
"", |
|
**entry, |
|
) |
|
|
|
|
|
class AuditContext: |
|
""" |
|
Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage. |
|
|
|
Attributes: |
|
request_body (bytearray): Accumulated request payload. |
|
response_body (bytearray): Accumulated response payload. |
|
max_body_size (int): Maximum number of bytes to capture. |
|
metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.). |
|
""" |
|
|
|
def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE): |
|
self.request_body = bytearray() |
|
self.response_body = bytearray() |
|
self.max_body_size = max_body_size |
|
self.metadata: Dict[str, Any] = {} |
|
|
|
def add_request_chunk(self, chunk: bytes): |
|
if len(self.request_body) < self.max_body_size: |
|
self.request_body.extend( |
|
chunk[: self.max_body_size - len(self.request_body)] |
|
) |
|
|
|
def add_response_chunk(self, chunk: bytes): |
|
if len(self.response_body) < self.max_body_size: |
|
self.response_body.extend( |
|
chunk[: self.max_body_size - len(self.response_body)] |
|
) |
|
|
|
|
|
class AuditLoggingMiddleware: |
|
""" |
|
ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle. |
|
""" |
|
|
|
AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"} |
|
|
|
def __init__( |
|
self, |
|
app: ASGI3Application, |
|
*, |
|
excluded_paths: Optional[list[str]] = None, |
|
max_body_size: int = MAX_BODY_LOG_SIZE, |
|
audit_level: AuditLevel = AuditLevel.NONE, |
|
) -> None: |
|
self.app = app |
|
self.audit_logger = AuditLogger(logger) |
|
self.excluded_paths = excluded_paths or [] |
|
self.max_body_size = max_body_size |
|
self.audit_level = audit_level |
|
|
|
async def __call__( |
|
self, |
|
scope: ASGIScope, |
|
receive: ASGIReceiveCallable, |
|
send: ASGISendCallable, |
|
) -> None: |
|
if scope["type"] != "http": |
|
return await self.app(scope, receive, send) |
|
|
|
request = Request(scope=cast(MutableMapping, scope)) |
|
|
|
if self._should_skip_auditing(request): |
|
return await self.app(scope, receive, send) |
|
|
|
async with self._audit_context(request) as context: |
|
|
|
async def send_wrapper(message: ASGISendEvent) -> None: |
|
if self.audit_level == AuditLevel.REQUEST_RESPONSE: |
|
await self._capture_response(message, context) |
|
|
|
await send(message) |
|
|
|
original_receive = receive |
|
|
|
async def receive_wrapper() -> ASGIReceiveEvent: |
|
nonlocal original_receive |
|
message = await original_receive() |
|
|
|
if self.audit_level in ( |
|
AuditLevel.REQUEST, |
|
AuditLevel.REQUEST_RESPONSE, |
|
): |
|
await self._capture_request(message, context) |
|
|
|
return message |
|
|
|
await self.app(scope, receive_wrapper, send_wrapper) |
|
|
|
@asynccontextmanager |
|
async def _audit_context( |
|
self, request: Request |
|
) -> AsyncGenerator[AuditContext, None]: |
|
""" |
|
async context manager that ensures that an audit log entry is recorded after the request is processed. |
|
""" |
|
context = AuditContext() |
|
try: |
|
yield context |
|
finally: |
|
await self._log_audit_entry(request, context) |
|
|
|
async def _get_authenticated_user(self, request: Request) -> UserModel: |
|
|
|
auth_header = request.headers.get("Authorization") |
|
assert auth_header |
|
user = get_current_user(request, None, get_http_authorization_cred(auth_header)) |
|
|
|
return user |
|
|
|
def _should_skip_auditing(self, request: Request) -> bool: |
|
if ( |
|
request.method not in {"POST", "PUT", "PATCH", "DELETE"} |
|
or AUDIT_LOG_LEVEL == "NONE" |
|
or not request.headers.get("authorization") |
|
): |
|
return True |
|
|
|
pattern = re.compile( |
|
r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b" |
|
) |
|
if pattern.match(request.url.path): |
|
return True |
|
|
|
return False |
|
|
|
async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext): |
|
if message["type"] == "http.request": |
|
body = message.get("body", b"") |
|
context.add_request_chunk(body) |
|
|
|
async def _capture_response(self, message: ASGISendEvent, context: AuditContext): |
|
if message["type"] == "http.response.start": |
|
context.metadata["response_status_code"] = message["status"] |
|
|
|
elif message["type"] == "http.response.body": |
|
body = message.get("body", b"") |
|
context.add_response_chunk(body) |
|
|
|
async def _log_audit_entry(self, request: Request, context: AuditContext): |
|
try: |
|
user = await self._get_authenticated_user(request) |
|
|
|
entry = AuditLogEntry( |
|
id=str(uuid.uuid4()), |
|
user=user.model_dump(include={"id", "name", "email", "role"}), |
|
audit_level=self.audit_level.value, |
|
verb=request.method, |
|
request_uri=str(request.url), |
|
response_status_code=context.metadata.get("response_status_code", None), |
|
source_ip=request.client.host if request.client else None, |
|
user_agent=request.headers.get("user-agent"), |
|
request_object=context.request_body.decode("utf-8", errors="replace"), |
|
response_object=context.response_body.decode("utf-8", errors="replace"), |
|
) |
|
|
|
self.audit_logger.write(entry) |
|
except Exception as e: |
|
logger.error(f"Failed to log audit entry: {str(e)}") |
|
|