Makhinur commited on
Commit
2f8a81b
·
verified ·
1 Parent(s): f114b05

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +28 -37
main.py CHANGED
@@ -1,53 +1,44 @@
1
  from fastapi import FastAPI, File, UploadFile, Form
2
  from fastapi.responses import JSONResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
  from gradio_client import Client, handle_file
5
- import os
6
- import tempfile
7
  import base64
 
8
 
9
  app = FastAPI()
10
 
11
- # Setup CORS
12
- app.add_middleware(
13
- CORSMiddleware,
14
- allow_origins=["*"], # Adjust as needed
15
- allow_credentials=True,
16
- allow_methods=["*"],
17
- allow_headers=["*"],
18
- )
19
-
20
  HF_TOKEN = os.getenv("HF_TOKEN")
21
 
22
  # Initialize the Gradio client with the token
23
  client = Client("Makhinur/Image_Face_Upscale_Restoration-GFPGAN", hf_token=HF_TOKEN)
24
 
 
25
  @app.post("/upload/")
26
- async def upload_file(file: UploadFile = File(...), version: str = Form(...), scale: int = Form(...)):
27
- # Save the uploaded file temporarily
28
- with tempfile.NamedTemporaryFile(delete=False) as temp_file:
29
- temp_file.write(await file.read())
30
- temp_file_path = temp_file.name
 
 
 
 
31
 
32
  try:
33
- # Use handle_file to prepare the file for the Gradio client
34
- result = client.predict(handle_file(temp_file_path), version, scale, api_name="/predict")
35
-
36
- # Check if the result is valid
37
- if result and len(result) == 2:
38
- # Convert the image data to a base64 string
39
- with open(result[0], "rb") as image_file:
40
- image_data = base64.b64encode(image_file.read()).decode("utf-8")
41
-
42
- return JSONResponse({
43
- "sketch_image_base64": f"data:image/png;base64,{image_data}",
44
- "result_file": result[1]
45
- })
46
- else:
47
- return JSONResponse({"error": "Invalid result from the prediction API."}, status_code=500)
48
  except Exception as e:
49
- return JSONResponse({"error": str(e)}, status_code=500)
50
- finally:
51
- # Clean up the temporary file
52
- if os.path.exists(temp_file_path):
53
- os.unlink(temp_file_path)
 
1
  from fastapi import FastAPI, File, UploadFile, Form
2
  from fastapi.responses import JSONResponse
 
3
  from gradio_client import Client, handle_file
4
+ import shutil
 
5
  import base64
6
+ import os
7
 
8
  app = FastAPI()
9
 
 
 
 
 
 
 
 
 
 
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
 
12
  # Initialize the Gradio client with the token
13
  client = Client("Makhinur/Image_Face_Upscale_Restoration-GFPGAN", hf_token=HF_TOKEN)
14
 
15
+
16
  @app.post("/upload/")
17
+ async def enhance_image(
18
+ file: UploadFile = File(...),
19
+ version: str = Form(...),
20
+ scale: int = Form(...)
21
+ ):
22
+ # Save the uploaded image to a temporary file
23
+ temp_file_path = "temp_image.png"
24
+ with open(temp_file_path, "wb") as buffer:
25
+ shutil.copyfileobj(file.file, buffer)
26
 
27
  try:
28
+ # Use the Gradio client to process the image
29
+ result = client.predict(
30
+ img=handle_file(temp_file_path),
31
+ version=version,
32
+ scale=scale,
33
+ api_name="/predict"
34
+ )
35
+
36
+ # Read the result image and encode it in base64
37
+ with open(result[0], "rb") as img_file:
38
+ b64_string = base64.b64encode(img_file.read()).decode('utf-8')
39
+
40
+ return JSONResponse(content={"sketch_image_base64": f"data:image/png;base64,{b64_string}"})
41
+
 
42
  except Exception as e:
43
+ return JSONResponse(status_code=500, content={"message": str(e)})
44
+