File size: 1,718 Bytes
e487426
ee3e372
e482b3f
2f8a81b
ee3e372
2f8a81b
ee3e372
 
 
1f5da40
afaca69
ee3e372
afaca69
e487426
ee3e372
1f5da40
 
 
 
 
 
2f8a81b
ee3e372
2f8a81b
 
 
 
 
1f5da40
 
 
2f8a81b
 
 
 
d69b7ce
e487426
2f8a81b
 
 
1f5da40
2f8a81b
 
 
 
 
 
 
 
1f5da40
 
 
2f8a81b
 
d69b7ce
1f5da40
 
 
2f8a81b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
from gradio_client import Client, handle_file
import shutil
import base64
import os

app = FastAPI()


HF_TOKEN = os.getenv("HF_TOKEN")

# Initialize the Gradio client with the token
client = Client("Makhinur/Image_Face_Upscale_Restoration-GFPGAN", hf_token=HF_TOKEN)

# Version mapping from HTML to Gradio API
version_map = {
    "M1": "v1.2",
    "M2": "v1.3",
    "M3": "v1.4"
}

@app.post("/upload/")
async def enhance_image(
    file: UploadFile = File(...),
    version: str = Form(...),
    scale: int = Form(...)
):
    # Map version from HTML to Gradio expected value
    gradio_version = version_map.get(version, "v1.4")

    # Save the uploaded image to a temporary file
    temp_file_path = "temp_image.png"
    with open(temp_file_path, "wb") as buffer:
        shutil.copyfileobj(file.file, buffer)

    try:
        # Use the Gradio client to process the image
        result = client.predict(
            img=handle_file(temp_file_path),
            version=gradio_version,
            scale=scale,
            api_name="/predict"
        )

        # Read the result image and encode it in base64
        with open(result[0], "rb") as img_file:
            b64_string = base64.b64encode(img_file.read()).decode('utf-8')

        # Clean up the temporary file
        os.remove(temp_file_path)

        return JSONResponse(content={"sketch_image_base64": f"data:image/png;base64,{b64_string}"})

    except Exception as e:
        # Log the error message for debugging
        print(f"Error processing image: {e}")
        return JSONResponse(status_code=500, content={"message": "Internal Server Error"})