Makhinur commited on
Commit
1f5da40
·
verified ·
1 Parent(s): 2f8a81b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +17 -2
main.py CHANGED
@@ -7,11 +7,18 @@ import os
7
 
8
  app = FastAPI()
9
 
 
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
 
12
  # Initialize the Gradio client with the token
13
  client = Client("Makhinur/Image_Face_Upscale_Restoration-GFPGAN", hf_token=HF_TOKEN)
14
 
 
 
 
 
 
 
15
 
16
  @app.post("/upload/")
17
  async def enhance_image(
@@ -19,6 +26,9 @@ async def enhance_image(
19
  version: str = Form(...),
20
  scale: int = Form(...)
21
  ):
 
 
 
22
  # Save the uploaded image to a temporary file
23
  temp_file_path = "temp_image.png"
24
  with open(temp_file_path, "wb") as buffer:
@@ -28,7 +38,7 @@ async def enhance_image(
28
  # Use the Gradio client to process the image
29
  result = client.predict(
30
  img=handle_file(temp_file_path),
31
- version=version,
32
  scale=scale,
33
  api_name="/predict"
34
  )
@@ -37,8 +47,13 @@ async def enhance_image(
37
  with open(result[0], "rb") as img_file:
38
  b64_string = base64.b64encode(img_file.read()).decode('utf-8')
39
 
 
 
 
40
  return JSONResponse(content={"sketch_image_base64": f"data:image/png;base64,{b64_string}"})
41
 
42
  except Exception as e:
43
- return JSONResponse(status_code=500, content={"message": str(e)})
 
 
44
 
 
7
 
8
  app = FastAPI()
9
 
10
+
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
 
13
  # Initialize the Gradio client with the token
14
  client = Client("Makhinur/Image_Face_Upscale_Restoration-GFPGAN", hf_token=HF_TOKEN)
15
 
16
+ # Version mapping from HTML to Gradio API
17
+ version_map = {
18
+ "M1": "v1.2",
19
+ "M2": "v1.3",
20
+ "M3": "v1.4"
21
+ }
22
 
23
  @app.post("/upload/")
24
  async def enhance_image(
 
26
  version: str = Form(...),
27
  scale: int = Form(...)
28
  ):
29
+ # Map version from HTML to Gradio expected value
30
+ gradio_version = version_map.get(version, "v1.4")
31
+
32
  # Save the uploaded image to a temporary file
33
  temp_file_path = "temp_image.png"
34
  with open(temp_file_path, "wb") as buffer:
 
38
  # Use the Gradio client to process the image
39
  result = client.predict(
40
  img=handle_file(temp_file_path),
41
+ version=gradio_version,
42
  scale=scale,
43
  api_name="/predict"
44
  )
 
47
  with open(result[0], "rb") as img_file:
48
  b64_string = base64.b64encode(img_file.read()).decode('utf-8')
49
 
50
+ # Clean up the temporary file
51
+ os.remove(temp_file_path)
52
+
53
  return JSONResponse(content={"sketch_image_base64": f"data:image/png;base64,{b64_string}"})
54
 
55
  except Exception as e:
56
+ # Log the error message for debugging
57
+ print(f"Error processing image: {e}")
58
+ return JSONResponse(status_code=500, content={"message": "Internal Server Error"})
59