|
import os |
|
import streamlit as st |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain.schema import Document |
|
from langchain_openai import ChatOpenAI |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langgraph.graph import StateGraph, END |
|
from graphviz import Digraph |
|
from typing_extensions import TypedDict |
|
from typing import List |
|
from utils.build_rag import RAG |
|
|
|
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") |
|
|
|
|
|
if not OPENAI_API_KEY or not TAVILY_API_KEY: |
|
st.error("❌ API keys missing! Please set `OPENAI_API_KEY` and `TAVILY_API_KEY` in your `.env` file.") |
|
st.stop() |
|
|
|
|
|
llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY) |
|
web_search_tool = TavilySearchResults(api_key=TAVILY_API_KEY, k=2) |
|
|
|
|
|
def get_prompt(): |
|
template = """Answer the question based only on the following context: |
|
{context} |
|
|
|
Question: {question} |
|
""" |
|
return ChatPromptTemplate.from_template(template) |
|
|
|
|
|
class GraphState(TypedDict): |
|
question: str |
|
generation: str |
|
web_search: str |
|
documents: List[Document] |
|
|
|
|
|
rag = RAG() |
|
retriever = rag.get_retriever() |
|
prompt = get_prompt() |
|
output_parser = StrOutputParser() |
|
|
|
|
|
def retrieve(state): |
|
question = state["question"] |
|
documents = retriever.get_relevant_documents(question) |
|
st.sidebar.write(f"Retrieved Documents: {len(documents)}") |
|
return {"documents": documents, "question": question} |
|
|
|
def grade_documents(state): |
|
question = state["question"] |
|
documents = state["documents"] |
|
filtered_docs = [] |
|
web_search = "No" |
|
|
|
for doc in documents: |
|
score = {"binary_score": "yes"} |
|
if score["binary_score"] == "yes": |
|
filtered_docs.append(doc) |
|
else: |
|
web_search = "Yes" |
|
st.sidebar.write(f"Document Grading Results: {len(filtered_docs)} relevant") |
|
return {"documents": filtered_docs, "web_search": web_search, "question": question} |
|
|
|
def generate(state): |
|
context = "\n".join([doc.page_content for doc in state["documents"]]) |
|
response = output_parser.parse(llm.invoke({"context": context, "question": state["question"]}).content) |
|
return {"generation": response} |
|
|
|
def transform_query(state): |
|
question = state["question"] |
|
new_question = llm.invoke(f"Rewrite: {question}").content |
|
st.sidebar.write(f"Rewritten Question: {new_question}") |
|
return {"question": new_question} |
|
|
|
def web_search(state): |
|
question = state["question"] |
|
results = web_search_tool.invoke({"query": question}) |
|
docs = "\n".join([result["content"] for result in results]) |
|
state["documents"].append(Document(page_content=docs)) |
|
st.sidebar.write("Web Search Completed") |
|
return {"documents": state["documents"], "question": question} |
|
|
|
def decide_to_generate(state): |
|
return "generate" if state["web_search"] == "No" else "transform_query" |
|
|
|
|
|
workflow = StateGraph(GraphState) |
|
workflow.add_node("retrieve", retrieve) |
|
workflow.add_node("grade_documents", grade_documents) |
|
workflow.add_node("generate", generate) |
|
workflow.add_node("transform_query", transform_query) |
|
workflow.add_node("web_search_node", web_search) |
|
|
|
workflow.set_entry_point("retrieve") |
|
workflow.add_edge("retrieve", "grade_documents") |
|
workflow.add_conditional_edges("grade_documents", decide_to_generate, {"transform_query": "transform_query", "generate": "generate"}) |
|
workflow.add_edge("transform_query", "web_search_node") |
|
workflow.add_edge("web_search_node", "generate") |
|
workflow.add_edge("generate", END) |
|
|
|
app = workflow.compile() |
|
|
|
|
|
def plot_workflow(): |
|
graph = Digraph() |
|
graph.attr(size='6,6') |
|
|
|
|
|
graph.node("retrieve", "Retrieve Documents") |
|
graph.node("grade_documents", "Grade Documents") |
|
graph.node("generate", "Generate Answer") |
|
graph.node("transform_query", "Transform Query") |
|
graph.node("web_search_node", "Web Search") |
|
graph.node("END", "End") |
|
|
|
|
|
graph.edge("retrieve", "grade_documents") |
|
graph.edge("grade_documents", "generate", label="Relevant Docs") |
|
graph.edge("grade_documents", "transform_query", label="No Relevant Docs") |
|
graph.edge("transform_query", "web_search_node") |
|
graph.edge("web_search_node", "generate") |
|
graph.edge("generate", "END") |
|
|
|
return graph |
|
|
|
|
|
st.title("Self-Corrective RAG") |
|
st.write("### Compare RAG Pipeline Outputs (With and Without Self-Correction)") |
|
|
|
|
|
st.subheader("Workflow Visualization") |
|
st.graphviz_chart(plot_workflow().source) |
|
|
|
|
|
question = st.text_input("Enter your question:", "What is Llama2?") |
|
|
|
if st.button("Run Comparison"): |
|
|
|
st.subheader("Without Self-Correction:") |
|
docs = retriever.invoke(question) |
|
basic_context = "\n".join([doc.page_content for doc in docs]) |
|
basic_response = output_parser.parse(llm.invoke({"context": basic_context, "question": question}).content) |
|
st.write(basic_response) |
|
|
|
|
|
st.subheader("With Self-Correction:") |
|
inputs = {"question": question} |
|
final_generation = "" |
|
for output in app.stream(inputs): |
|
for key, value in output.items(): |
|
if key == "generation": |
|
final_generation = value |
|
st.write(final_generation) |
|
|