Update main.py
Browse files
main.py
CHANGED
@@ -1,46 +1,37 @@
|
|
1 |
-
from fastapi import FastAPI,
|
2 |
-
from fastapi.responses import HTMLResponse
|
3 |
-
from fastapi.staticfiles import StaticFiles
|
4 |
-
from fastapi.templating import Jinja2Templates
|
5 |
-
from fastapi.responses import FileResponse
|
6 |
from gradio_client import Client
|
7 |
-
import shutil
|
8 |
import os
|
9 |
import tempfile
|
10 |
-
|
11 |
-
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
client = Client("Ashrafb/moondream1", hf_token=hf_token)
|
21 |
-
result = client.predict(file_path, api_name="/get_caption")
|
22 |
return result
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
# Route to handle file upload
|
28 |
-
@app.post("/uploadfile/")
|
29 |
-
async def create_upload_file(file: UploadFile = File(...)):
|
30 |
try:
|
|
|
31 |
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
32 |
-
shutil.copyfileobj(
|
33 |
temp_file_path = temp_file.name
|
34 |
-
|
|
|
|
|
|
|
|
|
35 |
os.unlink(temp_file_path)
|
36 |
-
|
|
|
|
|
37 |
except Exception as e:
|
38 |
raise HTTPException(status_code=500, detail=str(e))
|
39 |
-
|
40 |
-
|
41 |
-
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
42 |
-
|
43 |
-
@app.get("/")
|
44 |
-
def index() -> FileResponse:
|
45 |
-
return FileResponse(path="/app/static/index.html", media_type="text/html")
|
46 |
-
|
|
|
1 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
|
|
|
|
|
|
|
|
2 |
from gradio_client import Client
|
|
|
3 |
import os
|
4 |
import tempfile
|
5 |
+
import shutil
|
|
|
6 |
|
7 |
app = FastAPI()
|
8 |
|
9 |
+
# Initialize Gradio Client
|
10 |
+
hf_token = os.environ.get('HF_TOKEN')
|
11 |
+
client = Client("Ashrafb/moondream_captioning", hf_token=hf_token)
|
12 |
|
13 |
+
# Function to generate captions for uploaded images
|
14 |
+
def generate_caption(image_path):
|
15 |
+
context = "describe" # Fixed context value
|
16 |
+
result = client.predict(image_path, context, api_name="/get_caption")
|
|
|
|
|
17 |
return result
|
18 |
|
19 |
+
# Route to handle image uploads and generate captions
|
20 |
+
@app.post("/generate_caption/")
|
21 |
+
async def generate_image_caption(image: UploadFile = File(...)):
|
|
|
|
|
|
|
22 |
try:
|
23 |
+
# Save the uploaded image to a temporary file
|
24 |
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
25 |
+
shutil.copyfileobj(image.file, temp_file)
|
26 |
temp_file_path = temp_file.name
|
27 |
+
|
28 |
+
# Generate caption for the uploaded image
|
29 |
+
caption = generate_caption(temp_file_path)
|
30 |
+
|
31 |
+
# Clean up temporary file
|
32 |
os.unlink(temp_file_path)
|
33 |
+
|
34 |
+
return {"caption": caption}
|
35 |
+
|
36 |
except Exception as e:
|
37 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|