veerukhannan commited on
Commit
f68e1d5
·
verified ·
1 Parent(s): e05fa4e

Update add_embeddings.py

Browse files
Files changed (1) hide show
  1. add_embeddings.py +22 -13
add_embeddings.py CHANGED
@@ -1,19 +1,34 @@
1
  import os
2
  from PyPDF2 import PdfReader
3
- from sentence_transformers import SentenceTransformer
 
4
  import chromadb
5
  from typing import List, Dict
6
  import re
 
7
 
8
  class LegalDocumentProcessor:
9
  def __init__(self):
10
- self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
11
  self.chroma_client = chromadb.Client()
12
  self.collection = self.chroma_client.create_collection(
13
  name="indian_legal_docs",
14
  metadata={"description": "Indian Criminal Law Documents"}
15
  )
16
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def process_pdf(self, pdf_path: str) -> List[str]:
18
  """Extract text from PDF and split into chunks"""
19
  reader = PdfReader(pdf_path)
@@ -21,13 +36,11 @@ class LegalDocumentProcessor:
21
  for page in reader.pages:
22
  text += page.extract_text()
23
 
24
- # Split into meaningful chunks (by sections/paragraphs)
25
  chunks = self._split_into_chunks(text)
26
  return chunks
27
 
28
  def _split_into_chunks(self, text: str, max_chunk_size: int = 1000) -> List[str]:
29
  """Split text into smaller chunks while preserving context"""
30
- # Split on section boundaries or paragraphs
31
  sections = re.split(r'(Chapter \d+|Section \d+|\n\n)', text)
32
 
33
  chunks = []
@@ -55,16 +68,14 @@ class LegalDocumentProcessor:
55
  }
56
 
57
  for law_code, pdf_path in pdf_files.items():
58
- # Process PDF
59
  chunks = self.process_pdf(pdf_path)
60
 
61
- # Generate embeddings and store in ChromaDB
62
  for i, chunk in enumerate(chunks):
63
- embeddings = self.embedding_model.encode([chunk]).tolist()
64
 
65
  self.collection.add(
66
  documents=[chunk],
67
- embeddings=embeddings,
68
  metadatas=[{
69
  "law_code": law_code,
70
  "chunk_id": f"{law_code}_chunk_{i}",
@@ -73,11 +84,11 @@ class LegalDocumentProcessor:
73
  ids=[f"{law_code}_chunk_{i}"]
74
  )
75
 
76
- def search_documents(self, query: str, n_results: int = 3) -> List[Dict]:
77
  """Search for relevant legal information"""
78
- query_embedding = self.embedding_model.encode([query]).tolist()
79
  results = self.collection.query(
80
- query_embeddings=query_embedding,
81
  n_results=n_results
82
  )
83
 
@@ -87,11 +98,9 @@ class LegalDocumentProcessor:
87
  }
88
 
89
  if __name__ == "__main__":
90
- # Initialize and run document processing
91
  processor = LegalDocumentProcessor()
92
  processor.process_and_store_documents()
93
 
94
- # Test search functionality
95
  test_query = "What are the provisions for digital evidence?"
96
  results = processor.search_documents(test_query)
97
  print(f"Query: {test_query}")
 
1
  import os
2
  from PyPDF2 import PdfReader
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import torch
5
  import chromadb
6
  from typing import List, Dict
7
  import re
8
+ import numpy as np
9
 
10
  class LegalDocumentProcessor:
11
  def __init__(self):
12
+ self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
13
+ self.model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
14
  self.chroma_client = chromadb.Client()
15
  self.collection = self.chroma_client.create_collection(
16
  name="indian_legal_docs",
17
  metadata={"description": "Indian Criminal Law Documents"}
18
  )
19
 
20
+ def mean_pooling(self, model_output, attention_mask):
21
+ token_embeddings = model_output[0]
22
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
23
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
24
+
25
+ def get_embedding(self, text: str) -> List[float]:
26
+ inputs = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
27
+ with torch.no_grad():
28
+ model_output = self.model(**inputs)
29
+ sentence_embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
30
+ return sentence_embeddings[0].tolist()
31
+
32
  def process_pdf(self, pdf_path: str) -> List[str]:
33
  """Extract text from PDF and split into chunks"""
34
  reader = PdfReader(pdf_path)
 
36
  for page in reader.pages:
37
  text += page.extract_text()
38
 
 
39
  chunks = self._split_into_chunks(text)
40
  return chunks
41
 
42
  def _split_into_chunks(self, text: str, max_chunk_size: int = 1000) -> List[str]:
43
  """Split text into smaller chunks while preserving context"""
 
44
  sections = re.split(r'(Chapter \d+|Section \d+|\n\n)', text)
45
 
46
  chunks = []
 
68
  }
69
 
70
  for law_code, pdf_path in pdf_files.items():
 
71
  chunks = self.process_pdf(pdf_path)
72
 
 
73
  for i, chunk in enumerate(chunks):
74
+ embedding = self.get_embedding(chunk)
75
 
76
  self.collection.add(
77
  documents=[chunk],
78
+ embeddings=[embedding],
79
  metadatas=[{
80
  "law_code": law_code,
81
  "chunk_id": f"{law_code}_chunk_{i}",
 
84
  ids=[f"{law_code}_chunk_{i}"]
85
  )
86
 
87
+ def search_documents(self, query: str, n_results: int = 3) -> Dict:
88
  """Search for relevant legal information"""
89
+ query_embedding = self.get_embedding(query)
90
  results = self.collection.query(
91
+ query_embeddings=[query_embedding],
92
  n_results=n_results
93
  )
94
 
 
98
  }
99
 
100
  if __name__ == "__main__":
 
101
  processor = LegalDocumentProcessor()
102
  processor.process_and_store_documents()
103
 
 
104
  test_query = "What are the provisions for digital evidence?"
105
  results = processor.search_documents(test_query)
106
  print(f"Query: {test_query}")