Ashrafb commited on
Commit
8513b9b
·
verified ·
1 Parent(s): 455b83d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +22 -31
main.py CHANGED
@@ -1,46 +1,37 @@
1
- from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
- from fastapi.responses import HTMLResponse
3
- from fastapi.staticfiles import StaticFiles
4
- from fastapi.templating import Jinja2Templates
5
- from fastapi.responses import FileResponse
6
  from gradio_client import Client
7
- import shutil
8
  import os
9
  import tempfile
10
-
11
-
12
 
13
  app = FastAPI()
14
 
 
 
 
15
 
16
-
17
- # Function to make API prediction
18
- def predict_image_description(file_path):
19
- hf_token = os.environ.get('HF_TOKEN')
20
- client = Client("Ashrafb/moondream1", hf_token=hf_token)
21
- result = client.predict(file_path, api_name="/get_caption")
22
  return result
23
 
24
-
25
-
26
-
27
- # Route to handle file upload
28
- @app.post("/uploadfile/")
29
- async def create_upload_file(file: UploadFile = File(...)):
30
  try:
 
31
  with tempfile.NamedTemporaryFile(delete=False) as temp_file:
32
- shutil.copyfileobj(file.file, temp_file)
33
  temp_file_path = temp_file.name
34
- description = predict_image_description(temp_file_path)
 
 
 
 
35
  os.unlink(temp_file_path)
36
- return {"description": description}
 
 
37
  except Exception as e:
38
  raise HTTPException(status_code=500, detail=str(e))
39
-
40
-
41
- app.mount("/", StaticFiles(directory="static", html=True), name="static")
42
-
43
- @app.get("/")
44
- def index() -> FileResponse:
45
- return FileResponse(path="/app/static/index.html", media_type="text/html")
46
-
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
 
 
 
 
2
  from gradio_client import Client
 
3
  import os
4
  import tempfile
5
+ import shutil
 
6
 
7
  app = FastAPI()
8
 
9
+ # Initialize Gradio Client
10
+ hf_token = os.environ.get('HF_TOKEN')
11
+ client = Client("Ashrafb/moondream_captioning", hf_token=hf_token)
12
 
13
+ # Function to generate captions for uploaded images
14
+ def generate_caption(image_path):
15
+ context = "describe" # Fixed context value
16
+ result = client.predict(image_path, context, api_name="/get_caption")
 
 
17
  return result
18
 
19
+ # Route to handle image uploads and generate captions
20
+ @app.post("/generate_caption/")
21
+ async def generate_image_caption(image: UploadFile = File(...)):
 
 
 
22
  try:
23
+ # Save the uploaded image to a temporary file
24
  with tempfile.NamedTemporaryFile(delete=False) as temp_file:
25
+ shutil.copyfileobj(image.file, temp_file)
26
  temp_file_path = temp_file.name
27
+
28
+ # Generate caption for the uploaded image
29
+ caption = generate_caption(temp_file_path)
30
+
31
+ # Clean up temporary file
32
  os.unlink(temp_file_path)
33
+
34
+ return {"caption": caption}
35
+
36
  except Exception as e:
37
  raise HTTPException(status_code=500, detail=str(e))