Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -286,21 +286,29 @@ labels = [
|
|
286 |
@app.post("/fast_classify/", description="Quickly classify text into predefined categories.")
|
287 |
async def fast_classify_text(statement: str = Form(...)):
|
288 |
try:
|
289 |
-
#
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
|
|
|
|
|
|
294 |
|
295 |
-
# Extract the best label and score
|
296 |
best_label = result["labels"][0]
|
297 |
best_score = result["scores"][0]
|
298 |
|
299 |
return {"classification": best_label, "confidence": best_score}
|
300 |
except asyncio.TimeoutError:
|
|
|
301 |
return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
|
|
|
|
|
|
|
302 |
except Exception as e:
|
303 |
-
|
|
|
304 |
|
305 |
# Set up CORS middleware
|
306 |
origins = ["*"] # or specify your list of allowed origins
|
|
|
286 |
@app.post("/fast_classify/", description="Quickly classify text into predefined categories.")
|
287 |
async def fast_classify_text(statement: str = Form(...)):
|
288 |
try:
|
289 |
+
# Create a thread pool executor for running the synchronous function asynchronously
|
290 |
+
loop = asyncio.get_running_loop()
|
291 |
+
with ThreadPoolExecutor() as executor:
|
292 |
+
# Run the classification pipeline in a separate thread
|
293 |
+
result = await loop.run_in_executor(
|
294 |
+
executor,
|
295 |
+
lambda: nlp_sequence_classification(statement, labels, multi_label=False)
|
296 |
+
)
|
297 |
|
298 |
+
# Extract the best label and score from the result
|
299 |
best_label = result["labels"][0]
|
300 |
best_score = result["scores"][0]
|
301 |
|
302 |
return {"classification": best_label, "confidence": best_score}
|
303 |
except asyncio.TimeoutError:
|
304 |
+
logging.error("Classification request timed out.")
|
305 |
return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
|
306 |
+
except HTTPException as http_exc:
|
307 |
+
logging.error(f"HTTP exception: {http_exc.detail}")
|
308 |
+
return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
|
309 |
except Exception as e:
|
310 |
+
logging.error(f"Error in classification pipeline: {str(e)}")
|
311 |
+
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
312 |
|
313 |
# Set up CORS middleware
|
314 |
origins = ["*"] # or specify your list of allowed origins
|