Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -311,6 +311,42 @@ async def fast_classify_text(statement: str = Form(...)):
|
|
311 |
except Exception as e:
|
312 |
# Handle general errors
|
313 |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
|
315 |
# Set up CORS middleware
|
316 |
origins = ["*"] # or specify your list of allowed origins
|
|
|
311 |
except Exception as e:
|
312 |
# Handle general errors
|
313 |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
314 |
+
|
315 |
+
app.post("/fast_classify_v2/", description="Quickly classify text into predefined categories with confidence scores.")
|
316 |
+
async def fast_classify_text(statement: str = Form(...)):
|
317 |
+
try:
|
318 |
+
# Use run_in_executor to handle the synchronous model call asynchronously
|
319 |
+
loop = asyncio.get_running_loop()
|
320 |
+
result = await loop.run_in_executor(
|
321 |
+
executor,
|
322 |
+
lambda: nlp_sequence_classification(statement, labels, multi_label=False)
|
323 |
+
)
|
324 |
+
|
325 |
+
# Extract all labels and their scores
|
326 |
+
all_labels = result["labels"]
|
327 |
+
all_scores = result["scores"]
|
328 |
+
|
329 |
+
# Extract the best label and score
|
330 |
+
best_label = all_labels[0]
|
331 |
+
best_score = all_scores[0]
|
332 |
+
|
333 |
+
# Prepare the response
|
334 |
+
full_response = {
|
335 |
+
"classification": best_label,
|
336 |
+
"confidence": best_score,
|
337 |
+
"all_labels": {label: score for label, score in zip(all_labels, all_scores)}
|
338 |
+
}
|
339 |
+
|
340 |
+
return full_response
|
341 |
+
except asyncio.TimeoutError:
|
342 |
+
# Handle timeout
|
343 |
+
return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
|
344 |
+
except HTTPException as http_exc:
|
345 |
+
# Handle HTTP errors
|
346 |
+
return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
|
347 |
+
except Exception as e:
|
348 |
+
# Handle general errors
|
349 |
+
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
350 |
|
351 |
# Set up CORS middleware
|
352 |
origins = ["*"] # or specify your list of allowed origins
|