MJobe commited on
Commit
31d9e37
·
verified ·
1 Parent(s): fe58c62

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +89 -0
main.py CHANGED
@@ -39,6 +39,8 @@ nlp_classification = pipeline("text-classification", model="distilbert/distilber
39
  nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
40
  nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
41
  nlp_sequence_classification = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
 
 
42
  description = """
43
  ## Image-based Document QA
44
  This API performs document question answering using a LayoutLMv2-based model.
@@ -365,6 +367,93 @@ async def fast_classify_text(statement: str = Form(...)):
365
  except Exception as e:
366
  # Handle general errors
367
  return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
  # Set up CORS middleware
370
  origins = ["*"] # or specify your list of allowed origins
 
39
  nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
40
  nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
41
  nlp_sequence_classification = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
42
+ nlp_main_classification = pipeline("zero-shot-classification", model="roberta-large-mnli")
43
+
44
  description = """
45
  ## Image-based Document QA
46
  This API performs document question answering using a LayoutLMv2-based model.
 
367
  except Exception as e:
368
  # Handle general errors
369
  return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
370
+
371
+ # Labels for main classification
372
+ labels = [
373
+ "Change to quote",
374
+ "Copy quote requested",
375
+ "Expired Quote",
376
+ "Notes not clear"
377
+ ]
378
+
379
+ # Keywords for sub-classifications
380
+ keyword_map = {
381
+ "MRSP": ["MSRP", "MRSP copy quote", "msrp only"],
382
+ "Direct": ["Direct quote", "send directly"],
383
+ "All": ["All Pricing", "all pricing"],
384
+ "MRSP & All": ["MSRP & All Pricing", "msrp only with all pricing"]
385
+ }
386
+
387
+ # Function to detect if input is blank or vague
388
+ def is_blank_or_vague(text):
389
+ # Checks for empty or only contains general filler words (adjust as needed)
390
+ return not text.strip() or re.match(r'^\s*(please|send|quote|request|thank you|thanks)\s*$', text, re.IGNORECASE)
391
+
392
+ # Function to identify sub-classifications based on keywords
393
+ def get_sub_classification(text):
394
+ sub_labels = []
395
+ for sub_class, keywords in keyword_map.items():
396
+ if any(keyword.lower() in text.lower() for keyword in keywords):
397
+ sub_labels.append(sub_class)
398
+ return sub_labels if sub_labels else ["Uncategorized"]
399
+
400
+ @app.post("/classify_text/")
401
+ async def classify_text(statement: str = Form(...)):
402
+ try:
403
+ # Handle blank or vague text as "Notes not clear"
404
+ if is_blank_or_vague(statement):
405
+ return {
406
+ "main_classification": {
407
+ "label": "Notes not clear",
408
+ "confidence": 1.0,
409
+ "scores": {"Notes not clear": 1.0}
410
+ },
411
+ "sub_classification": {
412
+ "labels": ["Uncategorized"],
413
+ "scores": {"Uncategorized": 1.0}
414
+ }
415
+ }
416
+
417
+ # Run main classification in executor for async handling
418
+ loop = asyncio.get_running_loop()
419
+ main_classification_task = loop.run_in_executor(
420
+ None,
421
+ lambda: nlp_main_classification(statement, labels)
422
+ )
423
+
424
+ # Await result
425
+ main_class_result = await main_classification_task
426
+
427
+ # Extract main classification label and scores
428
+ main_class_scores = {label: score for label, score in zip(main_class_result["labels"], main_class_result["scores"])}
429
+ best_main_classification = main_class_result["labels"][0]
430
+ best_main_score = main_class_result["scores"][0]
431
+
432
+ # Detect sub-classifications using keywords
433
+ sub_classification = get_sub_classification(statement)
434
+
435
+ # Assign default high confidence for keyword-based sub-classification
436
+ sub_class_scores = {sub: 1.0 for sub in sub_classification}
437
+
438
+ # Return results
439
+ return {
440
+ "main_classification": {
441
+ "label": best_main_classification,
442
+ "confidence": best_main_score,
443
+ "scores": main_class_scores
444
+ },
445
+ "sub_classification": {
446
+ "labels": sub_classification,
447
+ "scores": sub_class_scores
448
+ }
449
+ }
450
+
451
+ except asyncio.TimeoutError:
452
+ return JSONResponse(content="Classification timed out.", status_code=504)
453
+ except HTTPException as http_exc:
454
+ return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
455
+ except Exception as e:
456
+ return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
457
 
458
  # Set up CORS middleware
459
  origins = ["*"] # or specify your list of allowed origins