Update main.py
Browse files
main.py
CHANGED
@@ -38,8 +38,14 @@ async def save_upload_file(upload_file: UploadFile) -> str:
|
|
38 |
async def get_caption(image: UploadFile = File(...), context: str = None):
|
39 |
# Save the uploaded image to a temporary file
|
40 |
temp_file_path = await save_upload_file(image)
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
42 |
result = client.predict(temp_file_path, context, api_name="/get_caption")
|
|
|
43 |
return {"caption": result}
|
44 |
|
45 |
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
|
|
38 |
async def get_caption(image: UploadFile = File(...), context: str = None):
|
39 |
# Save the uploaded image to a temporary file
|
40 |
temp_file_path = await save_upload_file(image)
|
41 |
+
|
42 |
+
# Check if additional context is provided and not None
|
43 |
+
if context is not None:
|
44 |
+
context = context.strip()
|
45 |
+
|
46 |
+
# Pass the temporary file path and context to the Gradio client for prediction
|
47 |
result = client.predict(temp_file_path, context, api_name="/get_caption")
|
48 |
+
|
49 |
return {"caption": result}
|
50 |
|
51 |
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|