|
from fastapi import FastAPI, UploadFile, File, HTTPException |
|
from gradio_client import Client |
|
import os |
|
import tempfile |
|
import shutil |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.templating import Jinja2Templates |
|
from fastapi.responses import FileResponse |
|
|
|
app = FastAPI() |
|
|
|
|
|
hf_token = os.environ.get('HF_TOKEN') |
|
client = Client("Ashrafb/moondream1", hf_token=hf_token) |
|
|
|
def get_caption(image, additional_context): |
|
if additional_context.strip(): |
|
context = additional_context |
|
else: |
|
context = "What is this image? Describe this image to someone who is visually impaired." |
|
result = client.predict(image, context, api_name="/predict") |
|
return result |
|
|
|
@app.post("/uploadfile/") |
|
async def upload_file(image: UploadFile = File(...), additional_context: str = Form(...)): |
|
try: |
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
temp_image_path = os.path.join(temp_dir, image.filename) |
|
|
|
with open(temp_image_path, "wb") as temp_image: |
|
shutil.copyfileobj(image.file, temp_image) |
|
|
|
with open(temp_image_path, "rb") as image_file: |
|
image_data = image_file.read() |
|
|
|
result = get_caption(image_data, additional_context) |
|
return {"result": result} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
app.mount("/", StaticFiles(directory="static", html=True), name="static") |
|
|
|
@app.get("/") |
|
def index() -> FileResponse: |
|
return FileResponse(path="/app/static/index.html", media_type="text/html") |
|
|
|
|
|
|