File size: 1,944 Bytes
8513b9b
54d4cab
 
 
8513b9b
2844d41
 
 
 
d44e115
bc1fd68
ba9ac6e
8513b9b
 
f9294c6
8513b9b
f9294c6
 
 
 
 
 
 
802a594
47080b7
f9294c6
54d4cab
f9294c6
 
 
 
 
 
 
 
 
 
 
 
54d4cab
 
158c342
f9294c6
158c342
 
 
 
 
f9294c6
158c342
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from fastapi import FastAPI, UploadFile, File, HTTPException
from gradio_client import Client
import os
import tempfile
import shutil
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.responses import FileResponse

app = FastAPI()

# Initialize Gradio Client
hf_token = os.environ.get('HF_TOKEN')
client = Client("Ashrafb/moondream1", hf_token=hf_token)
# Function to generate captions for uploaded images
def get_caption(image, additional_context):
    if additional_context.strip():  # Check if additional_context is not empty
        context = additional_context
    else:
        context = "What is this image? Describe this image to someone who is visually impaired."
    result = client.predict(image, context, api_name="/predict")
    return result  

@app.post("/uploadfile/")
async def upload_file(image: UploadFile = File(...), additional_context: str = Form(...)):
    try:
        # Create a temporary directory to store the uploaded image
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_image_path = os.path.join(temp_dir, image.filename)
            # Write the uploaded image data to a temporary file
            with open(temp_image_path, "wb") as temp_image:
                shutil.copyfileobj(image.file, temp_image)
            # Read the image data
            with open(temp_image_path, "rb") as image_file:
                image_data = image_file.read()
            # Generate caption using Gradio client
            result = get_caption(image_data, additional_context)
            return {"result": result}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


app.mount("/", StaticFiles(directory="static", html=True), name="static")

@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/static/index.html", media_type="text/html")