File size: 2,135 Bytes
e487426
ee3e372
e482b3f
2f8a81b
ee3e372
2f8a81b
4c802ed
ee3e372
 
 
afaca69
ee3e372
afaca69
e487426
ee3e372
1f5da40
 
 
 
 
 
2f8a81b
ee3e372
2f8a81b
 
 
 
 
1f5da40
 
 
2f8a81b
 
 
 
d69b7ce
e487426
2f8a81b
 
 
1f5da40
2f8a81b
 
 
 
4c802ed
 
 
 
 
 
 
 
 
 
b18ac6b
2f8a81b
4c802ed
1f5da40
4c802ed
1f5da40
b18ac6b
2f8a81b
d69b7ce
1f5da40
 
4c802ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
from gradio_client import Client, handle_file
import shutil
import base64
import os
from PIL import Image  # Import the Pillow library

app = FastAPI()

HF_TOKEN = os.getenv("HF_TOKEN")

# Initialize the Gradio client with the token
client = Client("Makhinur/Image_Face_Upscale_Restoration-GFPGAN", hf_token=HF_TOKEN)

# Version mapping from HTML to Gradio API
version_map = {
    "M1": "v1.2",
    "M2": "v1.3",
    "M3": "v1.4"
}

@app.post("/upload/")
async def enhance_image(
    file: UploadFile = File(...),
    version: str = Form(...),
    scale: int = Form(...)
):
    # Map version from HTML to Gradio expected value
    gradio_version = version_map.get(version, "v1.4")

    # Save the uploaded image to a temporary file
    temp_file_path = "temp_image.png"
    with open(temp_file_path, "wb") as buffer:
        shutil.copyfileobj(file.file, buffer)

    try:
        # Use the Gradio client to process the image
        result = client.predict(
            img=handle_file(temp_file_path),
            version=gradio_version,
            scale=scale,
            api_name="/predict"
        )

        # Assuming the Gradio app outputs a WebP file
        result_image_path = result[0]  # This path should be a WebP file

        # Convert the WebP image to PNG using Pillow
        with Image.open(result_image_path) as img:
            png_image_path = "output_image.png"
            img.save(png_image_path, format="PNG")

        # Read the PNG image and encode it in base64
        with open(png_image_path, "rb") as img_file:
            b64_string = base64.b64encode(img_file.read()).decode('utf-8')

        # Clean up the temporary files
        os.remove(temp_file_path)
        os.remove(png_image_path)

        return JSONResponse(content={"sketch_image_base64": f"data:image/png;base64,{b64_string}"})

    except Exception as e:
        # Log the error message for debugging
        print(f"Error processing image: {e}")
        return JSONResponse(status_code=500, content={"message": "Internal Server Error"})