File size: 1,956 Bytes
416d779
704df22
416d779
54d4cab
 
 
8513b9b
d44e115
bc1fd68
ba9ac6e
8513b9b
 
f9294c6
704df22
8513b9b
f9294c6
66b60c0
 
 
 
 
 
 
 
 
 
802a594
47080b7
f9294c6
54d4cab
f9294c6
 
 
 
 
 
 
 
54d4cab
66b60c0
 
f9294c6
704df22
158c342
 
704df22
158c342
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from gradio_client import Client
import os
import tempfile
import shutil

app = FastAPI()

# Initialize Gradio Client
hf_token = os.environ.get('HF_TOKEN')
client = Client("Ashrafb/moondream1", hf_token=hf_token)

# Function to generate captions for uploaded images
def get_caption(image, additional_context):
    try:
        if additional_context.strip():  
            context = additional_context
        else:
            context = "What is this image? Describe this image to someone who is visually impaired."
        result = client.predict(image, context, api_name="/predict")
        return result
    except Exception as e:
        error_msg = "Error occurred while generating caption: {}".format(str(e))
        raise HTTPException(status_code=500, detail={"error": error_msg})

@app.post("/uploadfile/")
async def upload_file(image: UploadFile = File(...), additional_context: str = Form(...)):
    try:
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_image_path = os.path.join(temp_dir, image.filename)
            with open(temp_image_path, "wb") as temp_image:
                shutil.copyfileobj(image.file, temp_image)
            with open(temp_image_path, "rb") as image_file:
                image_data = image_file.read()
            result = get_caption(image_data, additional_context)
            return {"result": result}
    except Exception as e:
        error_msg = "Error occurred while processing image: {}".format(str(e))
        raise HTTPException(status_code=500, detail={"error": error_msg})

# Serve static files
app.mount("/", StaticFiles(directory="static", html=True), name="static")

# Serve index.html
@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/static/index.html", media_type="text/html")