File size: 1,539 Bytes
2c151eb
 
f787272
 
2c151eb
d480925
 
d571b45
2c151eb
 
d571b45
e72c238
05761be
2c151eb
 
05761be
2c151eb
d480925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f787272
 
 
 
 
2c151eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from fastapi import FastAPI, File, UploadFile
from fastapi import FastAPI, File, UploadFile, Form, Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse
from gradio_client import Client
import os
import io
app = FastAPI()

hf_token = os.environ.get('HF_TOKEN')
client = Client("https://ashrafb-image-to-sketch.hf.space/", hf_token=hf_token)

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    try:
        # Save the uploaded image to a temporary file
        with NamedTemporaryFile(delete=False) as tmp:
            shutil.copyfileobj(file.file, tmp)
            tmp_path = tmp.name

        # Call the Gradio client to predict
        result = client.predict(tmp_path, api_name="/predict")

        # Delete the temporary file
        os.unlink(tmp_path)

        # Parse and return the result
        sketch_image_url, result_file_url = result
        return JSONResponse(content={"sketch_image_url": sketch_image_url, "result_file_url": result_file_url}, status_code=200)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

app.mount("/", StaticFiles(directory="static", html=True), name="static")

@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/static/index.html", media_type="text/html")