Ashrafb commited on
Commit
f9294c6
·
verified ·
1 Parent(s): 3df8281

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +23 -21
main.py CHANGED
@@ -12,37 +12,39 @@ app = FastAPI()
12
 
13
  # Initialize Gradio Client
14
  hf_token = os.environ.get('HF_TOKEN')
15
- client = Client("Ashrafb/moondream_captioning", hf_token=hf_token)
16
-
17
  # Function to generate captions for uploaded images
18
- def generate_caption(image_path):
19
- context = "describe" # Fixed context value
20
- result = client.predict(image_path, context, api_name="/get_caption")
21
- return result
 
 
 
22
 
23
- # Route to handle image uploads and generate captions
24
  @app.post("/uploadfile/")
25
- async def generate_image_caption(file: UploadFile = File(...)):
26
  try:
27
- # Save the uploaded image to a temporary file
28
- with tempfile.NamedTemporaryFile(delete=False) as temp_file:
29
- shutil.copyfileobj(file.file, temp_file)
30
- temp_file_path = temp_file.name
31
-
32
- # Generate caption for the uploaded image
33
- caption = generate_caption(temp_file_path)
34
-
35
- # Clean up temporary file
36
- os.unlink(temp_file_path)
37
-
38
- return {"caption": caption}
39
-
40
  except Exception as e:
41
  raise HTTPException(status_code=500, detail=str(e))
42
 
 
43
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
44
 
45
  @app.get("/")
46
  def index() -> FileResponse:
47
  return FileResponse(path="/app/static/index.html", media_type="text/html")
 
48
 
 
12
 
13
  # Initialize Gradio Client
14
  hf_token = os.environ.get('HF_TOKEN')
15
+ client = Client("Ashrafb/moondream1", hf_token=hf_token)
 
16
  # Function to generate captions for uploaded images
17
+ def get_caption(image, additional_context):
18
+ if additional_context.strip(): # Check if additional_context is not empty
19
+ context = additional_context
20
+ else:
21
+ context = "What is this image? Describe this image to someone who is visually impaired."
22
+ result = client.predict(image, context, api_name="/predict")
23
+ return result
24
 
 
25
  @app.post("/uploadfile/")
26
+ async def upload_file(image: UploadFile = File(...), additional_context: str = Form(...)):
27
  try:
28
+ # Create a temporary directory to store the uploaded image
29
+ with tempfile.TemporaryDirectory() as temp_dir:
30
+ temp_image_path = os.path.join(temp_dir, image.filename)
31
+ # Write the uploaded image data to a temporary file
32
+ with open(temp_image_path, "wb") as temp_image:
33
+ shutil.copyfileobj(image.file, temp_image)
34
+ # Read the image data
35
+ with open(temp_image_path, "rb") as image_file:
36
+ image_data = image_file.read()
37
+ # Generate caption using Gradio client
38
+ result = get_caption(image_data, additional_context)
39
+ return {"result": result}
 
40
  except Exception as e:
41
  raise HTTPException(status_code=500, detail=str(e))
42
 
43
+
44
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
45
 
46
  @app.get("/")
47
  def index() -> FileResponse:
48
  return FileResponse(path="/app/static/index.html", media_type="text/html")
49
+
50