from fastapi import FastAPI, Request import httpx from starlette.responses import StreamingResponse, JSONResponse from starlette.background import BackgroundTask import uvicorn import json app = FastAPI(debug=True) # Define the base URL of your backend server BACKEND_BASE_URL = "http://localhost:8000" TIMEOUT_KEEP_ALIVE = 5.0 timeout_config = httpx.Timeout(5.0, read=60.0) async def hook(response: httpx.Response) -> None: if response.is_error: await response.aread() response.raise_for_status() @app.get("/{path:path}") async def forward_get_request(path: str, request: Request): async with httpx.AsyncClient() as client: response = await client.get(f"{BACKEND_BASE_URL}/{path}", params=request.query_params) content = response.aiter_bytes() if response.is_stream_consumed else response.content return StreamingResponse(content, media_type=response.headers['Content-Type']) @app.post("/{path:path}") async def forward_post_request(path: str, request: Request): # Retrieve the request body body = await request.body() # Prepare the headers, excluding those that can cause issues headers = {k: v for k, v in request.headers.items() if k.lower() not in ["host", "content-length"]} async with httpx.AsyncClient(event_hooks={'response': [hook]}, timeout=timeout_config) as client: # Send the request and get the response as a stream req = client.build_request("POST", f"{BACKEND_BASE_URL}/{path}", content=body, headers=headers) try: response = await client.send(req, stream=True) response.raise_for_status() if json.loads(body.decode('utf-8'))['stream']: # Custom streaming function async def stream_response(response): async for chunk in response.aiter_bytes(): yield chunk await response.aclose() # Ensure the response is closed after streaming return StreamingResponse(stream_response(response), status_code=response.status_code, headers=headers) else: # For regular JSON responses # For non-streaming responses, read the complete response body content = await response.aread() return JSONResponse(content=content, status_code=response.status_code) except httpx.ResponseNotRead as exc: print(f"HTTP Exception for {exc.request.url} - {exc}") if __name__ == "__main__": uvicorn.run(app, host='127.0.0.1', port=7860, log_level="debug", timeout_keep_alive=TIMEOUT_KEEP_ALIVE)