Spaces:
Runtime error
Runtime error
Update add_embeddings.py
Browse files- add_embeddings.py +22 -13
add_embeddings.py
CHANGED
@@ -1,19 +1,34 @@
|
|
1 |
import os
|
2 |
from PyPDF2 import PdfReader
|
3 |
-
from
|
|
|
4 |
import chromadb
|
5 |
from typing import List, Dict
|
6 |
import re
|
|
|
7 |
|
8 |
class LegalDocumentProcessor:
|
9 |
def __init__(self):
|
10 |
-
self.
|
|
|
11 |
self.chroma_client = chromadb.Client()
|
12 |
self.collection = self.chroma_client.create_collection(
|
13 |
name="indian_legal_docs",
|
14 |
metadata={"description": "Indian Criminal Law Documents"}
|
15 |
)
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def process_pdf(self, pdf_path: str) -> List[str]:
|
18 |
"""Extract text from PDF and split into chunks"""
|
19 |
reader = PdfReader(pdf_path)
|
@@ -21,13 +36,11 @@ class LegalDocumentProcessor:
|
|
21 |
for page in reader.pages:
|
22 |
text += page.extract_text()
|
23 |
|
24 |
-
# Split into meaningful chunks (by sections/paragraphs)
|
25 |
chunks = self._split_into_chunks(text)
|
26 |
return chunks
|
27 |
|
28 |
def _split_into_chunks(self, text: str, max_chunk_size: int = 1000) -> List[str]:
|
29 |
"""Split text into smaller chunks while preserving context"""
|
30 |
-
# Split on section boundaries or paragraphs
|
31 |
sections = re.split(r'(Chapter \d+|Section \d+|\n\n)', text)
|
32 |
|
33 |
chunks = []
|
@@ -55,16 +68,14 @@ class LegalDocumentProcessor:
|
|
55 |
}
|
56 |
|
57 |
for law_code, pdf_path in pdf_files.items():
|
58 |
-
# Process PDF
|
59 |
chunks = self.process_pdf(pdf_path)
|
60 |
|
61 |
-
# Generate embeddings and store in ChromaDB
|
62 |
for i, chunk in enumerate(chunks):
|
63 |
-
|
64 |
|
65 |
self.collection.add(
|
66 |
documents=[chunk],
|
67 |
-
embeddings=
|
68 |
metadatas=[{
|
69 |
"law_code": law_code,
|
70 |
"chunk_id": f"{law_code}_chunk_{i}",
|
@@ -73,11 +84,11 @@ class LegalDocumentProcessor:
|
|
73 |
ids=[f"{law_code}_chunk_{i}"]
|
74 |
)
|
75 |
|
76 |
-
def search_documents(self, query: str, n_results: int = 3) ->
|
77 |
"""Search for relevant legal information"""
|
78 |
-
query_embedding = self.
|
79 |
results = self.collection.query(
|
80 |
-
query_embeddings=query_embedding,
|
81 |
n_results=n_results
|
82 |
)
|
83 |
|
@@ -87,11 +98,9 @@ class LegalDocumentProcessor:
|
|
87 |
}
|
88 |
|
89 |
if __name__ == "__main__":
|
90 |
-
# Initialize and run document processing
|
91 |
processor = LegalDocumentProcessor()
|
92 |
processor.process_and_store_documents()
|
93 |
|
94 |
-
# Test search functionality
|
95 |
test_query = "What are the provisions for digital evidence?"
|
96 |
results = processor.search_documents(test_query)
|
97 |
print(f"Query: {test_query}")
|
|
|
1 |
import os
|
2 |
from PyPDF2 import PdfReader
|
3 |
+
from transformers import AutoTokenizer, AutoModel
|
4 |
+
import torch
|
5 |
import chromadb
|
6 |
from typing import List, Dict
|
7 |
import re
|
8 |
+
import numpy as np
|
9 |
|
10 |
class LegalDocumentProcessor:
|
11 |
def __init__(self):
|
12 |
+
self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
13 |
+
self.model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
14 |
self.chroma_client = chromadb.Client()
|
15 |
self.collection = self.chroma_client.create_collection(
|
16 |
name="indian_legal_docs",
|
17 |
metadata={"description": "Indian Criminal Law Documents"}
|
18 |
)
|
19 |
|
20 |
+
def mean_pooling(self, model_output, attention_mask):
|
21 |
+
token_embeddings = model_output[0]
|
22 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
23 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
24 |
+
|
25 |
+
def get_embedding(self, text: str) -> List[float]:
|
26 |
+
inputs = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
|
27 |
+
with torch.no_grad():
|
28 |
+
model_output = self.model(**inputs)
|
29 |
+
sentence_embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
|
30 |
+
return sentence_embeddings[0].tolist()
|
31 |
+
|
32 |
def process_pdf(self, pdf_path: str) -> List[str]:
|
33 |
"""Extract text from PDF and split into chunks"""
|
34 |
reader = PdfReader(pdf_path)
|
|
|
36 |
for page in reader.pages:
|
37 |
text += page.extract_text()
|
38 |
|
|
|
39 |
chunks = self._split_into_chunks(text)
|
40 |
return chunks
|
41 |
|
42 |
def _split_into_chunks(self, text: str, max_chunk_size: int = 1000) -> List[str]:
|
43 |
"""Split text into smaller chunks while preserving context"""
|
|
|
44 |
sections = re.split(r'(Chapter \d+|Section \d+|\n\n)', text)
|
45 |
|
46 |
chunks = []
|
|
|
68 |
}
|
69 |
|
70 |
for law_code, pdf_path in pdf_files.items():
|
|
|
71 |
chunks = self.process_pdf(pdf_path)
|
72 |
|
|
|
73 |
for i, chunk in enumerate(chunks):
|
74 |
+
embedding = self.get_embedding(chunk)
|
75 |
|
76 |
self.collection.add(
|
77 |
documents=[chunk],
|
78 |
+
embeddings=[embedding],
|
79 |
metadatas=[{
|
80 |
"law_code": law_code,
|
81 |
"chunk_id": f"{law_code}_chunk_{i}",
|
|
|
84 |
ids=[f"{law_code}_chunk_{i}"]
|
85 |
)
|
86 |
|
87 |
+
def search_documents(self, query: str, n_results: int = 3) -> Dict:
|
88 |
"""Search for relevant legal information"""
|
89 |
+
query_embedding = self.get_embedding(query)
|
90 |
results = self.collection.query(
|
91 |
+
query_embeddings=[query_embedding],
|
92 |
n_results=n_results
|
93 |
)
|
94 |
|
|
|
98 |
}
|
99 |
|
100 |
if __name__ == "__main__":
|
|
|
101 |
processor = LegalDocumentProcessor()
|
102 |
processor.process_and_store_documents()
|
103 |
|
|
|
104 |
test_query = "What are the provisions for digital evidence?"
|
105 |
results = processor.search_documents(test_query)
|
106 |
print(f"Query: {test_query}")
|