DrishtiSharma commited on
Commit
76742d2
·
verified ·
1 Parent(s): 7122d79

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from langchain_core.prompts import ChatPromptTemplate
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from langchain.schema import Document
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain_community.tools.tavily_search import TavilySearchResults
8
+ from langgraph.graph import StateGraph, END
9
+ from graphviz import Digraph # For workflow visualization
10
+ from typing_extensions import TypedDict
11
+ from typing import List
12
+ from utils.build_rag import RAG
13
+
14
+
15
+ # Fetch API Keys
16
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
17
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
18
+
19
+ # Check for Missing API Keys
20
+ if not OPENAI_API_KEY or not TAVILY_API_KEY:
21
+ st.error("❌ API keys missing! Please set `OPENAI_API_KEY` and `TAVILY_API_KEY` in your `.env` file.")
22
+ st.stop() # Stop the app execution
23
+
24
+ # Set up LLM and Tools
25
+ llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY)
26
+ web_search_tool = TavilySearchResults(api_key=TAVILY_API_KEY, k=2)
27
+
28
+ # Prompt templates
29
+ def get_prompt():
30
+ template = """Answer the question based only on the following context:
31
+ {context}
32
+
33
+ Question: {question}
34
+ """
35
+ return ChatPromptTemplate.from_template(template)
36
+
37
+ # Define Graph State
38
+ class GraphState(TypedDict):
39
+ question: str
40
+ generation: str
41
+ web_search: str
42
+ documents: List[Document]
43
+
44
+ # RAG Setup
45
+ rag = RAG()
46
+ retriever = rag.get_retriever()
47
+ prompt = get_prompt()
48
+ output_parser = StrOutputParser()
49
+
50
+ # Nodes
51
+ def retrieve(state):
52
+ question = state["question"]
53
+ documents = retriever.get_relevant_documents(question)
54
+ st.sidebar.write(f"Retrieved Documents: {len(documents)}")
55
+ return {"documents": documents, "question": question}
56
+
57
+ def grade_documents(state):
58
+ question = state["question"]
59
+ documents = state["documents"]
60
+ filtered_docs = []
61
+ web_search = "No"
62
+
63
+ for doc in documents:
64
+ score = {"binary_score": "yes"} # Dummy grader; integrate as needed
65
+ if score["binary_score"] == "yes":
66
+ filtered_docs.append(doc)
67
+ else:
68
+ web_search = "Yes"
69
+ st.sidebar.write(f"Document Grading Results: {len(filtered_docs)} relevant")
70
+ return {"documents": filtered_docs, "web_search": web_search, "question": question}
71
+
72
+ def generate(state):
73
+ context = "\n".join([doc.page_content for doc in state["documents"]])
74
+ response = output_parser.parse(llm.invoke({"context": context, "question": state["question"]}).content)
75
+ return {"generation": response}
76
+
77
+ def transform_query(state):
78
+ question = state["question"]
79
+ new_question = llm.invoke(f"Rewrite: {question}").content
80
+ st.sidebar.write(f"Rewritten Question: {new_question}")
81
+ return {"question": new_question}
82
+
83
+ def web_search(state):
84
+ question = state["question"]
85
+ results = web_search_tool.invoke({"query": question})
86
+ docs = "\n".join([result["content"] for result in results])
87
+ state["documents"].append(Document(page_content=docs))
88
+ st.sidebar.write("Web Search Completed")
89
+ return {"documents": state["documents"], "question": question}
90
+
91
+ def decide_to_generate(state):
92
+ return "generate" if state["web_search"] == "No" else "transform_query"
93
+
94
+ # Build Graph
95
+ workflow = StateGraph(GraphState)
96
+ workflow.add_node("retrieve", retrieve)
97
+ workflow.add_node("grade_documents", grade_documents)
98
+ workflow.add_node("generate", generate)
99
+ workflow.add_node("transform_query", transform_query)
100
+ workflow.add_node("web_search_node", web_search)
101
+
102
+ workflow.set_entry_point("retrieve")
103
+ workflow.add_edge("retrieve", "grade_documents")
104
+ workflow.add_conditional_edges("grade_documents", decide_to_generate, {"transform_query": "transform_query", "generate": "generate"})
105
+ workflow.add_edge("transform_query", "web_search_node")
106
+ workflow.add_edge("web_search_node", "generate")
107
+ workflow.add_edge("generate", END)
108
+
109
+ app = workflow.compile()
110
+
111
+ # Visualize Workflows
112
+ def plot_workflow():
113
+ graph = Digraph()
114
+ graph.attr(size='6,6')
115
+
116
+ # Add nodes
117
+ graph.node("retrieve", "Retrieve Documents")
118
+ graph.node("grade_documents", "Grade Documents")
119
+ graph.node("generate", "Generate Answer")
120
+ graph.node("transform_query", "Transform Query")
121
+ graph.node("web_search_node", "Web Search")
122
+ graph.node("END", "End")
123
+
124
+ # Add edges
125
+ graph.edge("retrieve", "grade_documents")
126
+ graph.edge("grade_documents", "generate", label="Relevant Docs")
127
+ graph.edge("grade_documents", "transform_query", label="No Relevant Docs")
128
+ graph.edge("transform_query", "web_search_node")
129
+ graph.edge("web_search_node", "generate")
130
+ graph.edge("generate", "END")
131
+
132
+ return graph
133
+
134
+ # Streamlit App
135
+ st.title("Self-Corrective RAG")
136
+ st.write("### Compare RAG Pipeline Outputs (With and Without Self-Correction)")
137
+
138
+ # Plot Workflow
139
+ st.subheader("Workflow Visualization")
140
+ st.graphviz_chart(plot_workflow().source)
141
+
142
+ # User Input
143
+ question = st.text_input("Enter your question:", "What is Llama2?")
144
+
145
+ if st.button("Run Comparison"):
146
+ # Run Basic RAG
147
+ st.subheader("Without Self-Correction:")
148
+ docs = retriever.invoke(question)
149
+ basic_context = "\n".join([doc.page_content for doc in docs])
150
+ basic_response = output_parser.parse(llm.invoke({"context": basic_context, "question": question}).content)
151
+ st.write(basic_response)
152
+
153
+ # Run Self-Corrective RAG
154
+ st.subheader("With Self-Correction:")
155
+ inputs = {"question": question}
156
+ final_generation = ""
157
+ for output in app.stream(inputs):
158
+ for key, value in output.items():
159
+ if key == "generation":
160
+ final_generation = value
161
+ st.write(final_generation)