File size: 1,867 Bytes
ee3e372
 
 
e482b3f
ee3e372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e482b3f
 
ee3e372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from gradio_client import Client, handle_file
import os
import tempfile
import base64

app = FastAPI()

# Setup CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Adjust as needed
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize Gradio client
hf_token = os.environ.get('HF_TOKEN')
client = Client("https://Makhinur/Image_Face_Upscale_Restoration-GFPGAN.hf.space/", hf_token=hf_token)

@app.post("/upload/")
async def upload_file(file: UploadFile = File(...), version: str = Form(...), scale: int = Form(...)):
    # Save the uploaded file temporarily
    with tempfile.NamedTemporaryFile(delete=False) as temp_file:
        temp_file.write(await file.read())
        temp_file_path = temp_file.name

    try:
        # Use handle_file to prepare the file for the Gradio client
        result = client.predict(handle_file(temp_file_path), version, scale, api_name="/predict")

        # Check if the result is valid
        if result and len(result) == 2:
            # Convert the image data to a base64 string
            with open(result[0], "rb") as image_file:
                image_data = base64.b64encode(image_file.read()).decode("utf-8")

            return JSONResponse({
                "sketch_image_base64": f"data:image/png;base64,{image_data}",
                "result_file": result[1]
            })
        else:
            return JSONResponse({"error": "Invalid result from the prediction API."}, status_code=500)
    except Exception as e:
        return JSONResponse({"error": str(e)}, status_code=500)
    finally:
        # Clean up the temporary file
        if os.path.exists(temp_file_path):
            os.unlink(temp_file_path)