DrishtiSharma's picture
Create app.py
76742d2 verified
import os
import streamlit as st
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.schema import Document
from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, END
from graphviz import Digraph # For workflow visualization
from typing_extensions import TypedDict
from typing import List
from utils.build_rag import RAG
# Fetch API Keys
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
# Check for Missing API Keys
if not OPENAI_API_KEY or not TAVILY_API_KEY:
st.error("❌ API keys missing! Please set `OPENAI_API_KEY` and `TAVILY_API_KEY` in your `.env` file.")
st.stop() # Stop the app execution
# Set up LLM and Tools
llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY)
web_search_tool = TavilySearchResults(api_key=TAVILY_API_KEY, k=2)
# Prompt templates
def get_prompt():
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
return ChatPromptTemplate.from_template(template)
# Define Graph State
class GraphState(TypedDict):
question: str
generation: str
web_search: str
documents: List[Document]
# RAG Setup
rag = RAG()
retriever = rag.get_retriever()
prompt = get_prompt()
output_parser = StrOutputParser()
# Nodes
def retrieve(state):
question = state["question"]
documents = retriever.get_relevant_documents(question)
st.sidebar.write(f"Retrieved Documents: {len(documents)}")
return {"documents": documents, "question": question}
def grade_documents(state):
question = state["question"]
documents = state["documents"]
filtered_docs = []
web_search = "No"
for doc in documents:
score = {"binary_score": "yes"} # Dummy grader; integrate as needed
if score["binary_score"] == "yes":
filtered_docs.append(doc)
else:
web_search = "Yes"
st.sidebar.write(f"Document Grading Results: {len(filtered_docs)} relevant")
return {"documents": filtered_docs, "web_search": web_search, "question": question}
def generate(state):
context = "\n".join([doc.page_content for doc in state["documents"]])
response = output_parser.parse(llm.invoke({"context": context, "question": state["question"]}).content)
return {"generation": response}
def transform_query(state):
question = state["question"]
new_question = llm.invoke(f"Rewrite: {question}").content
st.sidebar.write(f"Rewritten Question: {new_question}")
return {"question": new_question}
def web_search(state):
question = state["question"]
results = web_search_tool.invoke({"query": question})
docs = "\n".join([result["content"] for result in results])
state["documents"].append(Document(page_content=docs))
st.sidebar.write("Web Search Completed")
return {"documents": state["documents"], "question": question}
def decide_to_generate(state):
return "generate" if state["web_search"] == "No" else "transform_query"
# Build Graph
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("web_search_node", web_search)
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges("grade_documents", decide_to_generate, {"transform_query": "transform_query", "generate": "generate"})
workflow.add_edge("transform_query", "web_search_node")
workflow.add_edge("web_search_node", "generate")
workflow.add_edge("generate", END)
app = workflow.compile()
# Visualize Workflows
def plot_workflow():
graph = Digraph()
graph.attr(size='6,6')
# Add nodes
graph.node("retrieve", "Retrieve Documents")
graph.node("grade_documents", "Grade Documents")
graph.node("generate", "Generate Answer")
graph.node("transform_query", "Transform Query")
graph.node("web_search_node", "Web Search")
graph.node("END", "End")
# Add edges
graph.edge("retrieve", "grade_documents")
graph.edge("grade_documents", "generate", label="Relevant Docs")
graph.edge("grade_documents", "transform_query", label="No Relevant Docs")
graph.edge("transform_query", "web_search_node")
graph.edge("web_search_node", "generate")
graph.edge("generate", "END")
return graph
# Streamlit App
st.title("Self-Corrective RAG")
st.write("### Compare RAG Pipeline Outputs (With and Without Self-Correction)")
# Plot Workflow
st.subheader("Workflow Visualization")
st.graphviz_chart(plot_workflow().source)
# User Input
question = st.text_input("Enter your question:", "What is Llama2?")
if st.button("Run Comparison"):
# Run Basic RAG
st.subheader("Without Self-Correction:")
docs = retriever.invoke(question)
basic_context = "\n".join([doc.page_content for doc in docs])
basic_response = output_parser.parse(llm.invoke({"context": basic_context, "question": question}).content)
st.write(basic_response)
# Run Self-Corrective RAG
st.subheader("With Self-Correction:")
inputs = {"question": question}
final_generation = ""
for output in app.stream(inputs):
for key, value in output.items():
if key == "generation":
final_generation = value
st.write(final_generation)