veerukhannan commited on
Commit
c2dd28c
·
verified ·
1 Parent(s): 86b8124

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -26
app.py CHANGED
@@ -2,7 +2,8 @@ import gradio as gr
2
  from typing import List, Dict, Tuple
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
5
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel
 
6
  import torch
7
  import os
8
  from astrapy.db import AstraDB
@@ -10,8 +11,8 @@ from dotenv import load_dotenv
10
  from huggingface_hub import login
11
  import time
12
  import logging
13
- from functools import lru_cache
14
  import numpy as np
 
15
 
16
  # Configure logging
17
  logging.basicConfig(
@@ -34,7 +35,7 @@ class LegalTextSearchBot:
34
  )
35
  self.collection = self.astra_db.collection(os.getenv("ASTRA_DB_COLLECTION"))
36
 
37
- # Initialize language model for text generation
38
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
39
  model = AutoModelForCausalLM.from_pretrained(
40
  model_name,
@@ -56,13 +57,8 @@ class LegalTextSearchBot:
56
  )
57
  self.llm = HuggingFacePipeline(pipeline=pipe)
58
 
59
- # Initialize embedding model
60
- self.embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
61
- self.embedding_pipeline = pipeline(
62
- "feature-extraction",
63
- model=self.embedding_model_name,
64
- device_map="auto"
65
- )
66
 
67
  self.template = """
68
  IMPORTANT: You are a legal assistant that provides accurate information based on the Indian legal sections provided in the context.
@@ -87,6 +83,8 @@ class LegalTextSearchBot:
87
  self.chat_history = ""
88
  self.is_searching = False
89
 
 
 
90
  except Exception as e:
91
  logger.error(f"Error initializing LegalTextSearchBot: {str(e)}")
92
  raise
@@ -96,21 +94,19 @@ class LegalTextSearchBot:
96
  try:
97
  # Clean and prepare text
98
  text = text.replace('\n', ' ').strip()
 
 
99
 
100
  # Generate embedding
101
- outputs = self.embedding_pipeline(text)
102
- embeddings = torch.mean(torch.tensor(outputs[0]), dim=0)
103
-
104
- # Convert to list and ensure correct dimension
105
- embedding_list = embeddings.tolist()
106
 
107
- # Pad or truncate to exactly 1024 dimensions
108
- if len(embedding_list) < 1024:
109
- embedding_list.extend([0.0] * (1024 - len(embedding_list)))
110
- elif len(embedding_list) > 1024:
111
- embedding_list = embedding_list[:1024]
112
 
113
- return embedding_list
114
 
115
  except Exception as e:
116
  logger.error(f"Error generating embedding: {str(e)}")
@@ -125,7 +121,7 @@ class LegalTextSearchBot:
125
 
126
  results = list(self.collection.vector_find(
127
  query_embedding,
128
- limit=5,
129
  fields=["section_number", "title", "chapter_number", "chapter_title",
130
  "content", "type", "metadata"]
131
  ))
@@ -142,10 +138,13 @@ class LegalTextSearchBot:
142
  results = list(self._cached_search(query))
143
 
144
  if not results and self.is_searching:
145
- results = list(self.collection.find(
146
- {},
147
- limit=5
148
- ))
 
 
 
149
 
150
  return results
151
 
 
2
  from typing import List, Dict, Tuple
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
5
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
6
+ from sentence_transformers import SentenceTransformer
7
  import torch
8
  import os
9
  from astrapy.db import AstraDB
 
11
  from huggingface_hub import login
12
  import time
13
  import logging
 
14
  import numpy as np
15
+ from functools import lru_cache
16
 
17
  # Configure logging
18
  logging.basicConfig(
 
35
  )
36
  self.collection = self.astra_db.collection(os.getenv("ASTRA_DB_COLLECTION"))
37
 
38
+ # Initialize language model
39
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
40
  model = AutoModelForCausalLM.from_pretrained(
41
  model_name,
 
57
  )
58
  self.llm = HuggingFacePipeline(pipeline=pipe)
59
 
60
+ # Initialize sentence transformer for embeddings
61
+ self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
 
 
 
 
62
 
63
  self.template = """
64
  IMPORTANT: You are a legal assistant that provides accurate information based on the Indian legal sections provided in the context.
 
83
  self.chat_history = ""
84
  self.is_searching = False
85
 
86
+ logger.info("Successfully initialized LegalTextSearchBot")
87
+
88
  except Exception as e:
89
  logger.error(f"Error initializing LegalTextSearchBot: {str(e)}")
90
  raise
 
94
  try:
95
  # Clean and prepare text
96
  text = text.replace('\n', ' ').strip()
97
+ if not text:
98
+ text = " " # Ensure non-empty input
99
 
100
  # Generate embedding
101
+ embedding = self.embedding_model.encode(text)
 
 
 
 
102
 
103
+ # Pad or truncate to 1024 dimensions
104
+ if len(embedding) < 1024:
105
+ embedding = np.pad(embedding, (0, 1024 - len(embedding)))
106
+ elif len(embedding) > 1024:
107
+ embedding = embedding[:1024]
108
 
109
+ return embedding.tolist()
110
 
111
  except Exception as e:
112
  logger.error(f"Error generating embedding: {str(e)}")
 
121
 
122
  results = list(self.collection.vector_find(
123
  query_embedding,
124
+ top_k=5, # Using top_k instead of limit
125
  fields=["section_number", "title", "chapter_number", "chapter_title",
126
  "content", "type", "metadata"]
127
  ))
 
138
  results = list(self._cached_search(query))
139
 
140
  if not results and self.is_searching:
141
+ # Fallback to regular search
142
+ cursor = self.collection.find({})
143
+ results = []
144
+ for doc in cursor:
145
+ if len(results) >= 5:
146
+ break
147
+ results.append(doc)
148
 
149
  return results
150