tejash300 commited on
Commit
6bbd6b4
·
verified ·
1 Parent(s): deeb866

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -14
app.py CHANGED
@@ -13,7 +13,7 @@ import numpy as np
13
  import json
14
  import tempfile
15
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form, BackgroundTasks
16
- from fastapi.responses import FileResponse, JSONResponse, HTMLResponse # Added HTMLResponse
17
  from fastapi.middleware.cors import CORSMiddleware
18
  from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
19
  from sentence_transformers import SentenceTransformer
@@ -31,6 +31,9 @@ from starlette.concurrency import run_in_threadpool
31
  import gensim
32
  from gensim import corpora, models
33
 
 
 
 
34
  # Global cache for analysis results based on file hash
35
  analysis_cache = {}
36
 
@@ -197,15 +200,13 @@ try:
197
  nlp = spacy.load("en_core_web_sm")
198
  print("✅ Loading NLP models...")
199
 
200
- # Update summarizer to use the LED model for long-document summarization
201
- from transformers import LEDTokenizer
202
  summarizer = pipeline(
203
  "summarization",
204
- model="allenai/led-large-16384",
205
- tokenizer="allenai/led-large-16384",
206
  device=0 if torch.cuda.is_available() else -1
207
  )
208
- # Optionally convert summarizer model to FP16 for faster inference on GPU (if supported)
209
  if device == "cuda":
210
  try:
211
  summarizer.model.half()
@@ -235,8 +236,6 @@ except Exception as e:
235
 
236
  from transformers import pipeline
237
  qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
238
-
239
- # Initialize sentiment-analysis pipeline
240
  sentiment_pipeline = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=0 if torch.cuda.is_available() else -1)
241
 
242
  def legal_chatbot(user_input, context):
@@ -263,10 +262,8 @@ async def process_video_to_text(video_file_path):
263
  "-acodec", "pcm_s16le", "-ar", "44100", "-ac", "2",
264
  temp_audio_path, "-y"
265
  ]
266
- # Run ffmpeg in a separate thread
267
  await run_in_threadpool(subprocess.run, cmd, check=True)
268
  print(f"Audio extracted to {temp_audio_path}")
269
- # Run speech-to-text in threadpool
270
  result = await run_in_threadpool(speech_to_text, temp_audio_path)
271
  transcript = result["text"]
272
  print(f"Transcription completed: {len(transcript)} characters")
@@ -326,11 +323,61 @@ def get_enhanced_context_info(text):
326
  enhanced["topics"] = analyze_topics(text, num_topics=5)
327
  return enhanced
328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  def analyze_risk_enhanced(text):
330
  enhanced = get_enhanced_context_info(text)
331
  avg_sentiment = enhanced["average_sentiment"]
332
  risk_score = abs(avg_sentiment) if avg_sentiment < 0 else 0
333
- return {"risk_score": risk_score, "average_sentiment": avg_sentiment, "topics": enhanced["topics"]}
 
 
 
 
 
 
 
334
 
335
  def analyze_contract_clauses(text):
336
  max_length = 512
@@ -370,7 +417,6 @@ async def analyze_legal_document(file: UploadFile = File(...)):
370
  try:
371
  content = await file.read()
372
  file_hash = compute_md5(content)
373
- # Return cached result if available
374
  if file_hash in analysis_cache:
375
  return analysis_cache[file_hash]
376
  text = await run_in_threadpool(extract_text_from_pdf, io.BytesIO(content))
@@ -594,10 +640,8 @@ async def download_clause_radar_chart(task_id: str):
594
  clauses = analyze_contract_clauses(text)
595
  if not clauses:
596
  raise HTTPException(status_code=404, detail="No clauses detected.")
597
- # For radar chart, use clause types and their confidence scores
598
  labels = [c["type"] for c in clauses]
599
  values = [c["confidence"] for c in clauses]
600
- # To close the radar chart, repeat the first value and label
601
  labels += labels[:1]
602
  values += values[:1]
603
  angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
 
13
  import json
14
  import tempfile
15
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form, BackgroundTasks
16
+ from fastapi.responses import FileResponse, JSONResponse, HTMLResponse
17
  from fastapi.middleware.cors import CORSMiddleware
18
  from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
19
  from sentence_transformers import SentenceTransformer
 
31
  import gensim
32
  from gensim import corpora, models
33
 
34
+ # Import spacy stop words
35
+ from spacy.lang.en.stop_words import STOP_WORDS
36
+
37
  # Global cache for analysis results based on file hash
38
  analysis_cache = {}
39
 
 
200
  nlp = spacy.load("en_core_web_sm")
201
  print("✅ Loading NLP models...")
202
 
203
+ # Update summarizer to use facebook/bart-large-cnn for summarization
 
204
  summarizer = pipeline(
205
  "summarization",
206
+ model="facebook/bart-large-cnn",
207
+ tokenizer="facebook/bart-large-cnn",
208
  device=0 if torch.cuda.is_available() else -1
209
  )
 
210
  if device == "cuda":
211
  try:
212
  summarizer.model.half()
 
236
 
237
  from transformers import pipeline
238
  qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
 
 
239
  sentiment_pipeline = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=0 if torch.cuda.is_available() else -1)
240
 
241
  def legal_chatbot(user_input, context):
 
262
  "-acodec", "pcm_s16le", "-ar", "44100", "-ac", "2",
263
  temp_audio_path, "-y"
264
  ]
 
265
  await run_in_threadpool(subprocess.run, cmd, check=True)
266
  print(f"Audio extracted to {temp_audio_path}")
 
267
  result = await run_in_threadpool(speech_to_text, temp_audio_path)
268
  transcript = result["text"]
269
  print(f"Transcription completed: {len(transcript)} characters")
 
323
  enhanced["topics"] = analyze_topics(text, num_topics=5)
324
  return enhanced
325
 
326
+ # New function to create a detailed, dynamic explanation for each topic
327
+ def explain_topics(topics):
328
+ explanation = {}
329
+ for topic_idx, topic_str in topics:
330
+ # Split topic string into individual weighted terms
331
+ parts = topic_str.split('+')
332
+ terms = []
333
+ for part in parts:
334
+ part = part.strip()
335
+ if '*' in part:
336
+ weight_str, word = part.split('*', 1)
337
+ word = word.strip().strip('\"').strip('\'')
338
+ try:
339
+ weight = float(weight_str)
340
+ except:
341
+ weight = 0.0
342
+ # Filter out common stop words
343
+ if word.lower() not in STOP_WORDS and len(word) > 1:
344
+ terms.append((weight, word))
345
+ terms.sort(key=lambda x: -x[0])
346
+ # Create a plain language label based on dominant words
347
+ if terms:
348
+ if any("liability" in word.lower() for weight, word in terms):
349
+ label = "Liability & Penalty Risk"
350
+ elif any("termination" in word.lower() for weight, word in terms):
351
+ label = "Termination & Refund Risk"
352
+ elif any("compliance" in word.lower() for weight, word in terms):
353
+ label = "Compliance & Regulatory Risk"
354
+ else:
355
+ label = "General Risk Language"
356
+ else:
357
+ label = "General Risk Language"
358
+ explanation_text = (
359
+ f"Topic {topic_idx} ({label}) is characterized by dominant terms: " +
360
+ ", ".join([f"'{word}' ({weight:.3f})" for weight, word in terms[:5]])
361
+ )
362
+ explanation[topic_idx] = {
363
+ "label": label,
364
+ "explanation": explanation_text,
365
+ "terms": terms
366
+ }
367
+ return explanation
368
+
369
  def analyze_risk_enhanced(text):
370
  enhanced = get_enhanced_context_info(text)
371
  avg_sentiment = enhanced["average_sentiment"]
372
  risk_score = abs(avg_sentiment) if avg_sentiment < 0 else 0
373
+ topics_raw = enhanced["topics"]
374
+ topics_explanation = explain_topics(topics_raw)
375
+ return {
376
+ "risk_score": risk_score,
377
+ "average_sentiment": avg_sentiment,
378
+ "topics": topics_raw,
379
+ "topics_explanation": topics_explanation
380
+ }
381
 
382
  def analyze_contract_clauses(text):
383
  max_length = 512
 
417
  try:
418
  content = await file.read()
419
  file_hash = compute_md5(content)
 
420
  if file_hash in analysis_cache:
421
  return analysis_cache[file_hash]
422
  text = await run_in_threadpool(extract_text_from_pdf, io.BytesIO(content))
 
640
  clauses = analyze_contract_clauses(text)
641
  if not clauses:
642
  raise HTTPException(status_code=404, detail="No clauses detected.")
 
643
  labels = [c["type"] for c in clauses]
644
  values = [c["confidence"] for c in clauses]
 
645
  labels += labels[:1]
646
  values += values[:1]
647
  angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()