Update main.py
Browse files
main.py
CHANGED
@@ -7,11 +7,18 @@ import os
|
|
7 |
|
8 |
app = FastAPI()
|
9 |
|
|
|
10 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
11 |
|
12 |
# Initialize the Gradio client with the token
|
13 |
client = Client("Makhinur/Image_Face_Upscale_Restoration-GFPGAN", hf_token=HF_TOKEN)
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
@app.post("/upload/")
|
17 |
async def enhance_image(
|
@@ -19,6 +26,9 @@ async def enhance_image(
|
|
19 |
version: str = Form(...),
|
20 |
scale: int = Form(...)
|
21 |
):
|
|
|
|
|
|
|
22 |
# Save the uploaded image to a temporary file
|
23 |
temp_file_path = "temp_image.png"
|
24 |
with open(temp_file_path, "wb") as buffer:
|
@@ -28,7 +38,7 @@ async def enhance_image(
|
|
28 |
# Use the Gradio client to process the image
|
29 |
result = client.predict(
|
30 |
img=handle_file(temp_file_path),
|
31 |
-
version=
|
32 |
scale=scale,
|
33 |
api_name="/predict"
|
34 |
)
|
@@ -37,8 +47,13 @@ async def enhance_image(
|
|
37 |
with open(result[0], "rb") as img_file:
|
38 |
b64_string = base64.b64encode(img_file.read()).decode('utf-8')
|
39 |
|
|
|
|
|
|
|
40 |
return JSONResponse(content={"sketch_image_base64": f"data:image/png;base64,{b64_string}"})
|
41 |
|
42 |
except Exception as e:
|
43 |
-
|
|
|
|
|
44 |
|
|
|
7 |
|
8 |
app = FastAPI()
|
9 |
|
10 |
+
|
11 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
12 |
|
13 |
# Initialize the Gradio client with the token
|
14 |
client = Client("Makhinur/Image_Face_Upscale_Restoration-GFPGAN", hf_token=HF_TOKEN)
|
15 |
|
16 |
+
# Version mapping from HTML to Gradio API
|
17 |
+
version_map = {
|
18 |
+
"M1": "v1.2",
|
19 |
+
"M2": "v1.3",
|
20 |
+
"M3": "v1.4"
|
21 |
+
}
|
22 |
|
23 |
@app.post("/upload/")
|
24 |
async def enhance_image(
|
|
|
26 |
version: str = Form(...),
|
27 |
scale: int = Form(...)
|
28 |
):
|
29 |
+
# Map version from HTML to Gradio expected value
|
30 |
+
gradio_version = version_map.get(version, "v1.4")
|
31 |
+
|
32 |
# Save the uploaded image to a temporary file
|
33 |
temp_file_path = "temp_image.png"
|
34 |
with open(temp_file_path, "wb") as buffer:
|
|
|
38 |
# Use the Gradio client to process the image
|
39 |
result = client.predict(
|
40 |
img=handle_file(temp_file_path),
|
41 |
+
version=gradio_version,
|
42 |
scale=scale,
|
43 |
api_name="/predict"
|
44 |
)
|
|
|
47 |
with open(result[0], "rb") as img_file:
|
48 |
b64_string = base64.b64encode(img_file.read()).decode('utf-8')
|
49 |
|
50 |
+
# Clean up the temporary file
|
51 |
+
os.remove(temp_file_path)
|
52 |
+
|
53 |
return JSONResponse(content={"sketch_image_base64": f"data:image/png;base64,{b64_string}"})
|
54 |
|
55 |
except Exception as e:
|
56 |
+
# Log the error message for debugging
|
57 |
+
print(f"Error processing image: {e}")
|
58 |
+
return JSONResponse(status_code=500, content={"message": "Internal Server Error"})
|
59 |
|