Spaces:
Sleeping
Sleeping
Rohil Bansal
commited on
Commit
Β·
4adc02d
1
Parent(s):
dc3ed8e
Chatbot working.
Browse files- app.py +1 -1
- assets/data/{Mandel-IntroEconTheory.pdf β Constitution.pdf} +2 -2
- assets/data/IPC.pdf +3 -0
- assets/data/IPC_and_Constitution.pdf +3 -0
- graphs/workflow_graph.jpg +3 -0
- src/__pycache__/buildgraph.cpython-312.pyc +0 -0
- src/__pycache__/graph.cpython-312.pyc +0 -0
- src/__pycache__/index.cpython-312.pyc +0 -0
- src/__pycache__/llm.cpython-312.pyc +0 -0
- src/buildgraph.py +60 -11
- src/graph.py +86 -116
- src/index.py +23 -8
- src/llm.py +40 -16
- vectordb/{08d73b15-e800-45c5-a450-5b9d696166f3 β 65ba2328-ffa1-497d-b641-c6b84db7f0e1}/data_level0.bin +0 -0
- vectordb/{08d73b15-e800-45c5-a450-5b9d696166f3 β 65ba2328-ffa1-497d-b641-c6b84db7f0e1}/header.bin +0 -0
- vectordb/{08d73b15-e800-45c5-a450-5b9d696166f3 β 65ba2328-ffa1-497d-b641-c6b84db7f0e1}/length.bin +1 -1
- vectordb/{08d73b15-e800-45c5-a450-5b9d696166f3 β 65ba2328-ffa1-497d-b641-c6b84db7f0e1}/link_lists.bin +0 -0
- vectordb/chroma.sqlite3 +2 -2
app.py
CHANGED
@@ -71,7 +71,7 @@ if prompt := st.chat_input("What is your question?"):
|
|
71 |
full_response = "β οΈ **_Note: Information provided may be inaccurate._** \n\n\n"
|
72 |
for char in response_content:
|
73 |
full_response += char
|
74 |
-
time.sleep(0.
|
75 |
message_placeholder.markdown(full_response + "β")
|
76 |
message_placeholder.markdown(full_response)
|
77 |
|
|
|
71 |
full_response = "β οΈ **_Note: Information provided may be inaccurate._** \n\n\n"
|
72 |
for char in response_content:
|
73 |
full_response += char
|
74 |
+
time.sleep(0.03)
|
75 |
message_placeholder.markdown(full_response + "β")
|
76 |
message_placeholder.markdown(full_response)
|
77 |
|
assets/data/{Mandel-IntroEconTheory.pdf β Constitution.pdf}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:043686f3266b9a88fd7949f87120520e544aebd20189ee4fbfb246e871333540
|
3 |
+
size 655093
|
assets/data/IPC.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef8945c5d1b02904da67959e245b87bd5751ed5563d03ab0079758909f145309
|
3 |
+
size 842456
|
assets/data/IPC_and_Constitution.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d44daff2e1184960f888e303558384322b719cfb82cdc1f50dec07794a7ed554
|
3 |
+
size 1500316
|
graphs/workflow_graph.jpg
ADDED
![]() |
Git LFS Details
|
src/__pycache__/buildgraph.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/buildgraph.cpython-312.pyc and b/src/__pycache__/buildgraph.cpython-312.pyc differ
|
|
src/__pycache__/graph.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/graph.cpython-312.pyc and b/src/__pycache__/graph.cpython-312.pyc differ
|
|
src/__pycache__/index.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/index.cpython-312.pyc and b/src/__pycache__/index.cpython-312.pyc differ
|
|
src/__pycache__/llm.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/llm.cpython-312.pyc and b/src/__pycache__/llm.cpython-312.pyc differ
|
|
src/buildgraph.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
from src.graph import *
|
2 |
-
from pprint import pprint
|
3 |
from langgraph.graph import END, StateGraph, START
|
4 |
import sys
|
5 |
from langgraph.checkpoint.memory import MemorySaver
|
6 |
-
import json
|
7 |
|
8 |
memory = MemorySaver()
|
9 |
|
@@ -12,23 +10,33 @@ try:
|
|
12 |
workflow = StateGraph(GraphState)
|
13 |
|
14 |
print("Adding nodes to the graph...")
|
15 |
-
workflow.add_node("
|
|
|
|
|
|
|
|
|
16 |
workflow.add_node("retrieve", retrieve)
|
|
|
17 |
workflow.add_node("grade_documents", grade_documents)
|
18 |
workflow.add_node("generate", generate)
|
19 |
workflow.add_node("transform_query", transform_query)
|
20 |
print("Nodes added successfully.")
|
21 |
|
22 |
print("Building graph edges...")
|
|
|
23 |
workflow.add_conditional_edges(
|
24 |
-
|
25 |
-
|
26 |
{
|
27 |
-
"
|
28 |
-
"
|
29 |
-
|
|
|
30 |
)
|
31 |
-
|
|
|
|
|
|
|
32 |
workflow.add_edge("retrieve", "grade_documents")
|
33 |
workflow.add_conditional_edges(
|
34 |
"grade_documents",
|
@@ -48,12 +56,52 @@ try:
|
|
48 |
"not useful": "transform_query",
|
49 |
},
|
50 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
print("Graph edges built successfully.")
|
52 |
|
53 |
print("Compiling the workflow...")
|
54 |
app = workflow.compile(checkpointer=memory)
|
55 |
print("Workflow compiled successfully.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
except Exception as e:
|
58 |
print(f"Error building the graph: {e}")
|
59 |
sys.exit(1)
|
@@ -74,8 +122,9 @@ def run_workflow(question, config):
|
|
74 |
final_output = None
|
75 |
for output in app.stream(input_state, config):
|
76 |
for key, value in output.items():
|
77 |
-
print(f"Node '{key}'
|
78 |
-
|
|
|
79 |
final_output = value
|
80 |
|
81 |
if final_output is None:
|
|
|
1 |
from src.graph import *
|
|
|
2 |
from langgraph.graph import END, StateGraph, START
|
3 |
import sys
|
4 |
from langgraph.checkpoint.memory import MemorySaver
|
|
|
5 |
|
6 |
memory = MemorySaver()
|
7 |
|
|
|
10 |
workflow = StateGraph(GraphState)
|
11 |
|
12 |
print("Adding nodes to the graph...")
|
13 |
+
workflow.add_node("understand_intent", understand_intent)
|
14 |
+
# workflow.add_node("intent_aware_response", intent_aware_response)
|
15 |
+
workflow.add_node("greeting", greeting)
|
16 |
+
workflow.add_node("off_topic", off_topic)
|
17 |
+
workflow.add_node("route_question", route_question)
|
18 |
workflow.add_node("retrieve", retrieve)
|
19 |
+
workflow.add_node("web_search", web_search)
|
20 |
workflow.add_node("grade_documents", grade_documents)
|
21 |
workflow.add_node("generate", generate)
|
22 |
workflow.add_node("transform_query", transform_query)
|
23 |
print("Nodes added successfully.")
|
24 |
|
25 |
print("Building graph edges...")
|
26 |
+
workflow.add_edge(START, "understand_intent")
|
27 |
workflow.add_conditional_edges(
|
28 |
+
"understand_intent",
|
29 |
+
intent_aware_response,
|
30 |
{
|
31 |
+
"off_topic": "off_topic",
|
32 |
+
"greeting": "greeting",
|
33 |
+
"route_question": "route_question",
|
34 |
+
}
|
35 |
)
|
36 |
+
|
37 |
+
workflow.add_edge("greeting", END)
|
38 |
+
workflow.add_edge("off_topic", END)
|
39 |
+
|
40 |
workflow.add_edge("retrieve", "grade_documents")
|
41 |
workflow.add_conditional_edges(
|
42 |
"grade_documents",
|
|
|
56 |
"not useful": "transform_query",
|
57 |
},
|
58 |
)
|
59 |
+
workflow.add_conditional_edges(
|
60 |
+
"route_question",
|
61 |
+
lambda x: x["route_question"],
|
62 |
+
{
|
63 |
+
"web_search": "web_search",
|
64 |
+
"vectorstore": "retrieve",
|
65 |
+
}
|
66 |
+
)
|
67 |
print("Graph edges built successfully.")
|
68 |
|
69 |
print("Compiling the workflow...")
|
70 |
app = workflow.compile(checkpointer=memory)
|
71 |
print("Workflow compiled successfully.")
|
72 |
+
|
73 |
+
try:
|
74 |
+
from IPython import get_ipython
|
75 |
+
from IPython.display import Image, display
|
76 |
+
|
77 |
+
# Check if we're in an IPython environment
|
78 |
+
if get_ipython() is not None:
|
79 |
+
print("Attempting to display graph visualization...")
|
80 |
+
graph_image = app.get_graph().draw_mermaid_png()
|
81 |
+
display(Image(graph_image))
|
82 |
+
print("Graph visualization displayed successfully.")
|
83 |
+
else:
|
84 |
+
print("Not running in IPython environment. Saving graph as JPG...")
|
85 |
+
import os
|
86 |
+
from PIL import Image
|
87 |
+
import io
|
88 |
|
89 |
+
graph_image = app.get_graph().draw_mermaid_png()
|
90 |
+
img = Image.open(io.BytesIO(graph_image))
|
91 |
+
img = img.convert('RGB')
|
92 |
+
|
93 |
+
# Create a 'graphs' directory if it doesn't exist
|
94 |
+
if not os.path.exists('graphs'):
|
95 |
+
os.makedirs('graphs')
|
96 |
+
|
97 |
+
img.save('graphs/workflow_graph.jpg', 'JPEG')
|
98 |
+
print("Graph saved as 'graphs/workflow_graph.jpg'")
|
99 |
+
except ImportError as e:
|
100 |
+
print(f"Required libraries not available. Graph visualization skipped. Error: {e}")
|
101 |
+
except Exception as e:
|
102 |
+
print(f"Error handling graph visualization: {e}")
|
103 |
+
print("Graph visualization skipped.")
|
104 |
+
|
105 |
except Exception as e:
|
106 |
print(f"Error building the graph: {e}")
|
107 |
sys.exit(1)
|
|
|
122 |
final_output = None
|
123 |
for output in app.stream(input_state, config):
|
124 |
for key, value in output.items():
|
125 |
+
print(f"Node '{key}'")
|
126 |
+
# print(f"Value: {json.dumps(value, default=str)}") # Debug print
|
127 |
+
if key in ["generate", "off_topic", "greeting"]:
|
128 |
final_output = value
|
129 |
|
130 |
if final_output is None:
|
src/graph.py
CHANGED
@@ -2,94 +2,96 @@ from typing import List, Dict
|
|
2 |
from typing_extensions import TypedDict
|
3 |
from src.websearch import *
|
4 |
from src.llm import *
|
|
|
5 |
|
6 |
-
#%%
|
7 |
class GraphState(TypedDict):
|
8 |
-
"""
|
9 |
-
Represents the state of our graph.
|
10 |
-
|
11 |
-
Attributes:
|
12 |
-
question: current question
|
13 |
-
generation: LLM generation
|
14 |
-
documents: list of documents
|
15 |
-
chat_history: list of previous messages
|
16 |
-
"""
|
17 |
-
|
18 |
question: str
|
19 |
generation: str
|
20 |
documents: List[str]
|
21 |
chat_history: List[Dict[str, str]]
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
def retrieve(state):
|
28 |
-
"""
|
29 |
-
Retrieve documents
|
30 |
-
|
31 |
-
Args:
|
32 |
-
state (dict): The current graph state
|
33 |
-
|
34 |
-
Returns:
|
35 |
-
state (dict): New key added to state, documents, that contains retrieved documents
|
36 |
-
"""
|
37 |
print("---RETRIEVE---")
|
38 |
question = state["question"]
|
39 |
-
|
40 |
-
# Retrieval
|
41 |
documents = retriever.invoke(question)
|
42 |
return {"documents": documents, "question": question}
|
43 |
|
44 |
-
|
45 |
def generate(state):
|
46 |
-
"""
|
47 |
-
Generate answer
|
48 |
-
|
49 |
-
Args:
|
50 |
-
state (dict): The current graph state
|
51 |
-
|
52 |
-
Returns:
|
53 |
-
state (dict): New key added to state, generation, that contains LLM generation
|
54 |
-
"""
|
55 |
print("---GENERATE---")
|
56 |
question = state["question"]
|
57 |
documents = state["documents"]
|
58 |
chat_history = state.get("chat_history", [])
|
59 |
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
# RAG generation
|
64 |
-
generation = rag_chain.invoke({
|
65 |
-
"context": documents,
|
66 |
-
"question": question,
|
67 |
-
"chat_history": context
|
68 |
-
})
|
69 |
return {
|
70 |
"documents": documents,
|
71 |
"question": question,
|
72 |
-
"generation": generation,
|
73 |
"chat_history": chat_history + [{"role": "human", "content": question}, {"role": "ai", "content": generation}]
|
74 |
}
|
75 |
|
76 |
-
|
77 |
def grade_documents(state):
|
78 |
-
"""
|
79 |
-
Determines whether the retrieved documents are relevant to the question.
|
80 |
-
|
81 |
-
Args:
|
82 |
-
state (dict): The current graph state
|
83 |
-
|
84 |
-
Returns:
|
85 |
-
state (dict): Updates documents key with only filtered relevant documents
|
86 |
-
"""
|
87 |
-
|
88 |
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
|
89 |
question = state["question"]
|
90 |
documents = state["documents"]
|
91 |
|
92 |
-
# Score each doc
|
93 |
filtered_docs = []
|
94 |
for d in documents:
|
95 |
score = retrieval_grader.invoke(
|
@@ -104,45 +106,20 @@ def grade_documents(state):
|
|
104 |
continue
|
105 |
return {"documents": filtered_docs, "question": question}
|
106 |
|
107 |
-
|
108 |
def transform_query(state):
|
109 |
-
"""
|
110 |
-
Transform the query to produce a better question.
|
111 |
-
|
112 |
-
Args:
|
113 |
-
state (dict): The current graph state
|
114 |
-
|
115 |
-
Returns:
|
116 |
-
state (dict): Updates question key with a re-phrased question
|
117 |
-
"""
|
118 |
-
|
119 |
print("---TRANSFORM QUERY---")
|
120 |
question = state["question"]
|
121 |
documents = state["documents"]
|
122 |
|
123 |
-
# Re-write question
|
124 |
better_question = question_rewriter.invoke({"question": question})
|
125 |
return {"documents": documents, "question": better_question}
|
126 |
|
127 |
-
|
128 |
def web_search(state):
|
129 |
-
"""
|
130 |
-
Web search based on the re-phrased question.
|
131 |
-
|
132 |
-
Args:
|
133 |
-
state (dict): The current graph state
|
134 |
-
|
135 |
-
Returns:
|
136 |
-
state (dict): Updates documents key with appended web results
|
137 |
-
"""
|
138 |
-
|
139 |
print("---WEB SEARCH---")
|
140 |
question = state["question"]
|
141 |
|
142 |
-
# Web search
|
143 |
web_results = web_search_tool.invoke({"query": question})
|
144 |
|
145 |
-
# Check if web_results is a string (single result) or a list of results
|
146 |
if isinstance(web_results, str):
|
147 |
web_results = [{"content": web_results}]
|
148 |
elif isinstance(web_results, list):
|
@@ -155,10 +132,6 @@ def web_search(state):
|
|
155 |
|
156 |
return {"documents": [web_document], "question": question}
|
157 |
|
158 |
-
|
159 |
-
### Edges ###
|
160 |
-
|
161 |
-
|
162 |
def route_question(state):
|
163 |
"""
|
164 |
Route question to web search or RAG.
|
@@ -167,59 +140,47 @@ def route_question(state):
|
|
167 |
state (dict): The current graph state
|
168 |
|
169 |
Returns:
|
170 |
-
|
171 |
"""
|
172 |
|
173 |
print("---ROUTE QUESTION---")
|
174 |
question = state["question"]
|
175 |
source = question_router.invoke({"question": question})
|
|
|
176 |
if source.datasource == "web_search":
|
177 |
print("---ROUTE QUESTION TO WEB SEARCH---")
|
178 |
-
return
|
|
|
|
|
|
|
179 |
elif source.datasource == "vectorstore":
|
180 |
print("---ROUTE QUESTION TO RAG---")
|
181 |
-
return
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
def decide_to_generate(state):
|
185 |
-
"""
|
186 |
-
Determines whether to generate an answer, or re-generate a question.
|
187 |
-
|
188 |
-
Args:
|
189 |
-
state (dict): The current graph state
|
190 |
-
|
191 |
-
Returns:
|
192 |
-
str: Binary decision for next node to call
|
193 |
-
"""
|
194 |
-
|
195 |
print("---ASSESS GRADED DOCUMENTS---")
|
196 |
state["question"]
|
197 |
filtered_documents = state["documents"]
|
198 |
|
199 |
if not filtered_documents:
|
200 |
-
# All documents have been filtered check_relevance
|
201 |
-
# We will re-generate a new query
|
202 |
print(
|
203 |
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
|
204 |
)
|
205 |
return "transform_query"
|
206 |
else:
|
207 |
-
# We have relevant documents, so generate answer
|
208 |
print("---DECISION: GENERATE---")
|
209 |
return "generate"
|
210 |
|
211 |
-
|
212 |
def grade_generation_v_documents_and_question(state):
|
213 |
-
"""
|
214 |
-
Determines whether the generation is grounded in the document and answers question.
|
215 |
-
|
216 |
-
Args:
|
217 |
-
state (dict): The current graph state
|
218 |
-
|
219 |
-
Returns:
|
220 |
-
str: Decision for next node to call
|
221 |
-
"""
|
222 |
-
|
223 |
print("---CHECK HALLUCINATIONS---")
|
224 |
question = state["question"]
|
225 |
documents = state["documents"]
|
@@ -230,11 +191,8 @@ def grade_generation_v_documents_and_question(state):
|
|
230 |
)
|
231 |
grade = score.binary_score
|
232 |
|
233 |
-
# Check hallucination
|
234 |
if grade == "yes":
|
235 |
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
236 |
-
# Check question-answering
|
237 |
-
print("---GRADE GENERATION vs QUESTION---")
|
238 |
score = answer_grader.invoke({"question": question, "generation": generation})
|
239 |
grade = score.binary_score
|
240 |
if grade == "yes":
|
@@ -245,4 +203,16 @@ def grade_generation_v_documents_and_question(state):
|
|
245 |
return "not useful"
|
246 |
else:
|
247 |
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
248 |
-
return "not supported"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from typing_extensions import TypedDict
|
3 |
from src.websearch import *
|
4 |
from src.llm import *
|
5 |
+
from langchain.schema import Document, AIMessage
|
6 |
|
|
|
7 |
class GraphState(TypedDict):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
question: str
|
9 |
generation: str
|
10 |
documents: List[str]
|
11 |
chat_history: List[Dict[str, str]]
|
12 |
+
|
13 |
+
def understand_intent(state):
|
14 |
+
print("---UNDERSTAND INTENT---")
|
15 |
+
question = state["question"].lower()
|
16 |
+
chat_history = state.get("chat_history", [])
|
17 |
|
18 |
+
# context = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history[-2:]])
|
19 |
+
|
20 |
+
intent = intent_classifier.invoke({"question": question})
|
21 |
+
print(f"Intent: {intent}") # Debug print
|
22 |
+
return {"intent": intent, "question": question}
|
23 |
|
24 |
+
def intent_aware_response(state):
|
25 |
+
print("---INTENT-AWARE RESPONSE---")
|
26 |
+
question = state["question"]
|
27 |
+
chat_history = state.get("chat_history", [])
|
28 |
+
intent = state.get("intent", "")
|
29 |
+
|
30 |
+
print(f"Responding to intent: {intent}") # Debug print
|
31 |
+
|
32 |
+
# Check if intent is an IntentClassifier object
|
33 |
+
if hasattr(intent, 'intent'):
|
34 |
+
intent = intent.intent.lower()
|
35 |
+
elif isinstance(intent, str):
|
36 |
+
intent = intent.lower().strip("intent='").rstrip("'")
|
37 |
+
else:
|
38 |
+
print(f"Unexpected intent type: {type(intent)}")
|
39 |
+
intent = "unknown"
|
40 |
+
|
41 |
+
if intent == 'greeting':
|
42 |
+
return "greeting"
|
43 |
+
elif intent == 'off_topic':
|
44 |
+
return "off_topic"
|
45 |
+
elif intent in ["legal_query", "follow_up"]:
|
46 |
+
return "route_question"
|
47 |
+
else:
|
48 |
+
print(f"Unknown intent '{intent}', treating as off-topic")
|
49 |
+
return "off_topic"
|
50 |
|
51 |
def retrieve(state):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
print("---RETRIEVE---")
|
53 |
question = state["question"]
|
|
|
|
|
54 |
documents = retriever.invoke(question)
|
55 |
return {"documents": documents, "question": question}
|
56 |
|
|
|
57 |
def generate(state):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
print("---GENERATE---")
|
59 |
question = state["question"]
|
60 |
documents = state["documents"]
|
61 |
chat_history = state.get("chat_history", [])
|
62 |
|
63 |
+
context = "\n".join([doc.page_content for doc in documents])
|
64 |
+
chat_context = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history[-5:]])
|
65 |
+
|
66 |
+
generation_prompt = f"""
|
67 |
+
As LegalAlly, an AI assistant specializing in the Indian Penal Code, provide a helpful and informative response to the following question. Use the given context and chat history for reference.
|
68 |
+
|
69 |
+
Context:
|
70 |
+
{context}
|
71 |
+
|
72 |
+
Chat History:
|
73 |
+
{chat_context}
|
74 |
+
|
75 |
+
Question: {question}
|
76 |
+
|
77 |
+
Response:
|
78 |
+
"""
|
79 |
+
|
80 |
+
generation = llm.invoke(generation_prompt)
|
81 |
+
generation = generation.content if hasattr(generation, 'content') else str(generation)
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
return {
|
84 |
"documents": documents,
|
85 |
"question": question,
|
86 |
+
"generation": generation,
|
87 |
"chat_history": chat_history + [{"role": "human", "content": question}, {"role": "ai", "content": generation}]
|
88 |
}
|
89 |
|
|
|
90 |
def grade_documents(state):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
|
92 |
question = state["question"]
|
93 |
documents = state["documents"]
|
94 |
|
|
|
95 |
filtered_docs = []
|
96 |
for d in documents:
|
97 |
score = retrieval_grader.invoke(
|
|
|
106 |
continue
|
107 |
return {"documents": filtered_docs, "question": question}
|
108 |
|
|
|
109 |
def transform_query(state):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
print("---TRANSFORM QUERY---")
|
111 |
question = state["question"]
|
112 |
documents = state["documents"]
|
113 |
|
|
|
114 |
better_question = question_rewriter.invoke({"question": question})
|
115 |
return {"documents": documents, "question": better_question}
|
116 |
|
|
|
117 |
def web_search(state):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
print("---WEB SEARCH---")
|
119 |
question = state["question"]
|
120 |
|
|
|
121 |
web_results = web_search_tool.invoke({"query": question})
|
122 |
|
|
|
123 |
if isinstance(web_results, str):
|
124 |
web_results = [{"content": web_results}]
|
125 |
elif isinstance(web_results, list):
|
|
|
132 |
|
133 |
return {"documents": [web_document], "question": question}
|
134 |
|
|
|
|
|
|
|
|
|
135 |
def route_question(state):
|
136 |
"""
|
137 |
Route question to web search or RAG.
|
|
|
140 |
state (dict): The current graph state
|
141 |
|
142 |
Returns:
|
143 |
+
dict: Updated state with routing information
|
144 |
"""
|
145 |
|
146 |
print("---ROUTE QUESTION---")
|
147 |
question = state["question"]
|
148 |
source = question_router.invoke({"question": question})
|
149 |
+
|
150 |
if source.datasource == "web_search":
|
151 |
print("---ROUTE QUESTION TO WEB SEARCH---")
|
152 |
+
return {
|
153 |
+
"route_question": "web_search",
|
154 |
+
"question": question # Maintain the current question
|
155 |
+
}
|
156 |
elif source.datasource == "vectorstore":
|
157 |
print("---ROUTE QUESTION TO RAG---")
|
158 |
+
return {
|
159 |
+
"route_question": "vectorstore",
|
160 |
+
"question": question # Maintain the current question
|
161 |
+
}
|
162 |
+
else:
|
163 |
+
print("---UNKNOWN ROUTE, DEFAULTING TO RAG---")
|
164 |
+
return {
|
165 |
+
"route_question": "vectorstore",
|
166 |
+
"question": question # Maintain the current question
|
167 |
+
}
|
168 |
|
169 |
def decide_to_generate(state):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
print("---ASSESS GRADED DOCUMENTS---")
|
171 |
state["question"]
|
172 |
filtered_documents = state["documents"]
|
173 |
|
174 |
if not filtered_documents:
|
|
|
|
|
175 |
print(
|
176 |
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
|
177 |
)
|
178 |
return "transform_query"
|
179 |
else:
|
|
|
180 |
print("---DECISION: GENERATE---")
|
181 |
return "generate"
|
182 |
|
|
|
183 |
def grade_generation_v_documents_and_question(state):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
print("---CHECK HALLUCINATIONS---")
|
185 |
question = state["question"]
|
186 |
documents = state["documents"]
|
|
|
191 |
)
|
192 |
grade = score.binary_score
|
193 |
|
|
|
194 |
if grade == "yes":
|
195 |
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
|
|
|
|
196 |
score = answer_grader.invoke({"question": question, "generation": generation})
|
197 |
grade = score.binary_score
|
198 |
if grade == "yes":
|
|
|
203 |
return "not useful"
|
204 |
else:
|
205 |
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
206 |
+
return "not supported"
|
207 |
+
|
208 |
+
def greeting(state):
|
209 |
+
print("---GREETING---")
|
210 |
+
return {
|
211 |
+
"generation": "Hello! I'm LegalAlly, an AI assistant specializing in Indian law, particularly the Indian Penal Code and Indian Constitution. How can I assist you today?"
|
212 |
+
}
|
213 |
+
|
214 |
+
def off_topic(state):
|
215 |
+
print("---OFF-TOPIC---")
|
216 |
+
return {
|
217 |
+
"generation": "I apologize, but I specialize in matters related to the Indian Penal Code. Could you please ask a question about Indian law or legal matters?"
|
218 |
+
}
|
src/index.py
CHANGED
@@ -6,6 +6,9 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
6 |
from langchain_community.document_loaders import PyPDFLoader
|
7 |
from langchain_community.vectorstores import Chroma
|
8 |
from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI
|
|
|
|
|
|
|
9 |
|
10 |
# Load environment variables
|
11 |
load_dotenv()
|
@@ -53,7 +56,7 @@ def vector_store_exists(persist_directory):
|
|
53 |
# Load and process documents
|
54 |
try:
|
55 |
print("Loading PDF document...")
|
56 |
-
docs = PyPDFLoader("assets/data/
|
57 |
print("PDF loaded successfully.")
|
58 |
|
59 |
print("Splitting documents...")
|
@@ -66,17 +69,29 @@ except Exception as e:
|
|
66 |
print(f"Error processing documents: {e}")
|
67 |
sys.exit(1)
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
# Create or load vector store
|
70 |
try:
|
71 |
persist_directory = './vectordb'
|
72 |
if not vector_store_exists(persist_directory):
|
73 |
print("Creating new vector store...")
|
74 |
-
vectorstore =
|
75 |
-
documents=doc_splits,
|
76 |
-
collection_name="rag-chroma",
|
77 |
-
embedding=embd,
|
78 |
-
persist_directory=persist_directory
|
79 |
-
)
|
80 |
print("New vector store created and populated.")
|
81 |
else:
|
82 |
print("Loading existing vector store...")
|
@@ -87,7 +102,7 @@ try:
|
|
87 |
)
|
88 |
print("Existing vector store loaded.")
|
89 |
|
90 |
-
retriever = vectorstore.as_retriever(
|
91 |
print("Retriever set up successfully.")
|
92 |
except Exception as e:
|
93 |
print(f"Error with vector store operations: {e}")
|
|
|
6 |
from langchain_community.document_loaders import PyPDFLoader
|
7 |
from langchain_community.vectorstores import Chroma
|
8 |
from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI
|
9 |
+
import time
|
10 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
11 |
+
from tqdm import tqdm # Add this import for progress bar
|
12 |
|
13 |
# Load environment variables
|
14 |
load_dotenv()
|
|
|
56 |
# Load and process documents
|
57 |
try:
|
58 |
print("Loading PDF document...")
|
59 |
+
docs = PyPDFLoader("assets/data/IPC_and_Constitution.pdf").load()
|
60 |
print("PDF loaded successfully.")
|
61 |
|
62 |
print("Splitting documents...")
|
|
|
69 |
print(f"Error processing documents: {e}")
|
70 |
sys.exit(1)
|
71 |
|
72 |
+
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10))
|
73 |
+
def create_vector_store_batch(persist_directory, documents, embedding, batch_size=50):
|
74 |
+
vectorstore = None
|
75 |
+
for i in tqdm(range(0, len(documents), batch_size), desc="Processing batches"):
|
76 |
+
batch = documents[i:i+batch_size]
|
77 |
+
if vectorstore is None:
|
78 |
+
vectorstore = Chroma.from_documents(
|
79 |
+
documents=batch,
|
80 |
+
collection_name="rag-chroma",
|
81 |
+
embedding=embedding,
|
82 |
+
persist_directory=persist_directory
|
83 |
+
)
|
84 |
+
else:
|
85 |
+
vectorstore.add_documents(batch)
|
86 |
+
time.sleep(1) # Add a small delay between batches
|
87 |
+
return vectorstore
|
88 |
+
|
89 |
# Create or load vector store
|
90 |
try:
|
91 |
persist_directory = './vectordb'
|
92 |
if not vector_store_exists(persist_directory):
|
93 |
print("Creating new vector store...")
|
94 |
+
vectorstore = create_vector_store_batch(persist_directory, doc_splits, embd)
|
|
|
|
|
|
|
|
|
|
|
95 |
print("New vector store created and populated.")
|
96 |
else:
|
97 |
print("Loading existing vector store...")
|
|
|
102 |
)
|
103 |
print("Existing vector store loaded.")
|
104 |
|
105 |
+
retriever = vectorstore.as_retriever()
|
106 |
print("Retriever set up successfully.")
|
107 |
except Exception as e:
|
108 |
print(f"Error with vector store operations: {e}")
|
src/llm.py
CHANGED
@@ -24,11 +24,11 @@ class RouteQuery(BaseModel):
|
|
24 |
# llm = AzureChatOpenAI(model="gpt-4o-mini", temperature=0.3)
|
25 |
structured_llm_router = llm.with_structured_output(RouteQuery)
|
26 |
|
27 |
-
#%%
|
28 |
# Prompt
|
29 |
system = """You are an expert at routing a user question to a vectorstore or web search.
|
30 |
-
The vectorstore contains documents related to
|
31 |
-
|
|
|
32 |
route_prompt = ChatPromptTemplate.from_messages(
|
33 |
[
|
34 |
("system", system),
|
@@ -36,7 +36,6 @@ route_prompt = ChatPromptTemplate.from_messages(
|
|
36 |
]
|
37 |
)
|
38 |
|
39 |
-
#%%
|
40 |
question_router = route_prompt | structured_llm_router
|
41 |
|
42 |
# %%
|
@@ -50,8 +49,6 @@ class GradeDocuments(BaseModel):
|
|
50 |
description="Documents are relevant to the question, 'yes' or 'no'"
|
51 |
)
|
52 |
|
53 |
-
|
54 |
-
#%%
|
55 |
# LLM with function call
|
56 |
# llm = AzureChatOpenAI(model="gpt-4o-mini", temperature=0.3)
|
57 |
structured_llm_grader = llm.with_structured_output(GradeDocuments)
|
@@ -69,10 +66,12 @@ grade_prompt = ChatPromptTemplate.from_messages(
|
|
69 |
)
|
70 |
|
71 |
retrieval_grader = grade_prompt | structured_llm_grader
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
76 |
|
77 |
#%%
|
78 |
|
@@ -95,9 +94,9 @@ def format_docs(docs):
|
|
95 |
# Chain
|
96 |
rag_chain = prompt | llm | StrOutputParser()
|
97 |
|
98 |
-
# Run
|
99 |
-
generation = rag_chain.invoke({"context": docs, "question": question})
|
100 |
-
print(generation)
|
101 |
|
102 |
#%%
|
103 |
|
@@ -128,7 +127,7 @@ hallucination_prompt = ChatPromptTemplate.from_messages(
|
|
128 |
)
|
129 |
|
130 |
hallucination_grader = hallucination_prompt | structured_llm_grader
|
131 |
-
hallucination_grader.invoke({"documents": docs, "generation": generation})
|
132 |
|
133 |
#%%
|
134 |
### Answer Grader
|
@@ -158,7 +157,7 @@ answer_prompt = ChatPromptTemplate.from_messages(
|
|
158 |
)
|
159 |
|
160 |
answer_grader = answer_prompt | structured_llm_grader
|
161 |
-
answer_grader.invoke({"question": question, "generation": generation})
|
162 |
|
163 |
#%%
|
164 |
### Question Re-writer
|
@@ -180,4 +179,29 @@ re_write_prompt = ChatPromptTemplate.from_messages(
|
|
180 |
)
|
181 |
|
182 |
question_rewriter = re_write_prompt | llm | StrOutputParser()
|
183 |
-
question_rewriter.invoke({"question": question})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
# llm = AzureChatOpenAI(model="gpt-4o-mini", temperature=0.3)
|
25 |
structured_llm_router = llm.with_structured_output(RouteQuery)
|
26 |
|
|
|
27 |
# Prompt
|
28 |
system = """You are an expert at routing a user question to a vectorstore or web search.
|
29 |
+
The vectorstore contains documents related to Indian Penal Code and The Indian Constitution.
|
30 |
+
It can answer most questions related to IPC and the Constitution.
|
31 |
+
Use web-search if the answer is not in the vectorstore."""
|
32 |
route_prompt = ChatPromptTemplate.from_messages(
|
33 |
[
|
34 |
("system", system),
|
|
|
36 |
]
|
37 |
)
|
38 |
|
|
|
39 |
question_router = route_prompt | structured_llm_router
|
40 |
|
41 |
# %%
|
|
|
49 |
description="Documents are relevant to the question, 'yes' or 'no'"
|
50 |
)
|
51 |
|
|
|
|
|
52 |
# LLM with function call
|
53 |
# llm = AzureChatOpenAI(model="gpt-4o-mini", temperature=0.3)
|
54 |
structured_llm_grader = llm.with_structured_output(GradeDocuments)
|
|
|
66 |
)
|
67 |
|
68 |
retrieval_grader = grade_prompt | structured_llm_grader
|
69 |
+
|
70 |
+
|
71 |
+
# question = "agent memory"
|
72 |
+
# docs = retriever.invoke(question)
|
73 |
+
# doc_txt = docs[1].page_content
|
74 |
+
# print(retrieval_grader.invoke({"question": question, "document": doc_txt}))
|
75 |
|
76 |
#%%
|
77 |
|
|
|
94 |
# Chain
|
95 |
rag_chain = prompt | llm | StrOutputParser()
|
96 |
|
97 |
+
# # Run
|
98 |
+
# generation = rag_chain.invoke({"context": docs, "question": question})
|
99 |
+
# print(generation)
|
100 |
|
101 |
#%%
|
102 |
|
|
|
127 |
)
|
128 |
|
129 |
hallucination_grader = hallucination_prompt | structured_llm_grader
|
130 |
+
# hallucination_grader.invoke({"documents": docs, "generation": generation})
|
131 |
|
132 |
#%%
|
133 |
### Answer Grader
|
|
|
157 |
)
|
158 |
|
159 |
answer_grader = answer_prompt | structured_llm_grader
|
160 |
+
# answer_grader.invoke({"question": question, "generation": generation})
|
161 |
|
162 |
#%%
|
163 |
### Question Re-writer
|
|
|
179 |
)
|
180 |
|
181 |
question_rewriter = re_write_prompt | llm | StrOutputParser()
|
182 |
+
# question_rewriter.invoke({"question": question})
|
183 |
+
|
184 |
+
class IntentClassifier(BaseModel):
|
185 |
+
"""Classify the intent of the user query."""
|
186 |
+
|
187 |
+
intent: Literal["greeting", "legal_query", "follow_up", "off_topic"] = Field(
|
188 |
+
...,
|
189 |
+
description="Classify the intent of the user query. 'greeting' if the user is saying greetings, 'legal_query' if the user is asking for information about law, 'follow_up' if the user is asking for information related to the previous conversation, 'off_topic' if the user is asking for information about anything else.",
|
190 |
+
)
|
191 |
+
|
192 |
+
# LLM with function call
|
193 |
+
# llm = AzureChatOpenAI(model="gpt-4o-mini", temperature=0.3)
|
194 |
+
structured_llm_intent_classifier = llm.with_structured_output(IntentClassifier)
|
195 |
+
|
196 |
+
# Prompt
|
197 |
+
system = """You are an intent classifier that classifies the intent of a user query. \n
|
198 |
+
Give the intent as one of the following: 'greeting', 'legal_query', 'follow_up', 'off_topic'."""
|
199 |
+
intent_classifier_prompt = ChatPromptTemplate.from_messages(
|
200 |
+
[
|
201 |
+
("system", system),
|
202 |
+
("human", "Here is the user query: \n\n {question} \n\n Classify the intent of the user query."),
|
203 |
+
]
|
204 |
+
)
|
205 |
+
|
206 |
+
intent_classifier = intent_classifier_prompt | structured_llm_intent_classifier
|
207 |
+
|
vectordb/{08d73b15-e800-45c5-a450-5b9d696166f3 β 65ba2328-ffa1-497d-b641-c6b84db7f0e1}/data_level0.bin
RENAMED
File without changes
|
vectordb/{08d73b15-e800-45c5-a450-5b9d696166f3 β 65ba2328-ffa1-497d-b641-c6b84db7f0e1}/header.bin
RENAMED
File without changes
|
vectordb/{08d73b15-e800-45c5-a450-5b9d696166f3 β 65ba2328-ffa1-497d-b641-c6b84db7f0e1}/length.bin
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4000
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f67fade90d336844894516fb804b85fd8b744c00595381c3203e9fd8f5db576b
|
3 |
size 4000
|
vectordb/{08d73b15-e800-45c5-a450-5b9d696166f3 β 65ba2328-ffa1-497d-b641-c6b84db7f0e1}/link_lists.bin
RENAMED
File without changes
|
vectordb/chroma.sqlite3
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aab4817b7e371b5ddea619c2acb598da4c22f8b8a47e32fd84528a50018b8668
|
3 |
+
size 13512704
|