Ashrafb commited on
Commit
d571b45
·
verified ·
1 Parent(s): 4d1d14d

Update main.py

Browse files

from fastapi import FastAPI, File, UploadFile
from fastapi.responses import StreamingResponse
from gradio_client import Client

app = FastAPI()

client = Client("https://ashrafb-image-to-sketch.hf.space/")



@app
.post("/predict")
async def predict_sketch(file: UploadFile = File(...)):
content = await file.read()
# Call the Gradio client to get the sketch result
result = client.predict(content, api_name="/predict")
sketch_image_bytes = result[0].encode('utf-8') # Assuming result[0] is a base64 encoded image
return StreamingResponse(io.BytesIO(sketch_image_bytes), media_type="image/png")

Files changed (1) hide show
  1. main.py +4 -2
main.py CHANGED
@@ -3,9 +3,10 @@ from fastapi import FastAPI, File, UploadFile, Form, Request
3
  from fastapi.responses import HTMLResponse, FileResponse
4
  from fastapi.staticfiles import StaticFiles
5
  from fastapi.templating import Jinja2Templates
 
6
  from gradio_client import Client
7
  import os
8
-
9
  app = FastAPI()
10
 
11
  hf_token = os.environ.get('HF_TOKEN')
@@ -16,7 +17,8 @@ async def predict_sketch(file: UploadFile = File(...)):
16
  content = await file.read()
17
  # Call the Gradio client to get the sketch result
18
  result = client.predict(content, api_name="/predict")
19
- return {"sketch_image": result[0], "result_file": result[1]}
 
20
 
21
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
22
 
 
3
  from fastapi.responses import HTMLResponse, FileResponse
4
  from fastapi.staticfiles import StaticFiles
5
  from fastapi.templating import Jinja2Templates
6
+ from fastapi.responses import StreamingResponse
7
  from gradio_client import Client
8
  import os
9
+ import io
10
  app = FastAPI()
11
 
12
  hf_token = os.environ.get('HF_TOKEN')
 
17
  content = await file.read()
18
  # Call the Gradio client to get the sketch result
19
  result = client.predict(content, api_name="/predict")
20
+ sketch_image_bytes = result[0].encode('utf-8') # Assuming result[0] is a base64 encoded image
21
+ return StreamingResponse(io.BytesIO(sketch_image_bytes), media_type="image/png")
22
 
23
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
24