|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
|
from fastapi.responses import JSONResponse |
|
from gradio_client import Client |
|
import os |
|
import tempfile |
|
import shutil |
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.templating import Jinja2Templates |
|
from fastapi.responses import FileResponse |
|
|
|
app = FastAPI() |
|
|
|
|
|
hf_token = os.environ.get('HF_TOKEN') |
|
client = Client("https://ashrafb-moondream1.hf.space/--replicas/o2fdl/") |
|
|
|
|
|
def get_caption(image_path: str, question: str): |
|
try: |
|
result = client.predict(image_path, question, api_name="/predict") |
|
return result |
|
except Exception as e: |
|
error_msg = "Error occurred while generating caption: {}".format(str(e)) |
|
raise HTTPException(status_code=500, detail={"error": error_msg}) |
|
|
|
@app.post("/predict/") |
|
async def predict(image: UploadFile = File(...), question: str = Form(...)): |
|
try: |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_image: |
|
shutil.copyfileobj(image.file, temp_image) |
|
temp_image_path = temp_image.name |
|
result = get_caption(temp_image_path, question) |
|
return JSONResponse(content={"result": result}) |
|
except Exception as e: |
|
error_msg = "An unexpected error occurred while processing the request." |
|
detail_msg = str(e) |
|
raise HTTPException(status_code=500, detail={"error": error_msg, "detail": detail_msg}) |
|
app.mount("/", StaticFiles(directory="static", html=True), name="static") |
|
|
|
@app.get("/") |
|
def index() -> FileResponse: |
|
return FileResponse(path="/app/static/index.html", media_type="text/html") |
|
|