Ashrafb commited on
Commit
d480925
1 Parent(s): d571b45

Update main.py

Browse files

from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from gradio_client import Client

app = FastAPI()
client = Client("https://ashrafb-image-to-sketch.hf.space/")



@app

.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
# Save the uploaded image to a temporary file
with NamedTemporaryFile(delete=False) as tmp:
shutil.copyfileobj(file.file, tmp)
tmp_path = tmp.name

# Call the Gradio client to predict
result = client.predict(tmp_path, api_name="/predict")

# Delete the temporary file
os.unlink(tmp_path)

# Parse and return the result
sketch_image_url, result_file_url = result
return JSONResponse(content={"sketch_image_url": sketch_image_url, "result_file_url": result_file_url}, status_code=200)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

Files changed (1) hide show
  1. main.py +20 -6
main.py CHANGED
@@ -3,6 +3,8 @@ from fastapi import FastAPI, File, UploadFile, Form, Request
3
  from fastapi.responses import HTMLResponse, FileResponse
4
  from fastapi.staticfiles import StaticFiles
5
  from fastapi.templating import Jinja2Templates
 
 
6
  from fastapi.responses import StreamingResponse
7
  from gradio_client import Client
8
  import os
@@ -13,12 +15,24 @@ hf_token = os.environ.get('HF_TOKEN')
13
  client = Client("https://ashrafb-image-to-sketch.hf.space/", hf_token=hf_token)
14
 
15
  @app.post("/predict")
16
- async def predict_sketch(file: UploadFile = File(...)):
17
- content = await file.read()
18
- # Call the Gradio client to get the sketch result
19
- result = client.predict(content, api_name="/predict")
20
- sketch_image_bytes = result[0].encode('utf-8') # Assuming result[0] is a base64 encoded image
21
- return StreamingResponse(io.BytesIO(sketch_image_bytes), media_type="image/png")
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
24
 
 
3
  from fastapi.responses import HTMLResponse, FileResponse
4
  from fastapi.staticfiles import StaticFiles
5
  from fastapi.templating import Jinja2Templates
6
+ from fastapi import FastAPI, File, UploadFile, HTTPException
7
+ from fastapi.responses import JSONResponse
8
  from fastapi.responses import StreamingResponse
9
  from gradio_client import Client
10
  import os
 
15
  client = Client("https://ashrafb-image-to-sketch.hf.space/", hf_token=hf_token)
16
 
17
  @app.post("/predict")
18
+ async def predict(file: UploadFile = File(...)):
19
+ try:
20
+ # Save the uploaded image to a temporary file
21
+ with NamedTemporaryFile(delete=False) as tmp:
22
+ shutil.copyfileobj(file.file, tmp)
23
+ tmp_path = tmp.name
24
+
25
+ # Call the Gradio client to predict
26
+ result = client.predict(tmp_path, api_name="/predict")
27
+
28
+ # Delete the temporary file
29
+ os.unlink(tmp_path)
30
+
31
+ # Parse and return the result
32
+ sketch_image_url, result_file_url = result
33
+ return JSONResponse(content={"sketch_image_url": sketch_image_url, "result_file_url": result_file_url}, status_code=200)
34
+ except Exception as e:
35
+ raise HTTPException(status_code=500, detail=str(e))
36
 
37
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
38