from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.responses import JSONResponse from gradio_client import Client import os import tempfile import shutil from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from fastapi.responses import FileResponse app = FastAPI() client = Client("Ashrafb/moondream_captioning") # Create the "uploads" directory if it doesn't exist os.makedirs("uploads", exist_ok=True) import os import tempfile # Define a function to save uploaded file to a temporary file async def save_upload_file(upload_file: UploadFile) -> str: # Create a temporary directory if it doesn't exist os.makedirs("temp_uploads", exist_ok=True) # Create a temporary file path temp_file_path = os.path.join("temp_uploads", tempfile.NamedTemporaryFile().name) # Save the uploaded file to the temporary file with open(temp_file_path, "wb") as buffer: buffer.write(await upload_file.read()) return temp_file_path @app.post("/get_caption") async def get_caption(image: UploadFile = File(...), context: str = Form(...)): # Save the uploaded image to a temporary file temp_file_path = await save_upload_file(image) # Debugging: Print the value of additional_context print("Additional Context:", context) # Check if additional context is provided and not None if context is not None: context = context.strip() # Pass the temporary file path and context to the Gradio client for prediction result = client.predict(temp_file_path, context, api_name="/get_caption") return {"caption": result} app.mount("/", StaticFiles(directory="static", html=True), name="static") @app.get("/") def index() -> FileResponse: return FileResponse(path="/app/static/index.html", media_type="text/html")