File size: 1,802 Bytes
2c151eb
 
f787272
 
2c151eb
d480925
 
d571b45
2c151eb
 
d571b45
e72c238
05761be
2c151eb
 
05761be
2371a34
53afd1c
215cdaf
 
b396701
 
29da619
 
2371a34
 
 
 
 
 
b396701
ae9a263
bff39e2
ae9a263
 
 
 
bff39e2
 
 
 
ae9a263
bff39e2
215cdaf
 
2371a34
215cdaf
 
f787272
b396701
f787272
 
 
 
2c151eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from fastapi import FastAPI, File, UploadFile
from fastapi import FastAPI, File, UploadFile, Form, Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse
from gradio_client import Client
import os
import io
app = FastAPI()

hf_token = os.environ.get('HF_TOKEN')
client = Client("https://ashrafb-image-to-sketch.hf.space/", hf_token=hf_token)

import tempfile



import base64

@app.post("/upload/")
async def upload_file(file: UploadFile = File(...)):
    with tempfile.NamedTemporaryFile(delete=False) as temp_file:
        temp_file.write(await file.read())
        temp_file_path = temp_file.name

    try:
        result = client.predict(temp_file_path, api_name="/predict")
        
        # Check if the result is valid
        if result and len(result) == 2:
            # Convert the image data to base64 string
            with open(result[0], "rb") as image_file:
                image_data = base64.b64encode(image_file.read()).decode("utf-8")

            return {
                "sketch_image_base64": f"data:image/png;base64,{image_data}",
                "result_file": result[1]
            }
        else:
            return {"error": "Invalid result from the prediction API."}
    except Exception as e:
        return {"error": str(e)}
    finally:
        if os.path.exists(temp_file_path):
            os.unlink(temp_file_path)


app.mount("/", StaticFiles(directory="static", html=True), name="static")

@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/static/index.html", media_type="text/html")