File size: 5,404 Bytes
ef1ad9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# --- Library Imports ---
from typing import Union
from fastapi import Request
from fastapi.exceptions import RequestValidationError, HTTPException
from fastapi.exception_handlers import (
request_validation_exception_handler as _request_validation_exception_handler,
)
from fastapi.responses import JSONResponse
from fastapi.responses import PlainTextResponse
from fastapi.responses import Response
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace.tracer import Tracer
from opencensus.trace.samplers import ProbabilitySampler
from opencensus.trace.span import SpanKind
from opencensus.trace.attributes_helper import COMMON_ATTRIBUTES
import traceback
# ---
# --- User Imports ---
from app.config.env import env
from app.utils.app_logger.logger import logger
# ---
# --- Constants ---
OPENAI_TRACER = Tracer(exporter=AzureExporter(
connection_string=env.APPLICATIONINSIGHTS_CONNECTION_STRING), sampler=ProbabilitySampler(1.0))
HTTP_URL = COMMON_ATTRIBUTES['HTTP_URL']
HTTP_STATUS_CODE = COMMON_ATTRIBUTES['HTTP_STATUS_CODE']
ERROR_NAME = COMMON_ATTRIBUTES['ERROR_NAME']
ERROR_MESSAGE = COMMON_ATTRIBUTES['ERROR_MESSAGE']
HTTP_METHOD = COMMON_ATTRIBUTES['HTTP_METHOD']
HTTP_PATH = COMMON_ATTRIBUTES['HTTP_PATH']
STACKTRACE = COMMON_ATTRIBUTES['STACKTRACE']
# ---
def add_trace_to_azure_appinsight(request: Request, exc: HTTPException):
"""
This function adds required traces to azure app insight logs for proper identification and dashboarding
"""
exception_traceback = traceback.format_exc(limit=2)
error_message_args = getattr(exc, 'args', None)
error_message_details = getattr(exc, 'detail', None)
error_message = error_message_details
if not error_message_details:
if (error_message_args):
error_message = error_message_args[0]
else:
error_message = 'Something went wrong'
with OPENAI_TRACER.span("main") as span:
span.span_kind = SpanKind.SERVER
OPENAI_TRACER.add_attribute_to_current_span(
attribute_key=HTTP_STATUS_CODE,
attribute_value=getattr(exc, 'status_code', 500))
OPENAI_TRACER.add_attribute_to_current_span(
attribute_key=HTTP_URL,
attribute_value=str(request.url))
OPENAI_TRACER.add_attribute_to_current_span(
attribute_key=ERROR_NAME,
attribute_value=str(type(exc).__name__))
OPENAI_TRACER.add_attribute_to_current_span(
attribute_key=ERROR_MESSAGE,
attribute_value=error_message)
OPENAI_TRACER.add_attribute_to_current_span(
attribute_key=HTTP_METHOD,
attribute_value=str(request.method))
OPENAI_TRACER.add_attribute_to_current_span(
attribute_key=HTTP_PATH,
attribute_value=str(request.url.path))
OPENAI_TRACER.add_attribute_to_current_span(
attribute_key=STACKTRACE,
attribute_value=str(exception_traceback))
async def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
"""
This is a wrapper to the default RequestValidationException handler of FastAPI.
This function will be called when client input is not valid.
"""
logger.debug("Our custom request_validation_exception_handler was called")
body = await request.body()
query_params = request.query_params._dict # pylint: disable=protected-access
detail = {"errors": exc.errors(), "body": body.decode(),
"query_params": query_params}
logger.info(detail)
return await _request_validation_exception_handler(request, exc)
async def http_exception_handler(request: Request, exc: HTTPException) -> Union[JSONResponse, Response]:
"""
This is a wrapper to the default HTTPException handler of FastAPI.
This function will be called when a HTTPException is explicitly raised.
"""
exception_traceback = traceback.format_exc(limit=2)
url = f"{request.url.path}?{request.query_params}" if request.query_params else request.url.path
logger.exception({
"status": exc.status_code,
"message": exc.detail,
"url": url,
"method": request.method,
"trace": exception_traceback,
})
add_trace_to_azure_appinsight(request, exc)
return JSONResponse(
status_code=exc.status_code,
content={"status": exc.status_code,
"message": exc.detail, "success": False},
)
async def unhandled_exception_handler(request: Request, exc: Exception) -> PlainTextResponse:
"""
This middleware will log all unhandled exceptions.
Unhandled exceptions are all exceptions that are not HTTPExceptions or RequestValidationErrors.
"""
exception_traceback = traceback.format_exc(limit=2)
url = f"{request.url.path}?{request.query_params}" if request.query_params else request.url.path
logger.exception({
"status": 500,
"message": exc,
"url": url,
"method": request.method,
"trace": exception_traceback,
})
add_trace_to_azure_appinsight(request, exc)
return JSONResponse(
status_code=500,
content={"status": 500,
"message": str(exc), "success": False},
)
|