MJobe commited on
Commit
6fe91f4
·
verified ·
1 Parent(s): 02494dd

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +36 -0
main.py CHANGED
@@ -311,6 +311,42 @@ async def fast_classify_text(statement: str = Form(...)):
311
  except Exception as e:
312
  # Handle general errors
313
  return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  # Set up CORS middleware
316
  origins = ["*"] # or specify your list of allowed origins
 
311
  except Exception as e:
312
  # Handle general errors
313
  return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
314
+
315
+ app.post("/fast_classify_v2/", description="Quickly classify text into predefined categories with confidence scores.")
316
+ async def fast_classify_text(statement: str = Form(...)):
317
+ try:
318
+ # Use run_in_executor to handle the synchronous model call asynchronously
319
+ loop = asyncio.get_running_loop()
320
+ result = await loop.run_in_executor(
321
+ executor,
322
+ lambda: nlp_sequence_classification(statement, labels, multi_label=False)
323
+ )
324
+
325
+ # Extract all labels and their scores
326
+ all_labels = result["labels"]
327
+ all_scores = result["scores"]
328
+
329
+ # Extract the best label and score
330
+ best_label = all_labels[0]
331
+ best_score = all_scores[0]
332
+
333
+ # Prepare the response
334
+ full_response = {
335
+ "classification": best_label,
336
+ "confidence": best_score,
337
+ "all_labels": {label: score for label, score in zip(all_labels, all_scores)}
338
+ }
339
+
340
+ return full_response
341
+ except asyncio.TimeoutError:
342
+ # Handle timeout
343
+ return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
344
+ except HTTPException as http_exc:
345
+ # Handle HTTP errors
346
+ return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
347
+ except Exception as e:
348
+ # Handle general errors
349
+ return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
350
 
351
  # Set up CORS middleware
352
  origins = ["*"] # or specify your list of allowed origins