File size: 1,814 Bytes
416d779
a0dc8ce
54d4cab
 
1d48320
 
 
 
 
 
 
d44e115
bc1fd68
ba9ac6e
8513b9b
 
a0dc8ce
704df22
1d48320
 
66b60c0
a0dc8ce
1d48320
 
 
 
 
 
 
 
 
 
 
 
a0dc8ce
66b60c0
a0dc8ce
 
 
1d48320
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from gradio_client import Client
import os
import tempfile
import shutil
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.responses import FileResponse

app = FastAPI()

# Initialize Gradio Client
hf_token = os.environ.get('HF_TOKEN')
client = Client("https://ashrafb-moondream1.hf.space/--replicas/o2fdl/")

# Function to generate captions for uploaded images
def get_caption(image_path: str, question: str):
    try:
        result = client.predict(image_path, question, api_name="/predict")
        return result
    except Exception as e:
        error_msg = "Error occurred while generating caption: {}".format(str(e))
        raise HTTPException(status_code=500, detail={"error": error_msg})

@app.post("/predict/")
async def predict(image: UploadFile = File(...), question: str = Form(...)):
    try:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_image:
            shutil.copyfileobj(image.file, temp_image)
            temp_image_path = temp_image.name
        result = get_caption(temp_image_path, question)
        return JSONResponse(content={"result": result})
    except Exception as e:
        error_msg = "An unexpected error occurred while processing the request."
        detail_msg = str(e)
        raise HTTPException(status_code=500, detail={"error": error_msg, "detail": detail_msg})
app.mount("/", StaticFiles(directory="static", html=True), name="static")

@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/static/index.html", media_type="text/html")