MJobe commited on
Commit
cf71890
·
verified ·
1 Parent(s): ee808b2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +8 -3
main.py CHANGED
@@ -35,7 +35,7 @@ nlp_qa_v3 = pipeline("question-answering", model="deepset/roberta-base-squad2")
35
  nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
36
  nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
37
  nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
38
- nlp_sequence_classification = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
39
 
40
  description = """
41
  ## Image-based Document QA
@@ -285,14 +285,19 @@ labels = [
285
  @app.post("/fast_classify/", description="Quickly classify text into predefined categories.")
286
  async def fast_classify_text(statement: str = Form(...)):
287
  try:
288
- # Use zero-shot classification to classify statement into one of the provided labels
289
- result = nlp_sequence_classification(statement, labels, multi_label=False)
 
 
 
290
 
291
  # Extract the best label and score
292
  best_label = result["labels"][0]
293
  best_score = result["scores"][0]
294
 
295
  return {"classification": best_label, "confidence": best_score}
 
 
296
  except Exception as e:
297
  return JSONResponse(content=f"Error in classification: {str(e)}", status_code=500)
298
 
 
35
  nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
36
  nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
37
  nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
38
+ nlp_sequence_classification = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-1")
39
 
40
  description = """
41
  ## Image-based Document QA
 
285
  @app.post("/fast_classify/", description="Quickly classify text into predefined categories.")
286
  async def fast_classify_text(statement: str = Form(...)):
287
  try:
288
+ # Use asyncio to set a timeout for the classification
289
+ result = await asyncio.wait_for(
290
+ nlp_sequence_classification(statement, labels, multi_label=False),
291
+ timeout=5 # timeout in seconds
292
+ )
293
 
294
  # Extract the best label and score
295
  best_label = result["labels"][0]
296
  best_score = result["scores"][0]
297
 
298
  return {"classification": best_label, "confidence": best_score}
299
+ except asyncio.TimeoutError:
300
+ return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
301
  except Exception as e:
302
  return JSONResponse(content=f"Error in classification: {str(e)}", status_code=500)
303