|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
|
from fastapi.responses import JSONResponse |
|
from gradio_client import Client |
|
import os |
|
import tempfile |
|
import shutil |
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.templating import Jinja2Templates |
|
from fastapi.responses import FileResponse |
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
client = Client("Ashrafb/moondream_captioning") |
|
|
|
|
|
os.makedirs("uploads", exist_ok=True) |
|
|
|
import os |
|
import tempfile |
|
|
|
|
|
async def save_upload_file(upload_file: UploadFile) -> str: |
|
|
|
os.makedirs("temp_uploads", exist_ok=True) |
|
|
|
temp_file_path = os.path.join("temp_uploads", tempfile.NamedTemporaryFile().name) |
|
|
|
with open(temp_file_path, "wb") as buffer: |
|
buffer.write(await upload_file.read()) |
|
return temp_file_path |
|
|
|
|
|
@app.post("/get_caption") |
|
async def get_caption(image: UploadFile = File(...), context: str = Form(...)): |
|
|
|
temp_file_path = await save_upload_file(image) |
|
|
|
|
|
print("Additional Context:", context) |
|
|
|
|
|
if context is not None: |
|
context = context.strip() |
|
|
|
|
|
result = client.predict(temp_file_path, context, api_name="/get_caption") |
|
|
|
return {"caption": result} |
|
|
|
app.mount("/", StaticFiles(directory="static", html=True), name="static") |
|
|
|
@app.get("/") |
|
def index() -> FileResponse: |
|
return FileResponse(path="/app/static/index.html", media_type="text/html") |
|
|
|
|