Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,8 @@ import gradio as gr
|
|
2 |
from typing import List, Dict, Tuple
|
3 |
from langchain_core.prompts import ChatPromptTemplate
|
4 |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
5 |
-
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
|
|
6 |
import torch
|
7 |
import os
|
8 |
from astrapy.db import AstraDB
|
@@ -10,8 +11,8 @@ from dotenv import load_dotenv
|
|
10 |
from huggingface_hub import login
|
11 |
import time
|
12 |
import logging
|
13 |
-
from functools import lru_cache
|
14 |
import numpy as np
|
|
|
15 |
|
16 |
# Configure logging
|
17 |
logging.basicConfig(
|
@@ -34,7 +35,7 @@ class LegalTextSearchBot:
|
|
34 |
)
|
35 |
self.collection = self.astra_db.collection(os.getenv("ASTRA_DB_COLLECTION"))
|
36 |
|
37 |
-
# Initialize language model
|
38 |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
39 |
model = AutoModelForCausalLM.from_pretrained(
|
40 |
model_name,
|
@@ -56,13 +57,8 @@ class LegalTextSearchBot:
|
|
56 |
)
|
57 |
self.llm = HuggingFacePipeline(pipeline=pipe)
|
58 |
|
59 |
-
# Initialize
|
60 |
-
self.
|
61 |
-
self.embedding_pipeline = pipeline(
|
62 |
-
"feature-extraction",
|
63 |
-
model=self.embedding_model_name,
|
64 |
-
device_map="auto"
|
65 |
-
)
|
66 |
|
67 |
self.template = """
|
68 |
IMPORTANT: You are a legal assistant that provides accurate information based on the Indian legal sections provided in the context.
|
@@ -87,6 +83,8 @@ class LegalTextSearchBot:
|
|
87 |
self.chat_history = ""
|
88 |
self.is_searching = False
|
89 |
|
|
|
|
|
90 |
except Exception as e:
|
91 |
logger.error(f"Error initializing LegalTextSearchBot: {str(e)}")
|
92 |
raise
|
@@ -96,21 +94,19 @@ class LegalTextSearchBot:
|
|
96 |
try:
|
97 |
# Clean and prepare text
|
98 |
text = text.replace('\n', ' ').strip()
|
|
|
|
|
99 |
|
100 |
# Generate embedding
|
101 |
-
|
102 |
-
embeddings = torch.mean(torch.tensor(outputs[0]), dim=0)
|
103 |
-
|
104 |
-
# Convert to list and ensure correct dimension
|
105 |
-
embedding_list = embeddings.tolist()
|
106 |
|
107 |
-
# Pad or truncate to
|
108 |
-
if len(
|
109 |
-
|
110 |
-
elif len(
|
111 |
-
|
112 |
|
113 |
-
return
|
114 |
|
115 |
except Exception as e:
|
116 |
logger.error(f"Error generating embedding: {str(e)}")
|
@@ -125,7 +121,7 @@ class LegalTextSearchBot:
|
|
125 |
|
126 |
results = list(self.collection.vector_find(
|
127 |
query_embedding,
|
128 |
-
|
129 |
fields=["section_number", "title", "chapter_number", "chapter_title",
|
130 |
"content", "type", "metadata"]
|
131 |
))
|
@@ -142,10 +138,13 @@ class LegalTextSearchBot:
|
|
142 |
results = list(self._cached_search(query))
|
143 |
|
144 |
if not results and self.is_searching:
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
149 |
|
150 |
return results
|
151 |
|
|
|
2 |
from typing import List, Dict, Tuple
|
3 |
from langchain_core.prompts import ChatPromptTemplate
|
4 |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
5 |
+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
import torch
|
8 |
import os
|
9 |
from astrapy.db import AstraDB
|
|
|
11 |
from huggingface_hub import login
|
12 |
import time
|
13 |
import logging
|
|
|
14 |
import numpy as np
|
15 |
+
from functools import lru_cache
|
16 |
|
17 |
# Configure logging
|
18 |
logging.basicConfig(
|
|
|
35 |
)
|
36 |
self.collection = self.astra_db.collection(os.getenv("ASTRA_DB_COLLECTION"))
|
37 |
|
38 |
+
# Initialize language model
|
39 |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
40 |
model = AutoModelForCausalLM.from_pretrained(
|
41 |
model_name,
|
|
|
57 |
)
|
58 |
self.llm = HuggingFacePipeline(pipeline=pipe)
|
59 |
|
60 |
+
# Initialize sentence transformer for embeddings
|
61 |
+
self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
self.template = """
|
64 |
IMPORTANT: You are a legal assistant that provides accurate information based on the Indian legal sections provided in the context.
|
|
|
83 |
self.chat_history = ""
|
84 |
self.is_searching = False
|
85 |
|
86 |
+
logger.info("Successfully initialized LegalTextSearchBot")
|
87 |
+
|
88 |
except Exception as e:
|
89 |
logger.error(f"Error initializing LegalTextSearchBot: {str(e)}")
|
90 |
raise
|
|
|
94 |
try:
|
95 |
# Clean and prepare text
|
96 |
text = text.replace('\n', ' ').strip()
|
97 |
+
if not text:
|
98 |
+
text = " " # Ensure non-empty input
|
99 |
|
100 |
# Generate embedding
|
101 |
+
embedding = self.embedding_model.encode(text)
|
|
|
|
|
|
|
|
|
102 |
|
103 |
+
# Pad or truncate to 1024 dimensions
|
104 |
+
if len(embedding) < 1024:
|
105 |
+
embedding = np.pad(embedding, (0, 1024 - len(embedding)))
|
106 |
+
elif len(embedding) > 1024:
|
107 |
+
embedding = embedding[:1024]
|
108 |
|
109 |
+
return embedding.tolist()
|
110 |
|
111 |
except Exception as e:
|
112 |
logger.error(f"Error generating embedding: {str(e)}")
|
|
|
121 |
|
122 |
results = list(self.collection.vector_find(
|
123 |
query_embedding,
|
124 |
+
top_k=5, # Using top_k instead of limit
|
125 |
fields=["section_number", "title", "chapter_number", "chapter_title",
|
126 |
"content", "type", "metadata"]
|
127 |
))
|
|
|
138 |
results = list(self._cached_search(query))
|
139 |
|
140 |
if not results and self.is_searching:
|
141 |
+
# Fallback to regular search
|
142 |
+
cursor = self.collection.find({})
|
143 |
+
results = []
|
144 |
+
for doc in cursor:
|
145 |
+
if len(results) >= 5:
|
146 |
+
break
|
147 |
+
results.append(doc)
|
148 |
|
149 |
return results
|
150 |
|