tejash300 hardik8588 commited on
Commit
fcda643
·
verified ·
1 Parent(s): ad0c9e7

Upload app.py (#3)

Browse files

- Upload app.py (1b15cff9f9517ef9ca83d0aeb7f34dc1b9145fe5)


Co-authored-by: hardik kandpal <[email protected]>

Files changed (1) hide show
  1. app.py +1526 -0
app.py ADDED
@@ -0,0 +1,1526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+
3
+ import os
4
+ import io
5
+ import time
6
+ import uuid
7
+ import tempfile
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import pdfplumber
11
+ import spacy
12
+ import torch
13
+ import sqlite3
14
+ import uvicorn
15
+ import moviepy.editor as mp
16
+ from threading import Thread
17
+ from datetime import datetime, timedelta
18
+ from typing import List, Dict, Optional
19
+ from fastapi import FastAPI, File, UploadFile, Form, Depends, HTTPException, status, Header
20
+ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
21
+ from fastapi.staticfiles import StaticFiles
22
+ from fastapi.middleware.cors import CORSMiddleware
23
+ import logging
24
+ from pydantic import BaseModel
25
+ from transformers import (
26
+ AutoTokenizer,
27
+ AutoModelForQuestionAnswering,
28
+ pipeline,
29
+ TrainingArguments,
30
+ Trainer
31
+ )
32
+ from sentence_transformers import SentenceTransformer
33
+ from passlib.context import CryptContext
34
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
35
+ import jwt
36
+ from dotenv import load_dotenv
37
+ # Import get_db_connection from auth
38
+ from auth import (
39
+ User, UserCreate, Token, get_current_active_user, authenticate_user,
40
+ create_access_token, hash_password, register_user, check_subscription_access,
41
+ SUBSCRIPTION_TIERS, JWT_EXPIRATION_DELTA, get_db_connection, update_auth_db_schema
42
+ )
43
+ # Add this import near the top with your other imports
44
+ from paypal_integration import (
45
+ create_user_subscription, verify_subscription_payment,
46
+ update_user_subscription, handle_subscription_webhook, initialize_database
47
+ )
48
+ from fastapi import Request # Add this if not already imported
49
+
50
+ logging.basicConfig(
51
+ level=logging.INFO,
52
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
53
+ )
54
+ logger = logging.getLogger("app")
55
+
56
+ # Initialize the database
57
+ # Initialize FastAPI app
58
+ app = FastAPI(
59
+ title="Legal Document Analysis API",
60
+ description="API for analyzing legal documents, videos, and audio",
61
+ version="1.0.0"
62
+ )
63
+
64
+ # Set up CORS middleware
65
+ app.add_middleware(
66
+ CORSMiddleware,
67
+ allow_origins=["https://testing-78wtxfqt0-hardikkandpals-projects.vercel.app", "http://localhost:3000"], # Frontend URL
68
+ allow_credentials=True,
69
+ allow_methods=["*"],
70
+ allow_headers=["*"],
71
+ )
72
+ initialize_database()
73
+ try:
74
+ update_auth_db_schema()
75
+ logger.info("Database schema updated successfully")
76
+ except Exception as e:
77
+ logger.error(f"Database schema update error: {e}")
78
+
79
+ # Create static directory for file storage
80
+ os.makedirs("static", exist_ok=True)
81
+ os.makedirs("uploads", exist_ok=True)
82
+ os.makedirs("temp", exist_ok=True)
83
+ app.mount("/static", StaticFiles(directory="static"), name="static")
84
+
85
+ # Set device for model inference
86
+ device = "cuda" if torch.cuda.is_available() else "cpu"
87
+ print(f"Using device: {device}")
88
+
89
+ # Initialize chat history
90
+ chat_history = []
91
+
92
+ # Document context storage
93
+ document_contexts = {}
94
+
95
+ def store_document_context(task_id, text):
96
+ """Store document text for later retrieval."""
97
+ document_contexts[task_id] = text
98
+
99
+ def load_document_context(task_id):
100
+ """Load document text for a given task ID."""
101
+ return document_contexts.get(task_id, "")
102
+
103
+ def get_db_connection():
104
+ """Get a connection to the SQLite database."""
105
+ db_path = os.path.join(os.path.dirname(__file__), "legal_analysis.db")
106
+ conn = sqlite3.connect(db_path)
107
+ conn.row_factory = sqlite3.Row
108
+ return conn
109
+
110
+ load_dotenv()
111
+ DB_PATH = os.getenv("DB_PATH", os.path.join(os.path.dirname(__file__), "data/user_data.db"))
112
+ os.makedirs(os.path.join(os.path.dirname(__file__), "data"), exist_ok=True)
113
+
114
+ def fine_tune_qa_model():
115
+ """Fine-tunes a QA model on the CUAD dataset."""
116
+ print("Loading base model for fine-tuning...")
117
+ tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
118
+ model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
119
+
120
+ # Load and preprocess CUAD dataset
121
+ print("Loading CUAD dataset...")
122
+ from datasets import load_dataset
123
+
124
+ try:
125
+ dataset = load_dataset("cuad")
126
+ except Exception as e:
127
+ print(f"Error loading CUAD dataset: {str(e)}")
128
+ print("Downloading CUAD dataset from alternative source...")
129
+ # Implement alternative dataset loading here
130
+ return tokenizer, model
131
+
132
+ print(f"Dataset loaded with {len(dataset['train'])} training examples")
133
+
134
+ # Preprocess the dataset
135
+ def preprocess_function(examples):
136
+ questions = [q.strip() for q in examples["question"]]
137
+ contexts = [c.strip() for c in examples["context"]]
138
+
139
+ inputs = tokenizer(
140
+ questions,
141
+ contexts,
142
+ max_length=384,
143
+ truncation="only_second",
144
+ stride=128,
145
+ return_overflowing_tokens=True,
146
+ return_offsets_mapping=True,
147
+ padding="max_length",
148
+ )
149
+
150
+ offset_mapping = inputs.pop("offset_mapping")
151
+ sample_map = inputs.pop("overflow_to_sample_mapping")
152
+
153
+ answers = examples["answers"]
154
+ start_positions = []
155
+ end_positions = []
156
+
157
+ for i, offset in enumerate(offset_mapping):
158
+ sample_idx = sample_map[i]
159
+ answer = answers[sample_idx]
160
+
161
+ start_char = answer["answer_start"][0] if len(answer["answer_start"]) > 0 else 0
162
+ end_char = start_char + len(answer["text"][0]) if len(answer["text"]) > 0 else 0
163
+
164
+ sequence_ids = inputs.sequence_ids(i)
165
+
166
+ # Find the start and end of the context
167
+ idx = 0
168
+ while sequence_ids[idx] != 1:
169
+ idx += 1
170
+ context_start = idx
171
+
172
+ while idx < len(sequence_ids) and sequence_ids[idx] == 1:
173
+ idx += 1
174
+ context_end = idx - 1
175
+
176
+ # If the answer is not fully inside the context, label is (0, 0)
177
+ if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
178
+ start_positions.append(0)
179
+ end_positions.append(0)
180
+ else:
181
+ # Otherwise it's the start and end token positions
182
+ idx = context_start
183
+ while idx <= context_end and offset[idx][0] <= start_char:
184
+ idx += 1
185
+ start_positions.append(idx - 1)
186
+
187
+ idx = context_end
188
+ while idx >= context_start and offset[idx][1] >= end_char:
189
+ idx -= 1
190
+ end_positions.append(idx + 1)
191
+
192
+ inputs["start_positions"] = start_positions
193
+ inputs["end_positions"] = end_positions
194
+ return inputs
195
+
196
+ print("Preprocessing dataset...")
197
+ processed_dataset = dataset.map(
198
+ preprocess_function,
199
+ batched=True,
200
+ remove_columns=dataset["train"].column_names,
201
+ )
202
+
203
+ print("Splitting dataset...")
204
+ train_dataset = processed_dataset["train"]
205
+ val_dataset = processed_dataset["validation"]
206
+
207
+ train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
208
+ val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
209
+
210
+ training_args = TrainingArguments(
211
+ output_dir="./fine_tuned_legal_qa",
212
+ evaluation_strategy="steps",
213
+ eval_steps=100,
214
+ learning_rate=2e-5,
215
+ per_device_train_batch_size=16,
216
+ per_device_eval_batch_size=16,
217
+ num_train_epochs=1,
218
+ weight_decay=0.01,
219
+ logging_steps=50,
220
+ save_steps=100,
221
+ load_best_model_at_end=True,
222
+ report_to=[]
223
+ )
224
+
225
+ print("✅ Starting fine tuning on CUAD QA dataset...")
226
+ trainer = Trainer(
227
+ model=model,
228
+ args=training_args,
229
+ train_dataset=train_dataset,
230
+ eval_dataset=val_dataset,
231
+ tokenizer=tokenizer,
232
+ )
233
+
234
+ trainer.train()
235
+ print("✅ Fine tuning completed. Saving model...")
236
+
237
+ model.save_pretrained("./fine_tuned_legal_qa")
238
+ tokenizer.save_pretrained("./fine_tuned_legal_qa")
239
+
240
+ return tokenizer, model
241
+
242
+ #############################
243
+ # Load NLP Models #
244
+ #############################
245
+
246
+ # Initialize model variables
247
+ nlp = None
248
+ summarizer = None
249
+ embedding_model = None
250
+ ner_model = None
251
+ speech_to_text = None
252
+ cuad_model = None
253
+ cuad_tokenizer = None
254
+ qa_model = None
255
+
256
+ # Add model caching functionality
257
+ import pickle
258
+ import os.path
259
+
260
+ #MODELS_CACHE_DIR = "c:\\Users\\hardi\\OneDrive\\Desktop\\New folder (7)\\doc-vid-analyze-main\\models_cache"
261
+ MODELS_CACHE_DIR = os.getenv("MODELS_CACHE_DIR", "models_cache")
262
+ os.makedirs(MODELS_CACHE_DIR, exist_ok=True)
263
+
264
+ def download_model_from_hub(model_id, subfolder=None):
265
+ """Download a model from Hugging Face Hub"""
266
+ try:
267
+ local_dir = snapshot_download(
268
+ repo_id=model_id,
269
+ subfolder=subfolder,
270
+ local_dir=os.path.join(MODELS_CACHE_DIR, model_id.replace("/", "_"))
271
+ )
272
+ print(f"✅ Downloaded model {model_id} to {local_dir}")
273
+ return local_dir
274
+ except Exception as e:
275
+ print(f"⚠️ Error downloading model {model_id}: {str(e)}")
276
+ return None
277
+
278
+
279
+ def save_model_to_cache(model, model_name):
280
+ """Save a model to the cache directory"""
281
+ try:
282
+ cache_path = os.path.join(MODELS_CACHE_DIR, f"{model_name}.pkl")
283
+ with open(cache_path, 'wb') as f:
284
+ pickle.dump(model, f)
285
+ print(f"✅ Saved {model_name} to cache")
286
+ return True
287
+ except Exception as e:
288
+ print(f"⚠️ Failed to save {model_name} to cache: {str(e)}")
289
+ return False
290
+
291
+ def load_model_from_cache(model_name):
292
+ """Load a model from the cache directory"""
293
+ try:
294
+ cache_path = os.path.join(MODELS_CACHE_DIR, f"{model_name}.pkl")
295
+ if os.path.exists(cache_path):
296
+ with open(cache_path, 'rb') as f:
297
+ model = pickle.load(f)
298
+ print(f"✅ Loaded {model_name} from cache")
299
+ return model
300
+ return None
301
+ except Exception as e:
302
+ print(f"⚠️ Failed to load {model_name} from cache: {str(e)}")
303
+ return None
304
+
305
+ # Add a flag to control model loading
306
+ LOAD_MODELS = os.getenv("LOAD_MODELS", "True").lower() in ("true", "1", "t")
307
+
308
+ try:
309
+ if LOAD_MODELS:
310
+ # Try to load SpaCy from cache first
311
+ nlp = load_model_from_cache("spacy_model")
312
+ if nlp is None:
313
+ try:
314
+ nlp = spacy.load("en_core_web_sm")
315
+ save_model_to_cache(nlp, "spacy_model")
316
+ except:
317
+ print("⚠️ SpaCy model not found, downloading...")
318
+ spacy.cli.download("en_core_web_sm")
319
+ nlp = spacy.load("en_core_web_sm")
320
+ save_model_to_cache(nlp, "spacy_model")
321
+
322
+ print("✅ Loading NLP models...")
323
+
324
+ # Load the summarizer with caching
325
+ print("Loading summarizer model...")
326
+ summarizer = load_model_from_cache("summarizer_model")
327
+ if summarizer is None:
328
+ try:
329
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn",
330
+ device=0 if torch.cuda.is_available() else -1)
331
+ save_model_to_cache(summarizer, "summarizer_model")
332
+ print("✅ Summarizer loaded successfully")
333
+ except Exception as e:
334
+ print(f"⚠️ Error loading summarizer: {str(e)}")
335
+ try:
336
+ print("Trying alternative summarizer model...")
337
+ summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6",
338
+ device=0 if torch.cuda.is_available() else -1)
339
+ save_model_to_cache(summarizer, "summarizer_model")
340
+ print("✅ Alternative summarizer loaded successfully")
341
+ except Exception as e2:
342
+ print(f"⚠️ Error loading alternative summarizer: {str(e2)}")
343
+ summarizer = None
344
+
345
+ # Load the embedding model with caching
346
+ print("Loading embedding model...")
347
+ embedding_model = load_model_from_cache("embedding_model")
348
+ if embedding_model is None:
349
+ try:
350
+ embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
351
+ save_model_to_cache(embedding_model, "embedding_model")
352
+ print("✅ Embedding model loaded successfully")
353
+ except Exception as e:
354
+ print(f"⚠️ Error loading embedding model: {str(e)}")
355
+ embedding_model = None
356
+
357
+ # Load the NER model with caching
358
+ print("Loading NER model...")
359
+ ner_model = load_model_from_cache("ner_model")
360
+ if ner_model is None:
361
+ try:
362
+ ner_model = pipeline("ner", model="dslim/bert-base-NER",
363
+ device=0 if torch.cuda.is_available() else -1)
364
+ save_model_to_cache(ner_model, "ner_model")
365
+ print("✅ NER model loaded successfully")
366
+ except Exception as e:
367
+ print(f"⚠️ Error loading NER model: {str(e)}")
368
+ ner_model = None
369
+
370
+ # Speech to text model with caching
371
+ print("Loading speech to text model...")
372
+ speech_to_text = load_model_from_cache("speech_to_text_model")
373
+ if speech_to_text is None:
374
+ try:
375
+ speech_to_text = pipeline("automatic-speech-recognition",
376
+ model="openai/whisper-medium",
377
+ chunk_length_s=30,
378
+ device_map="auto" if torch.cuda.is_available() else "cpu")
379
+ save_model_to_cache(speech_to_text, "speech_to_text_model")
380
+ print("✅ Speech to text model loaded successfully")
381
+ except Exception as e:
382
+ print(f"⚠️ Error loading speech to text model: {str(e)}")
383
+ speech_to_text = None
384
+
385
+ # Load the fine-tuned model with caching
386
+ print("Loading fine-tuned CUAD QA model...")
387
+ cuad_model = load_model_from_cache("cuad_model")
388
+ cuad_tokenizer = load_model_from_cache("cuad_tokenizer")
389
+
390
+ if cuad_model is None or cuad_tokenizer is None:
391
+ try:
392
+ cuad_tokenizer = AutoTokenizer.from_pretrained("hardik8588/fine-tuned-legal-qa")
393
+ from transformers import AutoModelForQuestionAnswering
394
+ cuad_model = AutoModelForQuestionAnswering.from_pretrained("hardik8588/fine-tuned-legal-qa")
395
+ cuad_model.to(device)
396
+ save_model_to_cache(cuad_tokenizer, "cuad_tokenizer")
397
+ save_model_to_cache(cuad_model, "cuad_model")
398
+ print("✅ Successfully loaded fine-tuned model")
399
+ except Exception as e:
400
+ print(f"⚠️ Error loading fine-tuned model: {str(e)}")
401
+ print("⚠️ Falling back to pre-trained model...")
402
+ try:
403
+ cuad_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
404
+ from transformers import AutoModelForQuestionAnswering
405
+ cuad_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
406
+ cuad_model.to(device)
407
+ save_model_to_cache(cuad_tokenizer, "cuad_tokenizer")
408
+ save_model_to_cache(cuad_model, "cuad_model")
409
+ print("✅ Pre-trained model loaded successfully")
410
+ except Exception as e2:
411
+ print(f"⚠️ Error loading pre-trained model: {str(e2)}")
412
+ cuad_model = None
413
+ cuad_tokenizer = None
414
+
415
+ # Load a general QA model with caching
416
+ print("Loading general QA model...")
417
+ qa_model = load_model_from_cache("qa_model")
418
+ if qa_model is None:
419
+ try:
420
+ qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
421
+ save_model_to_cache(qa_model, "qa_model")
422
+ print("✅ QA model loaded successfully")
423
+ except Exception as e:
424
+ print(f"⚠️ Error loading QA model: {str(e)}")
425
+ qa_model = None
426
+
427
+ print("✅ All models loaded successfully")
428
+ else:
429
+ print("⚠️ Model loading skipped (LOAD_MODELS=False)")
430
+
431
+ except Exception as e:
432
+ print(f"⚠️ Error loading models: {str(e)}")
433
+ # Instead of raising an error, set fallback behavior
434
+ nlp = None
435
+ summarizer = None
436
+ embedding_model = None
437
+ ner_model = None
438
+ speech_to_text = None
439
+ cuad_model = None
440
+ cuad_tokenizer = None
441
+ qa_model = None
442
+ print("⚠️ Running with limited functionality due to model loading errors")
443
+
444
+ def legal_chatbot(user_input, context):
445
+ """Uses a real NLP model for legal Q&A."""
446
+ global chat_history
447
+ chat_history.append({"role": "user", "content": user_input})
448
+ response = qa_model(question=user_input, context=context)["answer"]
449
+ chat_history.append({"role": "assistant", "content": response})
450
+ return response
451
+
452
+ def extract_text_from_pdf(pdf_file):
453
+ """Extracts text from a PDF file using pdfplumber."""
454
+ try:
455
+ # Suppress pdfplumber warnings about CropBox
456
+ import logging
457
+ logging.getLogger("pdfminer").setLevel(logging.ERROR)
458
+
459
+ with pdfplumber.open(pdf_file) as pdf:
460
+ print(f"Processing PDF with {len(pdf.pages)} pages")
461
+ text = ""
462
+ for i, page in enumerate(pdf.pages):
463
+ page_text = page.extract_text() or ""
464
+ text += page_text + "\n"
465
+ if (i + 1) % 10 == 0: # Log progress every 10 pages
466
+ print(f"Processed {i + 1} pages...")
467
+
468
+ print(f"✅ PDF text extraction complete: {len(text)} characters extracted")
469
+ return text.strip() if text else None
470
+ except Exception as e:
471
+ print(f"❌ PDF extraction error: {str(e)}")
472
+ raise HTTPException(status_code=400, detail=f"PDF extraction failed: {str(e)}")
473
+
474
+ def process_video_to_text(video_file_path):
475
+ """Extract audio from video and convert to text."""
476
+ try:
477
+ print(f"Processing video file at {video_file_path}")
478
+ temp_audio_path = os.path.join("temp", "extracted_audio.wav")
479
+ video = mp.VideoFileClip(video_file_path)
480
+ video.audio.write_audiofile(temp_audio_path, codec='pcm_s16le')
481
+ print(f"Audio extracted to {temp_audio_path}")
482
+ result = speech_to_text(temp_audio_path)
483
+ transcript = result["text"]
484
+ print(f"Transcription completed: {len(transcript)} characters")
485
+ if os.path.exists(temp_audio_path):
486
+ os.remove(temp_audio_path)
487
+ return transcript
488
+ except Exception as e:
489
+ print(f"Error in video processing: {str(e)}")
490
+ raise HTTPException(status_code=400, detail=f"Video processing failed: {str(e)}")
491
+
492
+ def process_audio_to_text(audio_file_path):
493
+ """Process audio file and convert to text."""
494
+ try:
495
+ print(f"Processing audio file at {audio_file_path}")
496
+ result = speech_to_text(audio_file_path)
497
+ transcript = result["text"]
498
+ print(f"Transcription completed: {len(transcript)} characters")
499
+ return transcript
500
+ except Exception as e:
501
+ print(f"Error in audio processing: {str(e)}")
502
+ raise HTTPException(status_code=400, detail=f"Audio processing failed: {str(e)}")
503
+
504
+ def extract_named_entities(text):
505
+ """Extracts named entities from legal text."""
506
+ max_length = 10000
507
+ entities = []
508
+ for i in range(0, len(text), max_length):
509
+ chunk = text[i:i+max_length]
510
+ doc = nlp(chunk)
511
+ entities.extend([{"entity": ent.text, "label": ent.label_} for ent in doc.ents])
512
+ return entities
513
+
514
+ def analyze_risk(text):
515
+ """Analyzes legal risk in the document using keyword-based analysis."""
516
+ risk_keywords = {
517
+ "Liability": ["liability", "responsible", "responsibility", "legal obligation"],
518
+ "Termination": ["termination", "breach", "contract end", "default"],
519
+ "Indemnification": ["indemnification", "indemnify", "hold harmless", "compensate", "compensation"],
520
+ "Payment Risk": ["payment", "terms", "reimbursement", "fee", "schedule", "invoice", "money"],
521
+ "Insurance": ["insurance", "coverage", "policy", "claims"],
522
+ }
523
+ risk_scores = {category: 0 for category in risk_keywords}
524
+ lower_text = text.lower()
525
+ for category, keywords in risk_keywords.items():
526
+ for keyword in keywords:
527
+ risk_scores[category] += lower_text.count(keyword.lower())
528
+ return risk_scores
529
+
530
+ def extract_context_for_risk_terms(text, risk_keywords, window=1):
531
+ """
532
+ Extracts and summarizes the context around risk terms.
533
+ """
534
+ doc = nlp(text)
535
+ sentences = list(doc.sents)
536
+ risk_contexts = {category: [] for category in risk_keywords}
537
+ for i, sent in enumerate(sentences):
538
+ sent_text_lower = sent.text.lower()
539
+ for category, details in risk_keywords.items():
540
+ for keyword in details["keywords"]:
541
+ if keyword.lower() in sent_text_lower:
542
+ start_idx = max(0, i - window)
543
+ end_idx = min(len(sentences), i + window + 1)
544
+ context_chunk = " ".join([s.text for s in sentences[start_idx:end_idx]])
545
+ risk_contexts[category].append(context_chunk)
546
+ summarized_contexts = {}
547
+ for category, contexts in risk_contexts.items():
548
+ if contexts:
549
+ combined_context = " ".join(contexts)
550
+ try:
551
+ summary_result = summarizer(combined_context, max_length=100, min_length=30, do_sample=False)
552
+ summary = summary_result[0]['summary_text']
553
+ except Exception as e:
554
+ summary = "Context summarization failed."
555
+ summarized_contexts[category] = summary
556
+ else:
557
+ summarized_contexts[category] = "No contextual details found."
558
+ return summarized_contexts
559
+
560
+ def get_detailed_risk_info(text):
561
+ """
562
+ Returns detailed risk information by merging risk scores with descriptive details
563
+ and contextual summaries from the document.
564
+ """
565
+ risk_details = {
566
+ "Liability": {
567
+ "description": "Liability refers to the legal responsibility for losses or damages.",
568
+ "common_concerns": "Broad liability clauses may expose parties to unforeseen risks.",
569
+ "recommendations": "Review and negotiate clear limits on liability.",
570
+ "example": "E.g., 'The party shall be liable for direct damages due to negligence.'"
571
+ },
572
+ "Termination": {
573
+ "description": "Termination involves conditions under which a contract can be ended.",
574
+ "common_concerns": "Unilateral termination rights or ambiguous conditions can be risky.",
575
+ "recommendations": "Ensure termination clauses are balanced and include notice periods.",
576
+ "example": "E.g., 'Either party may terminate the agreement with 30 days notice.'"
577
+ },
578
+ "Indemnification": {
579
+ "description": "Indemnification requires one party to compensate for losses incurred by the other.",
580
+ "common_concerns": "Overly broad indemnification can shift significant risk.",
581
+ "recommendations": "Negotiate clear limits and carve-outs where necessary.",
582
+ "example": "E.g., 'The seller shall indemnify the buyer against claims from product defects.'"
583
+ },
584
+ "Payment Risk": {
585
+ "description": "Payment risk pertains to terms regarding fees, schedules, and reimbursements.",
586
+ "common_concerns": "Vague payment terms or hidden charges increase risk.",
587
+ "recommendations": "Clarify payment conditions and include penalties for delays.",
588
+ "example": "E.g., 'Payments must be made within 30 days, with a 2% late fee thereafter.'"
589
+ },
590
+ "Insurance": {
591
+ "description": "Insurance risk covers the adequacy and scope of required coverage.",
592
+ "common_concerns": "Insufficient insurance can leave parties exposed in unexpected events.",
593
+ "recommendations": "Review insurance requirements to ensure they meet the risk profile.",
594
+ "example": "E.g., 'The contractor must maintain liability insurance with at least $1M coverage.'"
595
+ }
596
+ }
597
+ risk_scores = analyze_risk(text)
598
+ risk_keywords_context = {
599
+ "Liability": {"keywords": ["liability", "responsible", "responsibility", "legal obligation"]},
600
+ "Termination": {"keywords": ["termination", "breach", "contract end", "default"]},
601
+ "Indemnification": {"keywords": ["indemnification", "indemnify", "hold harmless", "compensate", "compensation"]},
602
+ "Payment Risk": {"keywords": ["payment", "terms", "reimbursement", "fee", "schedule", "invoice", "money"]},
603
+ "Insurance": {"keywords": ["insurance", "coverage", "policy", "claims"]}
604
+ }
605
+ risk_contexts = extract_context_for_risk_terms(text, risk_keywords_context, window=1)
606
+ detailed_info = {}
607
+ for risk_term, score in risk_scores.items():
608
+ if score > 0:
609
+ info = risk_details.get(risk_term, {"description": "No details available."})
610
+ detailed_info[risk_term] = {
611
+ "score": score,
612
+ "description": info.get("description", ""),
613
+ "common_concerns": info.get("common_concerns", ""),
614
+ "recommendations": info.get("recommendations", ""),
615
+ "example": info.get("example", ""),
616
+ "context_summary": risk_contexts.get(risk_term, "No context available.")
617
+ }
618
+ return detailed_info
619
+
620
+ def analyze_contract_clauses(text):
621
+ """Analyzes contract clauses using the fine-tuned CUAD QA model."""
622
+ max_length = 512
623
+ step = 256
624
+ clauses_detected = []
625
+ try:
626
+ clause_types = list(cuad_model.config.id2label.values())
627
+ except Exception as e:
628
+ clause_types = [
629
+ "Obligations of Seller", "Governing Law", "Termination", "Indemnification",
630
+ "Confidentiality", "Insurance", "Non-Compete", "Change of Control",
631
+ "Assignment", "Warranty", "Limitation of Liability", "Arbitration",
632
+ "IP Rights", "Force Majeure", "Revenue/Profit Sharing", "Audit Rights"
633
+ ]
634
+ chunks = [text[i:i+max_length] for i in range(0, len(text), step) if i+step < len(text)]
635
+ for chunk in chunks:
636
+ inputs = cuad_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512).to(device)
637
+ with torch.no_grad():
638
+ outputs = cuad_model(**inputs)
639
+ predictions = torch.sigmoid(outputs.start_logits).cpu().numpy()[0]
640
+ for idx, confidence in enumerate(predictions):
641
+ if confidence > 0.5 and idx < len(clause_types):
642
+ clauses_detected.append({"type": clause_types[idx], "confidence": float(confidence)})
643
+ aggregated_clauses = {}
644
+ for clause in clauses_detected:
645
+ clause_type = clause["type"]
646
+ if clause_type not in aggregated_clauses or clause["confidence"] > aggregated_clauses[clause_type]["confidence"]:
647
+ aggregated_clauses[clause_type] = clause
648
+ return list(aggregated_clauses.values())
649
+
650
+ def summarize_text(text):
651
+ """Summarizes legal text using the summarizer model."""
652
+ try:
653
+ if summarizer is None:
654
+ return "Basic analysis (NLP models not available)"
655
+
656
+ # Split text into chunks if it's too long
657
+ max_chunk_size = 1024
658
+ if len(text) > max_chunk_size:
659
+ chunks = [text[i:i+max_chunk_size] for i in range(0, len(text), max_chunk_size)]
660
+ summaries = []
661
+ for chunk in chunks:
662
+ summary = summarizer(chunk, max_length=100, min_length=30, do_sample=False)
663
+ summaries.append(summary[0]['summary_text'])
664
+ return " ".join(summaries)
665
+ else:
666
+ summary = summarizer(text, max_length=100, min_length=30, do_sample=False)
667
+ return summary[0]['summary_text']
668
+ except Exception as e:
669
+ print(f"Error in summarization: {str(e)}")
670
+ return "Summarization failed. Please try again later."
671
+
672
+ @app.post("/analyze_legal_document")
673
+ async def analyze_legal_document(
674
+ file: UploadFile = File(...),
675
+ current_user: User = Depends(get_current_active_user)
676
+ ):
677
+ """Analyzes a legal document (PDF) and returns insights based on subscription tier."""
678
+ try:
679
+ # Calculate file size in MB
680
+ file_content = await file.read()
681
+ file_size_mb = len(file_content) / (1024 * 1024)
682
+
683
+ # Check subscription access for document analysis
684
+ check_subscription_access(current_user, "document_analysis", file_size_mb)
685
+
686
+ print(f"Processing file: {file.filename}")
687
+
688
+ # Create a temporary file to store the uploaded PDF
689
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
690
+ tmp.write(file_content)
691
+ tmp_path = tmp.name
692
+
693
+ # Extract text from PDF
694
+ text = extract_text_from_pdf(tmp_path)
695
+
696
+ # Clean up the temporary file
697
+ os.unlink(tmp_path)
698
+
699
+ if not text:
700
+ raise HTTPException(status_code=400, detail="Could not extract text from PDF")
701
+
702
+ # Generate a task ID
703
+ task_id = str(uuid.uuid4())
704
+
705
+ # Store document context for later retrieval
706
+ store_document_context(task_id, text)
707
+
708
+ # Basic analysis available to all tiers
709
+ summary = summarize_text(text)
710
+ entities = extract_named_entities(text)
711
+ risk_scores = analyze_risk(text)
712
+
713
+ # Prepare response based on subscription tier
714
+ response = {
715
+ "task_id": task_id,
716
+ "summary": summary,
717
+ "entities": entities,
718
+ "risk_assessment": risk_scores,
719
+ "subscription_tier": current_user.subscription_tier
720
+ }
721
+
722
+ # Add premium features if user has access
723
+ if current_user.subscription_tier == "premium_tier":
724
+ # Add detailed risk assessment
725
+ if "detailed_risk_assessment" in SUBSCRIPTION_TIERS[current_user.subscription_tier]["features"]:
726
+ detailed_risk = get_detailed_risk_info(text)
727
+ response["detailed_risk_assessment"] = detailed_risk
728
+
729
+ # Add contract clause analysis
730
+ if "contract_clause_analysis" in SUBSCRIPTION_TIERS[current_user.subscription_tier]["features"]:
731
+ clauses = analyze_contract_clauses(text)
732
+ response["contract_clauses"] = clauses
733
+
734
+ return response
735
+
736
+ except Exception as e:
737
+ print(f"Error analyzing document: {str(e)}")
738
+ raise HTTPException(status_code=500, detail=f"Error analyzing document: {str(e)}")
739
+
740
+ # Add this function to check resource limits based on subscription tier
741
+ def check_resource_limits(user: User, resource_type: str, size_mb: float = None, count: int = 1):
742
+ """
743
+ Check if the user has exceeded their subscription limits for a specific resource
744
+
745
+ Args:
746
+ user: The user making the request
747
+ resource_type: Type of resource (document, video, audio)
748
+ size_mb: Size of the resource in MB
749
+ count: Number of resources being used (default 1)
750
+
751
+ Returns:
752
+ bool: True if within limits, raises HTTPException otherwise
753
+ """
754
+ # Get the user's subscription tier limits
755
+ tier = user.subscription_tier
756
+ tier_limits = SUBSCRIPTION_TIERS.get(tier, SUBSCRIPTION_TIERS["free_tier"])["limits"]
757
+
758
+ # Check size limits
759
+ if size_mb is not None:
760
+ if resource_type == "document" and size_mb > tier_limits["document_size_mb"]:
761
+ raise HTTPException(
762
+ status_code=status.HTTP_403_FORBIDDEN,
763
+ detail=f"Document size exceeds the {tier_limits['document_size_mb']}MB limit for your {tier} subscription"
764
+ )
765
+ elif resource_type == "video" and size_mb > tier_limits["video_size_mb"]:
766
+ raise HTTPException(
767
+ status_code=status.HTTP_403_FORBIDDEN,
768
+ detail=f"Video size exceeds the {tier_limits['video_size_mb']}MB limit for your {tier} subscription"
769
+ )
770
+ elif resource_type == "audio" and size_mb > tier_limits["audio_size_mb"]:
771
+ raise HTTPException(
772
+ status_code=status.HTTP_403_FORBIDDEN,
773
+ detail=f"Audio size exceeds the {tier_limits['audio_size_mb']}MB limit for your {tier} subscription"
774
+ )
775
+
776
+ # Check monthly document count
777
+ if resource_type == "document":
778
+ # Get current month and year
779
+ now = datetime.now()
780
+ month, year = now.month, now.year
781
+
782
+ # Check usage stats for current month
783
+ conn = get_db_connection()
784
+ cursor = conn.cursor()
785
+ cursor.execute(
786
+ "SELECT analyses_used FROM usage_stats WHERE user_id = ? AND month = ? AND year = ?",
787
+ (user.id, month, year)
788
+ )
789
+ result = cursor.fetchone()
790
+
791
+ current_usage = result[0] if result else 0
792
+
793
+ # Check if adding this usage would exceed the limit
794
+ if current_usage + count > tier_limits["documents_per_month"]:
795
+ conn.close()
796
+ raise HTTPException(
797
+ status_code=status.HTTP_403_FORBIDDEN,
798
+ detail=f"You have reached your monthly limit of {tier_limits['documents_per_month']} document analyses for your {tier} subscription"
799
+ )
800
+
801
+ # Update usage stats
802
+ if result:
803
+ cursor.execute(
804
+ "UPDATE usage_stats SET analyses_used = ? WHERE user_id = ? AND month = ? AND year = ?",
805
+ (current_usage + count, user.id, month, year)
806
+ )
807
+ else:
808
+ usage_id = str(uuid.uuid4())
809
+ cursor.execute(
810
+ "INSERT INTO usage_stats (id, user_id, month, year, analyses_used) VALUES (?, ?, ?, ?, ?)",
811
+ (usage_id, user.id, month, year, count)
812
+ )
813
+
814
+ conn.commit()
815
+ conn.close()
816
+
817
+ # Check if feature is available in the tier
818
+ if resource_type == "video" and tier_limits["video_size_mb"] == 0:
819
+ raise HTTPException(
820
+ status_code=status.HTTP_403_FORBIDDEN,
821
+ detail=f"Video analysis is not available in your {tier} subscription"
822
+ )
823
+
824
+ if resource_type == "audio" and tier_limits["audio_size_mb"] == 0:
825
+ raise HTTPException(
826
+ status_code=status.HTTP_403_FORBIDDEN,
827
+ detail=f"Audio analysis is not available in your {tier} subscription"
828
+ )
829
+
830
+ return True
831
+
832
+ @app.post("/analyze_legal_video")
833
+ async def analyze_legal_video(
834
+ file: UploadFile = File(...),
835
+ current_user: User = Depends(get_current_active_user)
836
+ ):
837
+ """Analyzes legal video by transcribing and analyzing the transcript."""
838
+ try:
839
+ # Calculate file size in MB
840
+ file_content = await file.read()
841
+ file_size_mb = len(file_content) / (1024 * 1024)
842
+
843
+ # Check subscription access for video analysis
844
+ check_subscription_access(current_user, "video_analysis", file_size_mb)
845
+
846
+ print(f"Processing video file: {file.filename}")
847
+
848
+ # Create a temporary file to store the uploaded video
849
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp:
850
+ tmp.write(file_content)
851
+ tmp_path = tmp.name
852
+
853
+ # Process video to extract transcript
854
+ transcript = process_video_to_text(tmp_path)
855
+
856
+ # Clean up the temporary file
857
+ os.unlink(tmp_path)
858
+
859
+ if not transcript:
860
+ raise HTTPException(status_code=400, detail="Could not extract transcript from video")
861
+
862
+ # Generate a task ID
863
+ task_id = str(uuid.uuid4())
864
+
865
+ # Store document context for later retrieval
866
+ store_document_context(task_id, transcript)
867
+
868
+ # Basic analysis
869
+ summary = summarize_text(transcript)
870
+ entities = extract_named_entities(transcript)
871
+ risk_scores = analyze_risk(transcript)
872
+
873
+ # Prepare response
874
+ response = {
875
+ "task_id": task_id,
876
+ "transcript": transcript,
877
+ "summary": summary,
878
+ "entities": entities,
879
+ "risk_assessment": risk_scores,
880
+ "subscription_tier": current_user.subscription_tier
881
+ }
882
+
883
+ # Add premium features if user has access
884
+ if current_user.subscription_tier == "premium_tier":
885
+ # Add detailed risk assessment
886
+ if "detailed_risk_assessment" in SUBSCRIPTION_TIERS[current_user.subscription_tier]["features"]:
887
+ detailed_risk = get_detailed_risk_info(transcript)
888
+ response["detailed_risk_assessment"] = detailed_risk
889
+
890
+ return response
891
+
892
+ except Exception as e:
893
+ print(f"Error analyzing video: {str(e)}")
894
+ raise HTTPException(status_code=500, detail=f"Error analyzing video: {str(e)}")
895
+
896
+
897
+ @app.post("/legal_chatbot/{task_id}")
898
+ async def chat_with_document(
899
+ task_id: str,
900
+ question: str = Form(...),
901
+ current_user: User = Depends(get_current_active_user)
902
+ ):
903
+ """Chat with a document using the legal chatbot."""
904
+ try:
905
+ # Check if user has access to chatbot feature
906
+ if "chatbot" not in SUBSCRIPTION_TIERS[current_user.subscription_tier]["features"]:
907
+ raise HTTPException(
908
+ status_code=403,
909
+ detail=f"The chatbot feature is not available in your {current_user.subscription_tier} subscription. Please upgrade to access this feature."
910
+ )
911
+
912
+ # Check if document context exists
913
+ context = load_document_context(task_id)
914
+ if not context:
915
+ raise HTTPException(status_code=404, detail="Document context not found. Please analyze a document first.")
916
+
917
+ # Use the chatbot to answer the question
918
+ answer = legal_chatbot(question, context)
919
+
920
+ return {"answer": answer, "chat_history": chat_history}
921
+
922
+ except Exception as e:
923
+ print(f"Error in chatbot: {str(e)}")
924
+ raise HTTPException(status_code=500, detail=f"Error in chatbot: {str(e)}")
925
+
926
+ @app.get("/")
927
+ async def root():
928
+ """Root endpoint that returns a welcome message."""
929
+ return HTMLResponse(content="""
930
+ <html>
931
+ <head>
932
+ <title>Legal Document Analysis API</title>
933
+ <style>
934
+ body {
935
+ font-family: Arial, sans-serif;
936
+ max-width: 800px;
937
+ margin: 0 auto;
938
+ padding: 20px;
939
+ }
940
+ h1 {
941
+ color: #2c3e50;
942
+ }
943
+ .endpoint {
944
+ background-color: #f8f9fa;
945
+ padding: 15px;
946
+ margin-bottom: 10px;
947
+ border-radius: 5px;
948
+ }
949
+ .method {
950
+ font-weight: bold;
951
+ color: #e74c3c;
952
+ }
953
+ </style>
954
+ </head>
955
+ <body>
956
+ <h1>Legal Document Analysis API</h1>
957
+ <p>Welcome to the Legal Document Analysis API. This API provides tools for analyzing legal documents, videos, and audio.</p>
958
+ <h2>Available Endpoints:</h2>
959
+ <div class="endpoint">
960
+ <p><span class="method">POST</span> /analyze_legal_document - Analyze a legal document (PDF)</p>
961
+ </div>
962
+ <div class="endpoint">
963
+ <p><span class="method">POST</span> /analyze_legal_video - Analyze a legal video</p>
964
+ </div>
965
+ <div class="endpoint">
966
+ <p><span class="method">POST</span> /analyze_legal_audio - Analyze legal audio</p>
967
+ </div>
968
+ <div class="endpoint">
969
+ <p><span class="method">POST</span> /legal_chatbot/{task_id} - Chat with a document</p>
970
+ </div>
971
+ <div class="endpoint">
972
+ <p><span class="method">POST</span> /register - Register a new user</p>
973
+ </div>
974
+ <div class="endpoint">
975
+ <p><span class="method">POST</span> /token - Login to get an access token</p>
976
+ </div>
977
+ <div class="endpoint">
978
+ <p><span class="method">GET</span> /users/me - Get current user information</p>
979
+ </div>
980
+ <div class="endpoint">
981
+ <p><span class="method">POST</span> /subscribe/{tier} - Subscribe to a plan</p>
982
+ </div>
983
+ <p>For more details, visit the <a href="/docs">API documentation</a>.</p>
984
+ </body>
985
+ </html>
986
+ """)
987
+
988
+ @app.post("/register", response_model=Token)
989
+ async def register_new_user(user_data: UserCreate):
990
+ """Register a new user with a free subscription"""
991
+ try:
992
+ success, result = register_user(user_data.email, user_data.password)
993
+
994
+ if not success:
995
+ raise HTTPException(status_code=400, detail=result)
996
+
997
+ return {"access_token": result["access_token"], "token_type": "bearer"}
998
+
999
+ except HTTPException:
1000
+ # Re-raise HTTP exceptions
1001
+ raise
1002
+ except Exception as e:
1003
+ print(f"Registration error: {str(e)}")
1004
+ raise HTTPException(status_code=500, detail=f"Registration failed: {str(e)}")
1005
+
1006
+ @app.post("/token", response_model=Token)
1007
+ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
1008
+ """Endpoint for OAuth2 token generation"""
1009
+ try:
1010
+ # Add debug logging
1011
+ logger.info(f"Token request for username: {form_data.username}")
1012
+
1013
+ user = authenticate_user(form_data.username, form_data.password)
1014
+ if not user:
1015
+ logger.warning(f"Authentication failed for: {form_data.username}")
1016
+ raise HTTPException(
1017
+ status_code=status.HTTP_401_UNAUTHORIZED,
1018
+ detail="Incorrect username or password",
1019
+ headers={"WWW-Authenticate": "Bearer"},
1020
+ )
1021
+
1022
+ access_token = create_access_token(user.id)
1023
+ if not access_token:
1024
+ logger.error(f"Failed to create access token for user: {user.id}")
1025
+ raise HTTPException(
1026
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1027
+ detail="Could not create access token",
1028
+ )
1029
+
1030
+ logger.info(f"Login successful for: {form_data.username}")
1031
+ return {"access_token": access_token, "token_type": "bearer"}
1032
+ except Exception as e:
1033
+ logger.error(f"Token endpoint error: {e}")
1034
+ raise HTTPException(
1035
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1036
+ detail=f"Login error: {str(e)}",
1037
+ )
1038
+
1039
+
1040
+ @app.get("/debug/token")
1041
+ async def debug_token(authorization: str = Header(None)):
1042
+ """Debug endpoint to check token validity"""
1043
+ try:
1044
+ if not authorization:
1045
+ return {"valid": False, "error": "No authorization header provided"}
1046
+
1047
+ # Extract token from Authorization header
1048
+ scheme, token = authorization.split()
1049
+ if scheme.lower() != 'bearer':
1050
+ return {"valid": False, "error": "Not a bearer token"}
1051
+
1052
+ # Log the token for debugging
1053
+ logger.info(f"Debugging token: {token[:10]}...")
1054
+
1055
+ # Try to validate the token
1056
+ try:
1057
+ user = await get_current_active_user(token)
1058
+ return {"valid": True, "user_id": user.id, "email": user.email}
1059
+ except Exception as e:
1060
+ return {"valid": False, "error": str(e)}
1061
+ except Exception as e:
1062
+ return {"valid": False, "error": f"Token debug error: {str(e)}"}
1063
+
1064
+
1065
+ @app.post("/login")
1066
+ async def api_login(email: str, password: str):
1067
+ success, result = login_user(email, password)
1068
+ if not success:
1069
+ raise HTTPException(
1070
+ status_code=status.HTTP_401_UNAUTHORIZED,
1071
+ detail=result
1072
+ )
1073
+ return result
1074
+
1075
+ @app.get("/health")
1076
+ def health_check():
1077
+ """Simple health check endpoint to verify the API is running"""
1078
+ return {"status": "ok", "message": "API is running"}
1079
+
1080
+ @app.get("/users/me", response_model=User)
1081
+ async def read_users_me(current_user: User = Depends(get_current_active_user)):
1082
+ return current_user
1083
+
1084
+ @app.post("/analyze_legal_audio")
1085
+ async def analyze_legal_audio(
1086
+ file: UploadFile = File(...),
1087
+ current_user: User = Depends(get_current_active_user)
1088
+ ):
1089
+ """Analyzes legal audio by transcribing and analyzing the transcript."""
1090
+ try:
1091
+ # Calculate file size in MB
1092
+ file_content = await file.read()
1093
+ file_size_mb = len(file_content) / (1024 * 1024)
1094
+
1095
+ # Check subscription access for audio analysis
1096
+ check_subscription_access(current_user, "audio_analysis", file_size_mb)
1097
+
1098
+ print(f"Processing audio file: {file.filename}")
1099
+
1100
+ # Create a temporary file to store the uploaded audio
1101
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp:
1102
+ tmp.write(file_content)
1103
+ tmp_path = tmp.name
1104
+
1105
+ # Process audio to extract transcript
1106
+ transcript = process_audio_to_text(tmp_path)
1107
+
1108
+ # Clean up the temporary file
1109
+ os.unlink(tmp_path)
1110
+
1111
+ if not transcript:
1112
+ raise HTTPException(status_code=400, detail="Could not extract transcript from audio")
1113
+
1114
+ # Generate a task ID
1115
+ task_id = str(uuid.uuid4())
1116
+
1117
+ # Store document context for later retrieval
1118
+ store_document_context(task_id, transcript)
1119
+
1120
+ # Basic analysis
1121
+ summary = summarize_text(transcript)
1122
+ entities = extract_named_entities(transcript)
1123
+ risk_scores = analyze_risk(transcript)
1124
+
1125
+ # Prepare response
1126
+ response = {
1127
+ "task_id": task_id,
1128
+ "transcript": transcript,
1129
+ "summary": summary,
1130
+ "entities": entities,
1131
+ "risk_assessment": risk_scores,
1132
+ "subscription_tier": current_user.subscription_tier
1133
+ }
1134
+
1135
+ # Add premium features if user has access
1136
+ if current_user.subscription_tier == "premium_tier": # Change from premium_tier to premium
1137
+ # Add detailed risk assessment
1138
+ if "detailed_risk_assessment" in SUBSCRIPTION_TIERS[current_user.subscription_tier]["features"]:
1139
+ detailed_risk = get_detailed_risk_info(transcript)
1140
+ response["detailed_risk_assessment"] = detailed_risk
1141
+
1142
+ return response
1143
+
1144
+ except Exception as e:
1145
+ print(f"Error analyzing audio: {str(e)}")
1146
+ raise HTTPException(status_code=500, detail=f"Error analyzing audio: {str(e)}")
1147
+
1148
+
1149
+
1150
+ # Add these new endpoints before the if __name__ == "__main__" line
1151
+ @app.get("/users/me/subscription")
1152
+ async def get_user_subscription(current_user: User = Depends(get_current_active_user)):
1153
+ """Get the current user's subscription details"""
1154
+ try:
1155
+ # Get subscription details from database
1156
+ conn = get_db_connection()
1157
+ cursor = conn.cursor()
1158
+
1159
+ # Get the most recent active subscription
1160
+ try:
1161
+ cursor.execute(
1162
+ "SELECT id, tier, status, created_at, expires_at, paypal_subscription_id FROM subscriptions "
1163
+ "WHERE user_id = ? AND status = 'active' ORDER BY created_at DESC LIMIT 1",
1164
+ (current_user.id,)
1165
+ )
1166
+ subscription = cursor.fetchone()
1167
+ except sqlite3.OperationalError as e:
1168
+ # Handle missing tier column
1169
+ if "no such column: tier" in str(e):
1170
+ logger.warning("Subscriptions table missing 'tier' column. Returning default subscription.")
1171
+ subscription = None
1172
+ else:
1173
+ raise
1174
+
1175
+ # Get subscription tiers with pricing directly from SUBSCRIPTION_TIERS
1176
+ subscription_tiers = {
1177
+ "free_tier": {
1178
+ "price": SUBSCRIPTION_TIERS["free_tier"]["price"],
1179
+ "currency": SUBSCRIPTION_TIERS["free_tier"]["currency"],
1180
+ "features": SUBSCRIPTION_TIERS["free_tier"]["features"]
1181
+ },
1182
+ "standard_tier": {
1183
+ "price": SUBSCRIPTION_TIERS["standard_tier"]["price"],
1184
+ "currency": SUBSCRIPTION_TIERS["standard_tier"]["currency"],
1185
+ "features": SUBSCRIPTION_TIERS["standard_tier"]["features"]
1186
+ },
1187
+ "premium_tier": {
1188
+ "price": SUBSCRIPTION_TIERS["premium_tier"]["price"],
1189
+ "currency": SUBSCRIPTION_TIERS["premium_tier"]["currency"],
1190
+ "features": SUBSCRIPTION_TIERS["premium_tier"]["features"]
1191
+ }
1192
+ }
1193
+
1194
+ if subscription:
1195
+ sub_id, tier, status, created_at, expires_at, paypal_id = subscription
1196
+ result = {
1197
+ "id": sub_id,
1198
+ "tier": tier,
1199
+ "status": status,
1200
+ "created_at": created_at,
1201
+ "expires_at": expires_at,
1202
+ "paypal_subscription_id": paypal_id,
1203
+ "current_tier": current_user.subscription_tier,
1204
+ "subscription_tiers": subscription_tiers
1205
+ }
1206
+ else:
1207
+ result = {
1208
+ "tier": "free_tier",
1209
+ "status": "active",
1210
+ "current_tier": current_user.subscription_tier,
1211
+ "subscription_tiers": subscription_tiers
1212
+ }
1213
+
1214
+ conn.close()
1215
+ return result
1216
+ except Exception as e:
1217
+ logger.error(f"Error getting subscription: {str(e)}")
1218
+ raise HTTPException(status_code=500, detail=f"Error getting subscription: {str(e)}")
1219
+ # Add this model definition before your endpoints
1220
+ class SubscriptionCreate(BaseModel):
1221
+ tier: str
1222
+
1223
+ @app.post("/create_subscription")
1224
+ async def create_subscription(
1225
+ subscription: SubscriptionCreate,
1226
+ current_user: User = Depends(get_current_active_user)
1227
+ ):
1228
+ """Create a subscription for the current user"""
1229
+ try:
1230
+ # Log the request for debugging
1231
+ logger.info(f"Creating subscription for user {current_user.email} with tier {subscription.tier}")
1232
+ logger.info(f"Available tiers: {list(SUBSCRIPTION_TIERS.keys())}")
1233
+
1234
+ # Validate tier
1235
+ valid_tiers = ["standard_tier", "premium_tier"]
1236
+ if subscription.tier not in valid_tiers:
1237
+ logger.warning(f"Invalid tier requested: {subscription.tier}")
1238
+ raise HTTPException(status_code=400, detail=f"Invalid tier: {subscription.tier}. Must be one of {valid_tiers}")
1239
+
1240
+ # Create subscription
1241
+ logger.info(f"Calling create_user_subscription with email: {current_user.email}, tier: {subscription.tier}")
1242
+ success, result = create_user_subscription(current_user.email, subscription.tier)
1243
+
1244
+ if not success:
1245
+ logger.error(f"Failed to create subscription: {result}")
1246
+ raise HTTPException(status_code=400, detail=result)
1247
+
1248
+ logger.info(f"Subscription created successfully: {result}")
1249
+ return result
1250
+ except Exception as e:
1251
+ logger.error(f"Error creating subscription: {str(e)}")
1252
+ # Include the full traceback for better debugging
1253
+ import traceback
1254
+ logger.error(f"Traceback: {traceback.format_exc()}")
1255
+ raise HTTPException(status_code=500, detail=f"Error creating subscription: {str(e)}")
1256
+
1257
+ @app.post("/subscribe/{tier}")
1258
+ async def subscribe_to_tier(
1259
+ tier: str,
1260
+ current_user: User = Depends(get_current_active_user)
1261
+ ):
1262
+ """Subscribe to a specific tier"""
1263
+ try:
1264
+ # Validate tier
1265
+ valid_tiers = ["standard_tier", "premium_tier"]
1266
+ if tier not in valid_tiers:
1267
+ raise HTTPException(status_code=400, detail=f"Invalid tier: {tier}. Must be one of {valid_tiers}")
1268
+
1269
+ # Create subscription
1270
+ success, result = create_user_subscription(current_user.email, tier)
1271
+
1272
+ if not success:
1273
+ raise HTTPException(status_code=400, detail=result)
1274
+
1275
+ return result
1276
+ except Exception as e:
1277
+ logger.error(f"Error creating subscription: {str(e)}")
1278
+ raise HTTPException(status_code=500, detail=f"Error creating subscription: {str(e)}")
1279
+
1280
+ @app.post("/subscription/create")
1281
+ async def create_subscription(request: Request, current_user: User = Depends(get_current_active_user)):
1282
+ """Create a subscription for the current user"""
1283
+ try:
1284
+ data = await request.json()
1285
+ tier = data.get("tier")
1286
+
1287
+ if not tier:
1288
+ return JSONResponse(
1289
+ status_code=400,
1290
+ content={"detail": "Tier is required"}
1291
+ )
1292
+
1293
+ # Log the request for debugging
1294
+ logger.info(f"Creating subscription for user {current_user.email} with tier {tier}")
1295
+
1296
+ # Create the subscription using the imported function directly
1297
+ success, result = create_user_subscription(current_user.email, tier)
1298
+
1299
+ if success:
1300
+ # Make sure we're returning the approval_url in the response
1301
+ logger.info(f"Subscription created successfully: {result}")
1302
+ logger.info(f"Approval URL: {result.get('approval_url')}")
1303
+
1304
+ return {
1305
+ "success": True,
1306
+ "data": {
1307
+ "approval_url": result["approval_url"],
1308
+ "subscription_id": result["subscription_id"],
1309
+ "tier": result["tier"]
1310
+ }
1311
+ }
1312
+ else:
1313
+ logger.error(f"Failed to create subscription: {result}")
1314
+ return JSONResponse(
1315
+ status_code=400,
1316
+ content={"success": False, "detail": result}
1317
+ )
1318
+ except Exception as e:
1319
+ logger.error(f"Error creating subscription: {str(e)}")
1320
+ import traceback
1321
+ logger.error(f"Traceback: {traceback.format_exc()}")
1322
+ return JSONResponse(
1323
+ status_code=500,
1324
+ content={"success": False, "detail": f"Error creating subscription: {str(e)}"}
1325
+ )
1326
+
1327
+ @app.post("/admin/initialize-paypal-plans")
1328
+ async def initialize_paypal_plans(request: Request):
1329
+ """Initialize PayPal subscription plans"""
1330
+ try:
1331
+ # This should be protected with admin authentication in production
1332
+ plans = initialize_subscription_plans()
1333
+
1334
+ if plans:
1335
+ return JSONResponse(
1336
+ status_code=200,
1337
+ content={"success": True, "plans": plans}
1338
+ )
1339
+ else:
1340
+ return JSONResponse(
1341
+ status_code=500,
1342
+ content={"success": False, "detail": "Failed to initialize plans"}
1343
+ )
1344
+ except Exception as e:
1345
+ logger.error(f"Error initializing PayPal plans: {str(e)}")
1346
+ return JSONResponse(
1347
+ status_code=500,
1348
+ content={"success": False, "detail": f"Error initializing plans: {str(e)}"}
1349
+ )
1350
+
1351
+
1352
+ @app.post("/subscription/verify")
1353
+ async def verify_subscription(request: Request, current_user: User = Depends(get_current_active_user)):
1354
+ """Verify a subscription after payment"""
1355
+ try:
1356
+ data = await request.json()
1357
+ subscription_id = data.get("subscription_id")
1358
+
1359
+ if not subscription_id:
1360
+ return JSONResponse(
1361
+ status_code=400,
1362
+ content={"success": False, "detail": "Subscription ID is required"}
1363
+ )
1364
+
1365
+ logger.info(f"Verifying subscription: {subscription_id}")
1366
+
1367
+ # Verify the subscription with PayPal
1368
+ success, result = verify_paypal_subscription(subscription_id)
1369
+
1370
+ if not success:
1371
+ logger.error(f"Subscription verification failed: {result}")
1372
+ return JSONResponse(
1373
+ status_code=400,
1374
+ content={"success": False, "detail": str(result)}
1375
+ )
1376
+
1377
+ # Update the user's subscription in the database
1378
+ conn = get_db_connection()
1379
+ cursor = conn.cursor()
1380
+
1381
+ # Get the subscription details
1382
+ cursor.execute(
1383
+ "SELECT tier FROM subscriptions WHERE paypal_subscription_id = ?",
1384
+ (subscription_id,)
1385
+ )
1386
+ subscription = cursor.fetchone()
1387
+
1388
+ if not subscription:
1389
+ # This is a new subscription, get the tier from the PayPal response
1390
+ tier = "standard_tier" # Default to standard tier
1391
+ # You could extract the tier from the PayPal plan ID if needed
1392
+
1393
+ # Create a new subscription record
1394
+ sub_id = str(uuid.uuid4())
1395
+ start_date = datetime.now()
1396
+ expires_at = start_date + timedelta(days=30)
1397
+
1398
+ cursor.execute(
1399
+ "INSERT INTO subscriptions (id, user_id, tier, status, created_at, expires_at, paypal_subscription_id) VALUES (?, ?, ?, ?, ?, ?, ?)",
1400
+ (sub_id, current_user.id, tier, "active", start_date, expires_at, subscription_id)
1401
+ )
1402
+ else:
1403
+ # Update existing subscription
1404
+ tier = subscription[0]
1405
+ cursor.execute(
1406
+ "UPDATE subscriptions SET status = 'active' WHERE paypal_subscription_id = ?",
1407
+ (subscription_id,)
1408
+ )
1409
+
1410
+ # Update user's subscription tier
1411
+ cursor.execute(
1412
+ "UPDATE users SET subscription_tier = ? WHERE id = ?",
1413
+ (tier, current_user.id)
1414
+ )
1415
+
1416
+ conn.commit()
1417
+ conn.close()
1418
+
1419
+ return JSONResponse(
1420
+ status_code=200,
1421
+ content={"success": True, "detail": "Subscription verified successfully"}
1422
+ )
1423
+
1424
+ except Exception as e:
1425
+ logger.error(f"Error verifying subscription: {str(e)}")
1426
+ return JSONResponse(
1427
+ status_code=500,
1428
+ content={"success": False, "detail": f"Error verifying subscription: {str(e)}"}
1429
+ )
1430
+
1431
+ @app.post("/subscription/webhook")
1432
+ async def subscription_webhook(request: Request):
1433
+ """Handle PayPal subscription webhooks"""
1434
+ try:
1435
+ payload = await request.json()
1436
+ success, result = handle_subscription_webhook(payload)
1437
+
1438
+ if not success:
1439
+ logger.error(f"Webhook processing failed: {result}")
1440
+ return {"status": "error", "message": result}
1441
+
1442
+ return {"status": "success", "message": result}
1443
+ except Exception as e:
1444
+ logger.error(f"Error processing webhook: {str(e)}")
1445
+ return {"status": "error", "message": f"Error processing webhook: {str(e)}"}
1446
+
1447
+ @app.get("/subscription/verify/{subscription_id}")
1448
+ async def verify_subscription(
1449
+ subscription_id: str,
1450
+ current_user: User = Depends(get_current_active_user)
1451
+ ):
1452
+ """Verify a subscription payment and update user tier"""
1453
+ try:
1454
+ # Verify the subscription
1455
+ success, result = verify_subscription_payment(subscription_id)
1456
+
1457
+ if not success:
1458
+ raise HTTPException(status_code=400, detail=f"Subscription verification failed: {result}")
1459
+
1460
+ # Get the plan ID from the subscription to determine tier
1461
+ plan_id = result.get("plan_id", "")
1462
+
1463
+ # Connect to DB to get the tier for this plan
1464
+ conn = get_db_connection()
1465
+ cursor = conn.cursor()
1466
+ cursor.execute("SELECT tier FROM paypal_plans WHERE plan_id = ?", (plan_id,))
1467
+ tier_result = cursor.fetchone()
1468
+ conn.close()
1469
+
1470
+ if not tier_result:
1471
+ raise HTTPException(status_code=400, detail="Could not determine subscription tier")
1472
+
1473
+ tier = tier_result[0]
1474
+
1475
+ # Update the user's subscription
1476
+ success, update_result = update_user_subscription(current_user.email, subscription_id, tier)
1477
+
1478
+ if not success:
1479
+ raise HTTPException(status_code=500, detail=f"Failed to update subscription: {update_result}")
1480
+
1481
+ return {
1482
+ "message": f"Successfully subscribed to {tier} tier",
1483
+ "subscription_id": subscription_id,
1484
+ "status": result.get("status", ""),
1485
+ "next_billing_time": result.get("billing_info", {}).get("next_billing_time", "")
1486
+ }
1487
+
1488
+ except HTTPException:
1489
+ raise
1490
+ except Exception as e:
1491
+ print(f"Subscription verification error: {str(e)}")
1492
+ raise HTTPException(status_code=500, detail=f"Subscription verification failed: {str(e)}")
1493
+
1494
+ @app.post("/webhook/paypal")
1495
+ async def paypal_webhook(request: Request):
1496
+ """Handle PayPal subscription webhooks"""
1497
+ try:
1498
+ payload = await request.json()
1499
+ logger.info(f"Received PayPal webhook: {payload.get('event_type', 'unknown event')}")
1500
+
1501
+ # Process the webhook
1502
+ result = handle_subscription_webhook(payload)
1503
+
1504
+ return {"status": "success", "message": "Webhook processed"}
1505
+ except Exception as e:
1506
+ logger.error(f"Webhook processing error: {str(e)}")
1507
+ # Return 200 even on error to acknowledge receipt to PayPal
1508
+ return {"status": "error", "message": str(e)}
1509
+
1510
+ # Add this to your startup code
1511
+ @app.on_event("startup")
1512
+ async def startup_event():
1513
+ """Initialize subscription plans on startup"""
1514
+ try:
1515
+ # Initialize PayPal subscription plans if needed
1516
+ # If you have an initialize_subscription_plans function in your paypal_integration.py,
1517
+ # you can call it here
1518
+ print("Application started successfully")
1519
+ except Exception as e:
1520
+ print(f"Error during startup: {str(e)}")
1521
+
1522
+ if __name__ == "__main__":
1523
+ import uvicorn
1524
+ port = int(os.environ.get("PORT", 7860))
1525
+ host = os.environ.get("HOST", "0.0.0.0")
1526
+ uvicorn.run("app:app", host=host, port=port, reload=True)