Ashrafb commited on
Commit
805c775
·
verified ·
1 Parent(s): ad9892d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +22 -25
main.py CHANGED
@@ -10,33 +10,30 @@ from fastapi.staticfiles import StaticFiles
10
  from fastapi.templating import Jinja2Templates
11
  from fastapi.responses import FileResponse
12
 
 
 
13
  app = FastAPI()
14
 
15
- # Initialize Gradio Client
16
- hf_token = os.environ.get('HF_TOKEN')
17
- client = Client("https://ashrafb-moondream1.hf.space/--replicas/o2fdl/")
18
-
19
- # Function to generate captions for uploaded images
20
- def get_caption(image_path: str, question: str):
21
- try:
22
- result = client.predict(image_path, question, api_name="/predict")
23
- return result
24
- except Exception as e:
25
- error_msg = "Error occurred while generating caption: {}".format(str(e))
26
- raise HTTPException(status_code=500, detail={"error": error_msg})
27
-
28
- @app.post("/predict/")
29
- async def predict(image: UploadFile = File(...), question: str = Form(...)):
30
- try:
31
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_image:
32
- shutil.copyfileobj(image.file, temp_image)
33
- temp_image_path = temp_image.name
34
- result = get_caption(temp_image_path, question)
35
- return JSONResponse(content={"result": result})
36
- except Exception as e:
37
- error_msg = "An unexpected error occurred while processing the request."
38
- detail_msg = str(e)
39
- raise HTTPException(status_code=500, detail={"error": error_msg, "detail": detail_msg})
40
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
41
 
42
  @app.get("/")
 
10
  from fastapi.templating import Jinja2Templates
11
  from fastapi.responses import FileResponse
12
 
13
+
14
+
15
  app = FastAPI()
16
 
17
+ client = Client("Ashrafb/moondream_captioning")
18
+
19
+ # Create the "uploads" directory if it doesn't exist
20
+ os.makedirs("uploads", exist_ok=True)
21
+
22
+ # Define a function to save uploaded file
23
+ async def save_upload_file(upload_file: UploadFile) -> str:
24
+ file_path = os.path.join("uploads", upload_file.filename)
25
+ with open(file_path, "wb") as buffer:
26
+ buffer.write(await upload_file.read())
27
+ return file_path
28
+
29
+ @app.post("/get_caption")
30
+ async def get_caption(image: UploadFile = File(...), context: str = None):
31
+ # Save the uploaded image to a temporary file
32
+ file_path = await save_upload_file(image)
33
+ # Pass the file path to the client for prediction
34
+ result = client.predict(file_path, context, api_name="/get_caption")
35
+ return {"caption": result}
36
+
 
 
 
 
 
37
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
38
 
39
  @app.get("/")