MJobe commited on
Commit
1d4a335
·
verified ·
1 Parent(s): 0cd7858

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +15 -7
main.py CHANGED
@@ -286,21 +286,29 @@ labels = [
286
  @app.post("/fast_classify/", description="Quickly classify text into predefined categories.")
287
  async def fast_classify_text(statement: str = Form(...)):
288
  try:
289
- # Use asyncio to set a timeout for the classification
290
- result = await asyncio.wait_for(
291
- nlp_sequence_classification(statement, labels, multi_label=False),
292
- timeout=5 # timeout in seconds
293
- )
 
 
 
294
 
295
- # Extract the best label and score
296
  best_label = result["labels"][0]
297
  best_score = result["scores"][0]
298
 
299
  return {"classification": best_label, "confidence": best_score}
300
  except asyncio.TimeoutError:
 
301
  return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
 
 
 
302
  except Exception as e:
303
- return JSONResponse(content=f"Error in classification: {str(e)}", status_code=500)
 
304
 
305
  # Set up CORS middleware
306
  origins = ["*"] # or specify your list of allowed origins
 
286
  @app.post("/fast_classify/", description="Quickly classify text into predefined categories.")
287
  async def fast_classify_text(statement: str = Form(...)):
288
  try:
289
+ # Create a thread pool executor for running the synchronous function asynchronously
290
+ loop = asyncio.get_running_loop()
291
+ with ThreadPoolExecutor() as executor:
292
+ # Run the classification pipeline in a separate thread
293
+ result = await loop.run_in_executor(
294
+ executor,
295
+ lambda: nlp_sequence_classification(statement, labels, multi_label=False)
296
+ )
297
 
298
+ # Extract the best label and score from the result
299
  best_label = result["labels"][0]
300
  best_score = result["scores"][0]
301
 
302
  return {"classification": best_label, "confidence": best_score}
303
  except asyncio.TimeoutError:
304
+ logging.error("Classification request timed out.")
305
  return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
306
+ except HTTPException as http_exc:
307
+ logging.error(f"HTTP exception: {http_exc.detail}")
308
+ return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
309
  except Exception as e:
310
+ logging.error(f"Error in classification pipeline: {str(e)}")
311
+ return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
312
 
313
  # Set up CORS middleware
314
  origins = ["*"] # or specify your list of allowed origins