import os import streamlit as st from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain.schema import Document from langchain_openai import ChatOpenAI from langchain_community.tools.tavily_search import TavilySearchResults from langgraph.graph import StateGraph, END from graphviz import Digraph # For workflow visualization from typing_extensions import TypedDict from typing import List from utils.build_rag import RAG # Fetch API Keys OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") # Check for Missing API Keys if not OPENAI_API_KEY or not TAVILY_API_KEY: st.error("❌ API keys missing! Please set `OPENAI_API_KEY` and `TAVILY_API_KEY` in your `.env` file.") st.stop() # Stop the app execution # Set up LLM and Tools llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY) web_search_tool = TavilySearchResults(api_key=TAVILY_API_KEY, k=2) # Prompt templates def get_prompt(): template = """Answer the question based only on the following context: {context} Question: {question} """ return ChatPromptTemplate.from_template(template) # Define Graph State class GraphState(TypedDict): question: str generation: str web_search: str documents: List[Document] # RAG Setup rag = RAG() retriever = rag.get_retriever() prompt = get_prompt() output_parser = StrOutputParser() # Nodes def retrieve(state): question = state["question"] documents = retriever.get_relevant_documents(question) st.sidebar.write(f"Retrieved Documents: {len(documents)}") return {"documents": documents, "question": question} def grade_documents(state): question = state["question"] documents = state["documents"] filtered_docs = [] web_search = "No" for doc in documents: score = {"binary_score": "yes"} # Dummy grader; integrate as needed if score["binary_score"] == "yes": filtered_docs.append(doc) else: web_search = "Yes" st.sidebar.write(f"Document Grading Results: {len(filtered_docs)} relevant") return {"documents": filtered_docs, "web_search": web_search, "question": question} def generate(state): context = "\n".join([doc.page_content for doc in state["documents"]]) response = output_parser.parse(llm.invoke({"context": context, "question": state["question"]}).content) return {"generation": response} def transform_query(state): question = state["question"] new_question = llm.invoke(f"Rewrite: {question}").content st.sidebar.write(f"Rewritten Question: {new_question}") return {"question": new_question} def web_search(state): question = state["question"] results = web_search_tool.invoke({"query": question}) docs = "\n".join([result["content"] for result in results]) state["documents"].append(Document(page_content=docs)) st.sidebar.write("Web Search Completed") return {"documents": state["documents"], "question": question} def decide_to_generate(state): return "generate" if state["web_search"] == "No" else "transform_query" # Build Graph workflow = StateGraph(GraphState) workflow.add_node("retrieve", retrieve) workflow.add_node("grade_documents", grade_documents) workflow.add_node("generate", generate) workflow.add_node("transform_query", transform_query) workflow.add_node("web_search_node", web_search) workflow.set_entry_point("retrieve") workflow.add_edge("retrieve", "grade_documents") workflow.add_conditional_edges("grade_documents", decide_to_generate, {"transform_query": "transform_query", "generate": "generate"}) workflow.add_edge("transform_query", "web_search_node") workflow.add_edge("web_search_node", "generate") workflow.add_edge("generate", END) app = workflow.compile() # Visualize Workflows def plot_workflow(): graph = Digraph() graph.attr(size='6,6') # Add nodes graph.node("retrieve", "Retrieve Documents") graph.node("grade_documents", "Grade Documents") graph.node("generate", "Generate Answer") graph.node("transform_query", "Transform Query") graph.node("web_search_node", "Web Search") graph.node("END", "End") # Add edges graph.edge("retrieve", "grade_documents") graph.edge("grade_documents", "generate", label="Relevant Docs") graph.edge("grade_documents", "transform_query", label="No Relevant Docs") graph.edge("transform_query", "web_search_node") graph.edge("web_search_node", "generate") graph.edge("generate", "END") return graph # Streamlit App st.title("Self-Corrective RAG") st.write("### Compare RAG Pipeline Outputs (With and Without Self-Correction)") # Plot Workflow st.subheader("Workflow Visualization") st.graphviz_chart(plot_workflow().source) # User Input question = st.text_input("Enter your question:", "What is Llama2?") if st.button("Run Comparison"): # Run Basic RAG st.subheader("Without Self-Correction:") docs = retriever.invoke(question) basic_context = "\n".join([doc.page_content for doc in docs]) basic_response = output_parser.parse(llm.invoke({"context": basic_context, "question": question}).content) st.write(basic_response) # Run Self-Corrective RAG st.subheader("With Self-Correction:") inputs = {"question": question} final_generation = "" for output in app.stream(inputs): for key, value in output.items(): if key == "generation": final_generation = value st.write(final_generation)