Ashrafb commited on
Commit
704df22
·
verified ·
1 Parent(s): a65c154

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +12 -17
main.py CHANGED
@@ -1,26 +1,20 @@
1
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
- from fastapi.responses import HTMLResponse
3
  from fastapi.staticfiles import StaticFiles
4
- from fastapi.templating import Jinja2Templates
5
- from fastapi.responses import FileResponse
6
- from fastapi import FastAPI, UploadFile, File, HTTPException
7
  from gradio_client import Client
8
  import os
9
  import tempfile
10
  import shutil
11
- from fastapi.responses import HTMLResponse
12
- from fastapi.staticfiles import StaticFiles
13
- from fastapi.templating import Jinja2Templates
14
- from fastapi.responses import FileResponse
15
 
16
  app = FastAPI()
17
 
18
  # Initialize Gradio Client
19
  hf_token = os.environ.get('HF_TOKEN')
20
  client = Client("Ashrafb/moondream1", hf_token=hf_token)
 
21
  # Function to generate captions for uploaded images
22
  def get_caption(image, additional_context):
23
- if additional_context.strip(): # Check if additional_context is not empty
24
  context = additional_context
25
  else:
26
  context = "What is this image? Describe this image to someone who is visually impaired."
@@ -30,26 +24,27 @@ def get_caption(image, additional_context):
30
  @app.post("/uploadfile/")
31
  async def upload_file(image: UploadFile = File(...), additional_context: str = Form(...)):
32
  try:
33
- # Create a temporary directory to store the uploaded image
34
  with tempfile.TemporaryDirectory() as temp_dir:
35
  temp_image_path = os.path.join(temp_dir, image.filename)
36
- # Write the uploaded image data to a temporary file
37
  with open(temp_image_path, "wb") as temp_image:
38
  shutil.copyfileobj(image.file, temp_image)
39
- # Read the image data
40
  with open(temp_image_path, "rb") as image_file:
41
  image_data = image_file.read()
42
- # Generate caption using Gradio client
43
  result = get_caption(image_data, additional_context)
44
  return {"result": result}
 
 
 
45
  except Exception as e:
46
- raise HTTPException(status_code=500, detail=str(e))
47
-
 
 
48
 
 
49
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
50
 
 
51
  @app.get("/")
52
  def index() -> FileResponse:
53
  return FileResponse(path="/app/static/index.html", media_type="text/html")
54
-
55
-
 
1
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
+ from fastapi.responses import HTMLResponse, FileResponse
3
  from fastapi.staticfiles import StaticFiles
 
 
 
4
  from gradio_client import Client
5
  import os
6
  import tempfile
7
  import shutil
 
 
 
 
8
 
9
  app = FastAPI()
10
 
11
  # Initialize Gradio Client
12
  hf_token = os.environ.get('HF_TOKEN')
13
  client = Client("Ashrafb/moondream1", hf_token=hf_token)
14
+
15
  # Function to generate captions for uploaded images
16
  def get_caption(image, additional_context):
17
+ if additional_context.strip():
18
  context = additional_context
19
  else:
20
  context = "What is this image? Describe this image to someone who is visually impaired."
 
24
  @app.post("/uploadfile/")
25
  async def upload_file(image: UploadFile = File(...), additional_context: str = Form(...)):
26
  try:
 
27
  with tempfile.TemporaryDirectory() as temp_dir:
28
  temp_image_path = os.path.join(temp_dir, image.filename)
 
29
  with open(temp_image_path, "wb") as temp_image:
30
  shutil.copyfileobj(image.file, temp_image)
 
31
  with open(temp_image_path, "rb") as image_file:
32
  image_data = image_file.read()
 
33
  result = get_caption(image_data, additional_context)
34
  return {"result": result}
35
+ except HTTPException as http_exc:
36
+ # Handle HTTPException raised by FastAPI
37
+ raise http_exc
38
  except Exception as e:
39
+ # Handle other unexpected exceptions
40
+ error_msg = "An unexpected error occurred while processing the request."
41
+ detail_msg = str(e)
42
+ raise HTTPException(status_code=500, detail={"error": error_msg, "detail": detail_msg})
43
 
44
+ # Serve static files
45
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
46
 
47
+ # Serve index.html
48
  @app.get("/")
49
  def index() -> FileResponse:
50
  return FileResponse(path="/app/static/index.html", media_type="text/html")