from fastapi import FastAPI, Request
import httpx
from starlette.responses import StreamingResponse, JSONResponse
from starlette.background import BackgroundTask
import uvicorn
import json

app = FastAPI(debug=True)

# Define the base URL of your backend server
BACKEND_BASE_URL = "http://localhost:8000"
TIMEOUT_KEEP_ALIVE = 5.0
timeout_config = httpx.Timeout(5.0, read=60.0)


async def hook(response: httpx.Response) -> None:
    if response.is_error:
        await response.aread()
        response.raise_for_status()


@app.get("/{path:path}")
async def forward_get_request(path: str, request: Request):
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{BACKEND_BASE_URL}/{path}", params=request.query_params)
        content = response.aiter_bytes() if response.is_stream_consumed else response.content
        return StreamingResponse(content, media_type=response.headers['Content-Type'])


@app.post("/{path:path}")
async def forward_post_request(path: str, request: Request):
    # Retrieve the request body
    body = await request.body()

    # Prepare the headers, excluding those that can cause issues
    headers = {k: v for k, v in request.headers.items() if k.lower() not in ["host", "content-length"]}

    async with httpx.AsyncClient(event_hooks={'response': [hook]}, timeout=timeout_config) as client:
        # Send the request and get the response as a stream
        req = client.build_request("POST", f"{BACKEND_BASE_URL}/{path}", content=body, headers=headers)

        try:
            response = await client.send(req, stream=True)
            response.raise_for_status()

            if json.loads(body.decode('utf-8'))['stream']:
                # Custom streaming function
                async def stream_response(response):
                    async for chunk in response.aiter_bytes():
                        yield chunk
                    await response.aclose()  # Ensure the response is closed after streaming

                return StreamingResponse(stream_response(response),
                                         status_code=response.status_code,
                                         headers=headers)
            else:  # For regular JSON responses
                # For non-streaming responses, read the complete response body
                content = await response.aread()
                return JSONResponse(content=content, status_code=response.status_code)
        except httpx.ResponseNotRead as exc:
            print(f"HTTP Exception for {exc.request.url} - {exc}")


if __name__ == "__main__":
    uvicorn.run(app,
                host='127.0.0.1',
                port=7860,
                log_level="debug",
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE)