File size: 5,561 Bytes
76742d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import streamlit as st
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.schema import Document
from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, END
from graphviz import Digraph  # For workflow visualization
from typing_extensions import TypedDict
from typing import List
from utils.build_rag import RAG


# Fetch API Keys
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")

# Check for Missing API Keys
if not OPENAI_API_KEY or not TAVILY_API_KEY:
    st.error("❌ API keys missing! Please set `OPENAI_API_KEY` and `TAVILY_API_KEY` in your `.env` file.")
    st.stop()  # Stop the app execution

# Set up LLM and Tools
llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY)
web_search_tool = TavilySearchResults(api_key=TAVILY_API_KEY, k=2)

# Prompt templates
def get_prompt():
    template = """Answer the question based only on the following context:
    {context}

    Question: {question}
    """
    return ChatPromptTemplate.from_template(template)

# Define Graph State
class GraphState(TypedDict):
    question: str
    generation: str
    web_search: str
    documents: List[Document]

# RAG Setup
rag = RAG()
retriever = rag.get_retriever()
prompt = get_prompt()
output_parser = StrOutputParser()

# Nodes
def retrieve(state):
    question = state["question"]
    documents = retriever.get_relevant_documents(question)
    st.sidebar.write(f"Retrieved Documents: {len(documents)}")
    return {"documents": documents, "question": question}

def grade_documents(state):
    question = state["question"]
    documents = state["documents"]
    filtered_docs = []
    web_search = "No"

    for doc in documents:
        score = {"binary_score": "yes"}  # Dummy grader; integrate as needed
        if score["binary_score"] == "yes":
            filtered_docs.append(doc)
        else:
            web_search = "Yes"
    st.sidebar.write(f"Document Grading Results: {len(filtered_docs)} relevant")
    return {"documents": filtered_docs, "web_search": web_search, "question": question}

def generate(state):
    context = "\n".join([doc.page_content for doc in state["documents"]])
    response = output_parser.parse(llm.invoke({"context": context, "question": state["question"]}).content)
    return {"generation": response}

def transform_query(state):
    question = state["question"]
    new_question = llm.invoke(f"Rewrite: {question}").content
    st.sidebar.write(f"Rewritten Question: {new_question}")
    return {"question": new_question}

def web_search(state):
    question = state["question"]
    results = web_search_tool.invoke({"query": question})
    docs = "\n".join([result["content"] for result in results])
    state["documents"].append(Document(page_content=docs))
    st.sidebar.write("Web Search Completed")
    return {"documents": state["documents"], "question": question}

def decide_to_generate(state):
    return "generate" if state["web_search"] == "No" else "transform_query"

# Build Graph
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("web_search_node", web_search)

workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges("grade_documents", decide_to_generate, {"transform_query": "transform_query", "generate": "generate"})
workflow.add_edge("transform_query", "web_search_node")
workflow.add_edge("web_search_node", "generate")
workflow.add_edge("generate", END)

app = workflow.compile()

# Visualize Workflows
def plot_workflow():
    graph = Digraph()
    graph.attr(size='6,6')

    # Add nodes
    graph.node("retrieve", "Retrieve Documents")
    graph.node("grade_documents", "Grade Documents")
    graph.node("generate", "Generate Answer")
    graph.node("transform_query", "Transform Query")
    graph.node("web_search_node", "Web Search")
    graph.node("END", "End")

    # Add edges
    graph.edge("retrieve", "grade_documents")
    graph.edge("grade_documents", "generate", label="Relevant Docs")
    graph.edge("grade_documents", "transform_query", label="No Relevant Docs")
    graph.edge("transform_query", "web_search_node")
    graph.edge("web_search_node", "generate")
    graph.edge("generate", "END")

    return graph

# Streamlit App
st.title("Self-Corrective RAG")
st.write("### Compare RAG Pipeline Outputs (With and Without Self-Correction)")

# Plot Workflow
st.subheader("Workflow Visualization")
st.graphviz_chart(plot_workflow().source)

# User Input
question = st.text_input("Enter your question:", "What is Llama2?")

if st.button("Run Comparison"):
    # Run Basic RAG
    st.subheader("Without Self-Correction:")
    docs = retriever.invoke(question)
    basic_context = "\n".join([doc.page_content for doc in docs])
    basic_response = output_parser.parse(llm.invoke({"context": basic_context, "question": question}).content)
    st.write(basic_response)

    # Run Self-Corrective RAG
    st.subheader("With Self-Correction:")
    inputs = {"question": question}
    final_generation = ""
    for output in app.stream(inputs):
        for key, value in output.items():
            if key == "generation":
                final_generation = value
    st.write(final_generation)