Spaces:
Sleeping
Sleeping
File size: 2,762 Bytes
e2d4dfc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from fastapi import FastAPI, Request
import httpx
from starlette.responses import StreamingResponse, JSONResponse
from starlette.background import BackgroundTask
import uvicorn
import json
app = FastAPI(debug=True)
# Define the base URL of your backend server
BACKEND_BASE_URL = "http://localhost:8000"
TIMEOUT_KEEP_ALIVE = 5.0
timeout_config = httpx.Timeout(5.0, read=60.0)
async def hook(response: httpx.Response) -> None:
if response.is_error:
await response.aread()
response.raise_for_status()
@app.get("/{path:path}")
async def forward_get_request(path: str, request: Request):
async with httpx.AsyncClient() as client:
response = await client.get(f"{BACKEND_BASE_URL}/{path}", params=request.query_params)
content = response.aiter_bytes() if response.is_stream_consumed else response.content
return StreamingResponse(content, media_type=response.headers['Content-Type'])
@app.post("/{path:path}")
async def forward_post_request(path: str, request: Request):
# Retrieve the request body
body = await request.body()
# Prepare the headers, excluding those that can cause issues
headers = {k: v for k, v in request.headers.items() if k.lower() not in ["host", "content-length"]}
async with httpx.AsyncClient(event_hooks={'response': [hook]}, timeout=timeout_config) as client:
# Send the request and get the response as a stream
req = client.build_request("POST", f"{BACKEND_BASE_URL}/{path}", content=body, headers=headers)
try:
response = await client.send(req, stream=True)
response.raise_for_status()
if json.loads(body.decode('utf-8'))['stream']:
# Custom streaming function
async def stream_response(response):
async for chunk in response.aiter_bytes():
yield chunk
await response.aclose() # Ensure the response is closed after streaming
return StreamingResponse(stream_response(response),
status_code=response.status_code,
headers=headers)
else: # For regular JSON responses
# For non-streaming responses, read the complete response body
content = await response.aread()
return JSONResponse(content=content, status_code=response.status_code)
except httpx.ResponseNotRead as exc:
print(f"HTTP Exception for {exc.request.url} - {exc}")
if __name__ == "__main__":
uvicorn.run(app,
host='127.0.0.1',
port=7860,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|