Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -35,7 +35,7 @@ nlp_qa_v3 = pipeline("question-answering", model="deepset/roberta-base-squad2")
|
|
35 |
nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
|
36 |
nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
|
37 |
nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
|
38 |
-
nlp_sequence_classification = pipeline("zero-shot-classification", model="
|
39 |
|
40 |
description = """
|
41 |
## Image-based Document QA
|
@@ -285,14 +285,19 @@ labels = [
|
|
285 |
@app.post("/fast_classify/", description="Quickly classify text into predefined categories.")
|
286 |
async def fast_classify_text(statement: str = Form(...)):
|
287 |
try:
|
288 |
-
# Use
|
289 |
-
result =
|
|
|
|
|
|
|
290 |
|
291 |
# Extract the best label and score
|
292 |
best_label = result["labels"][0]
|
293 |
best_score = result["scores"][0]
|
294 |
|
295 |
return {"classification": best_label, "confidence": best_score}
|
|
|
|
|
296 |
except Exception as e:
|
297 |
return JSONResponse(content=f"Error in classification: {str(e)}", status_code=500)
|
298 |
|
|
|
35 |
nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
|
36 |
nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
|
37 |
nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
|
38 |
+
nlp_sequence_classification = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-1")
|
39 |
|
40 |
description = """
|
41 |
## Image-based Document QA
|
|
|
285 |
@app.post("/fast_classify/", description="Quickly classify text into predefined categories.")
|
286 |
async def fast_classify_text(statement: str = Form(...)):
|
287 |
try:
|
288 |
+
# Use asyncio to set a timeout for the classification
|
289 |
+
result = await asyncio.wait_for(
|
290 |
+
nlp_sequence_classification(statement, labels, multi_label=False),
|
291 |
+
timeout=5 # timeout in seconds
|
292 |
+
)
|
293 |
|
294 |
# Extract the best label and score
|
295 |
best_label = result["labels"][0]
|
296 |
best_score = result["scores"][0]
|
297 |
|
298 |
return {"classification": best_label, "confidence": best_score}
|
299 |
+
except asyncio.TimeoutError:
|
300 |
+
return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
|
301 |
except Exception as e:
|
302 |
return JSONResponse(content=f"Error in classification: {str(e)}", status_code=500)
|
303 |
|