Ashrafb commited on
Commit
c978c43
·
verified ·
1 Parent(s): d3b357f

Create main1.py

Browse files
Files changed (1) hide show
  1. main1.py +75 -0
main1.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from gradio_client import Client, handle_file
4
+ import os
5
+ import tempfile
6
+ import shutil
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ import logging
9
+
10
+ # Initialize logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ app = FastAPI()
15
+
16
+ client = Client("Ashrafb/moondream_captioning")
17
+
18
+ # Create the "uploads" directory if it doesn't exist
19
+ os.makedirs("uploads", exist_ok=True)
20
+
21
+ # Define a function to save uploaded file to a temporary file
22
+ async def save_upload_file(upload_file: UploadFile) -> str:
23
+ try:
24
+ # Create a temporary directory if it doesn't exist
25
+ os.makedirs("temp_uploads", exist_ok=True)
26
+ # Create a temporary file path
27
+ temp_file_path = os.path.join("temp_uploads", tempfile.NamedTemporaryFile(delete=False).name)
28
+ # Save the uploaded file to the temporary file
29
+ with open(temp_file_path, "wb") as buffer:
30
+ shutil.copyfileobj(upload_file.file, buffer)
31
+ return temp_file_path
32
+ except Exception as e:
33
+ logger.error(f"Error saving upload file: {e}")
34
+ raise HTTPException(status_code=500, detail=f"Error saving upload file: {e}")
35
+
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=["*"], # Adjust as needed, '*' allows requests from any origin
39
+ allow_credentials=True,
40
+ allow_methods=["*"],
41
+ allow_headers=["*"],
42
+ )
43
+
44
+ @app.post("/get_caption")
45
+ async def get_caption(image: UploadFile = File(...), context: str = Form(...)):
46
+ try:
47
+ # Save the uploaded image to a temporary file
48
+ temp_file_path = await save_upload_file(image)
49
+
50
+ # Debugging: Print the value of additional_context
51
+ logger.info(f"Additional Context: {context}")
52
+
53
+ # Check if additional context is provided and not None
54
+ if context is not None:
55
+ context = context.strip()
56
+
57
+ # Log the parameters being passed to the Gradio client
58
+ logger.info(f"Calling client.predict with image={temp_file_path} and context={context}")
59
+
60
+ # Use handle_file to handle the file upload correctly
61
+ result = client.predict(
62
+ image=handle_file(temp_file_path),
63
+ question=context,
64
+ api_name="/get_caption"
65
+ )
66
+
67
+ return {"caption": result}
68
+ except Exception as e:
69
+ logger.error(f"Error in get_caption: {e}")
70
+ raise HTTPException(status_code=500, detail=f"Error in get_caption: {e}")
71
+
72
+ # Serve the app
73
+ if __name__ == "__main__":
74
+ import uvicorn
75
+ uvicorn.run(app, host="0.0.0.0", port=7860)