Mnddoke / main.py
Ashrafb's picture
Update main.py
805c775 verified
raw
history blame
1.42 kB
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from gradio_client import Client
import os
import tempfile
import shutil
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.responses import FileResponse
app = FastAPI()
client = Client("Ashrafb/moondream_captioning")
# Create the "uploads" directory if it doesn't exist
os.makedirs("uploads", exist_ok=True)
# Define a function to save uploaded file
async def save_upload_file(upload_file: UploadFile) -> str:
file_path = os.path.join("uploads", upload_file.filename)
with open(file_path, "wb") as buffer:
buffer.write(await upload_file.read())
return file_path
@app.post("/get_caption")
async def get_caption(image: UploadFile = File(...), context: str = None):
# Save the uploaded image to a temporary file
file_path = await save_upload_file(image)
# Pass the file path to the client for prediction
result = client.predict(file_path, context, api_name="/get_caption")
return {"caption": result}
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")