from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse, JSONResponse
from loguru import logger

class OpenAIError(Exception):
    pass

class APIError(OpenAIError):
    message: str
    code: str = None
    param: str = None
    type: str = None

    def __init__(self, message: str, code: int = 500, param: str = None, internal_message: str = ''):
        super().__init__(message)
        self.message = message
        self.code = code
        self.param = param
        self.type = self.__class__.__name__,
        self.internal_message = internal_message

    def __repr__(self):
        return "%s(message=%r, code=%d, param=%s)" % (
            self.__class__.__name__,
            self.message,
            self.code,
            self.param,
        )

class InternalServerError(APIError):
    pass

class ServiceUnavailableError(APIError):
    def __init__(self, message="Service unavailable, please try again later.", code=503, internal_message=''):
        super().__init__(message, code, internal_message)

class APIStatusError(APIError):
    status_code: int = 400
    
    def __init__(self, message: str, param: str = None, internal_message: str = ''):
        super().__init__(message, self.status_code, param, internal_message)

class BadRequestError(APIStatusError):
    status_code: int = 400

class AuthenticationError(APIStatusError):
    status_code: int = 401

class PermissionDeniedError(APIStatusError):
    status_code: int = 403

class NotFoundError(APIStatusError):
    status_code: int = 404

class ConflictError(APIStatusError):
    status_code: int = 409

class UnprocessableEntityError(APIStatusError):
    status_code: int = 422

class RateLimitError(APIStatusError):
    status_code: int = 429

class OpenAIStub(FastAPI):
    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.models = {}

        self.add_middleware(
            CORSMiddleware,
            allow_origins=["*"],
            allow_credentials=True,
            allow_methods=["*"],
            allow_headers=["*"]
        )

        @self.exception_handler(Exception)
        def openai_exception_handler(request: Request, exc: Exception) -> JSONResponse:
            # Generic server errors
            #logger.opt(exception=exc).error("Logging exception traceback")

            return JSONResponse(status_code=500, content={
                'message': 'InternalServerError',
                'code': 500,
            })

        @self.exception_handler(APIError)
        def openai_apierror_handler(request: Request, exc: APIError) -> JSONResponse:
            # Server error
            logger.opt(exception=exc).error("Logging exception traceback")

            if exc.internal_message:
                logger.info(exc.internal_message)

            return JSONResponse(status_code = exc.code, content={
                'message': exc.message,
                'code': exc.code,
                'type': exc.__class__.__name__,
                'param': exc.param,
            })

        @self.exception_handler(APIStatusError)
        def openai_statuserror_handler(request: Request, exc: APIStatusError) -> JSONResponse:
            # Client side error
            logger.info(repr(exc))

            if exc.internal_message:
                logger.info(exc.internal_message)

            return JSONResponse(status_code = exc.code, content={
                'message': exc.message,
                'code': exc.code,
                'type': exc.__class__.__name__,
                'param': exc.param,
            })

        @self.middleware("http")
        async def log_requests(request: Request, call_next):
            logger.debug(f"Request path: {request.url.path}")
            logger.debug(f"Request method: {request.method}")
            logger.debug(f"Request headers: {request.headers}")
            logger.debug(f"Request query params: {request.query_params}")
            logger.debug(f"Request body: {await request.body()}")

            response = await call_next(request)

            logger.debug(f"Response status code: {response.status_code}")
            logger.debug(f"Response headers: {response.headers}")

            return response

        @self.get('/v1/billing/usage')
        @self.get('/v1/dashboard/billing/usage')
        async def handle_billing_usage():
            return { 'total_usage': 0 }

        @self.get("/", response_class=PlainTextResponse)
        @self.head("/", response_class=PlainTextResponse)
        @self.options("/", response_class=PlainTextResponse)
        async def root():
            return PlainTextResponse(content="", status_code=200 if self.models else 503)

        @self.get("/health")
        async def health():
            return {"status": "ok" if self.models else "unk" }

        @self.get("/v1/models")
        async def get_model_list():
            return self.model_list()

        @self.get("/v1/models/{model}")
        async def get_model_info(model_id: str):
            return self.model_info(model_id)

    def register_model(self, name: str, model: str = None) -> None:
        self.models[name] = model if model else name

    def deregister_model(self, name: str) -> None:
        if name in self.models:
            del self.models[name]

    def model_info(self, model: str) -> dict:
        result = {
            "id": model,
            "object": "model",
            "created": 0,
            "owned_by": "user"
        }
        return result

    def model_list(self) -> dict:
        if not self.models:
            return {}
        
        result = {
            "object": "list",
            "data": [ self.model_info(model) for model in list(set(self.models.keys() | self.models.values())) if model ]
        }

        return result