|
from fastapi import FastAPI, File, UploadFile, Form |
|
from fastapi.responses import JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from gradio_client import Client, handle_file |
|
import os |
|
import tempfile |
|
import base64 |
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
hf_token = os.environ.get('HF_TOKEN') |
|
client = Client("https://Makhinur/Image_Face_Upscale_Restoration-GFPGAN.hf.space/", hf_token=hf_token) |
|
|
|
@app.post("/upload/") |
|
async def upload_file(file: UploadFile = File(...), version: str = Form(...), scale: int = Form(...)): |
|
|
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file: |
|
temp_file.write(await file.read()) |
|
temp_file_path = temp_file.name |
|
|
|
try: |
|
|
|
result = client.predict(handle_file(temp_file_path), version, scale, api_name="/predict") |
|
|
|
|
|
if result and len(result) == 2: |
|
|
|
with open(result[0], "rb") as image_file: |
|
image_data = base64.b64encode(image_file.read()).decode("utf-8") |
|
|
|
return JSONResponse({ |
|
"sketch_image_base64": f"data:image/png;base64,{image_data}", |
|
"result_file": result[1] |
|
}) |
|
else: |
|
return JSONResponse({"error": "Invalid result from the prediction API."}, status_code=500) |
|
except Exception as e: |
|
return JSONResponse({"error": str(e)}, status_code=500) |
|
finally: |
|
|
|
if os.path.exists(temp_file_path): |
|
os.unlink(temp_file_path) |