File size: 2,135 Bytes
e487426 ee3e372 e482b3f 2f8a81b ee3e372 2f8a81b 4c802ed ee3e372 afaca69 ee3e372 afaca69 e487426 ee3e372 1f5da40 2f8a81b ee3e372 2f8a81b 1f5da40 2f8a81b d69b7ce e487426 2f8a81b 1f5da40 2f8a81b 4c802ed b18ac6b 2f8a81b 4c802ed 1f5da40 4c802ed 1f5da40 b18ac6b 2f8a81b d69b7ce 1f5da40 4c802ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
from gradio_client import Client, handle_file
import shutil
import base64
import os
from PIL import Image # Import the Pillow library
app = FastAPI()
HF_TOKEN = os.getenv("HF_TOKEN")
# Initialize the Gradio client with the token
client = Client("Makhinur/Image_Face_Upscale_Restoration-GFPGAN", hf_token=HF_TOKEN)
# Version mapping from HTML to Gradio API
version_map = {
"M1": "v1.2",
"M2": "v1.3",
"M3": "v1.4"
}
@app.post("/upload/")
async def enhance_image(
file: UploadFile = File(...),
version: str = Form(...),
scale: int = Form(...)
):
# Map version from HTML to Gradio expected value
gradio_version = version_map.get(version, "v1.4")
# Save the uploaded image to a temporary file
temp_file_path = "temp_image.png"
with open(temp_file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
try:
# Use the Gradio client to process the image
result = client.predict(
img=handle_file(temp_file_path),
version=gradio_version,
scale=scale,
api_name="/predict"
)
# Assuming the Gradio app outputs a WebP file
result_image_path = result[0] # This path should be a WebP file
# Convert the WebP image to PNG using Pillow
with Image.open(result_image_path) as img:
png_image_path = "output_image.png"
img.save(png_image_path, format="PNG")
# Read the PNG image and encode it in base64
with open(png_image_path, "rb") as img_file:
b64_string = base64.b64encode(img_file.read()).decode('utf-8')
# Clean up the temporary files
os.remove(temp_file_path)
os.remove(png_image_path)
return JSONResponse(content={"sketch_image_base64": f"data:image/png;base64,{b64_string}"})
except Exception as e:
# Log the error message for debugging
print(f"Error processing image: {e}")
return JSONResponse(status_code=500, content={"message": "Internal Server Error"}) |