File size: 1,969 Bytes
416d779
a0dc8ce
54d4cab
 
1d48320
 
 
 
 
 
 
d44e115
805c775
 
bc1fd68
ba9ac6e
805c775
 
 
 
 
ef1a321
 
 
 
805c775
ef1a321
 
 
 
 
 
805c775
ef1a321
 
805c775
 
1ad680e
805c775
ef1a321
30b9303
322247c
 
 
30b9303
 
 
 
 
ef1a321
30b9303
805c775
 
1d48320
 
 
 
 
ad9892d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from gradio_client import Client
import os
import tempfile
import shutil
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.responses import FileResponse



app = FastAPI()

client = Client("Ashrafb/moondream_captioning")

# Create the "uploads" directory if it doesn't exist
os.makedirs("uploads", exist_ok=True)

import os
import tempfile

# Define a function to save uploaded file to a temporary file
async def save_upload_file(upload_file: UploadFile) -> str:
    # Create a temporary directory if it doesn't exist
    os.makedirs("temp_uploads", exist_ok=True)
    # Create a temporary file path
    temp_file_path = os.path.join("temp_uploads", tempfile.NamedTemporaryFile().name)
    # Save the uploaded file to the temporary file
    with open(temp_file_path, "wb") as buffer:
        buffer.write(await upload_file.read())
    return temp_file_path


@app.post("/get_caption")
async def get_caption(image: UploadFile = File(...), context: str = Form(...)):
    # Save the uploaded image to a temporary file
    temp_file_path = await save_upload_file(image)
    
    # Debugging: Print the value of additional_context
    print("Additional Context:", context)
    
    # Check if additional context is provided and not None
    if context is not None:
        context = context.strip()
    
    # Pass the temporary file path and context to the Gradio client for prediction
    result = client.predict(temp_file_path, context, api_name="/get_caption")
    
    return {"caption": result}

app.mount("/", StaticFiles(directory="static", html=True), name="static")

@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/static/index.html", media_type="text/html")