Spaces:
Sleeping
Sleeping
File size: 5,561 Bytes
76742d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import os
import streamlit as st
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.schema import Document
from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, END
from graphviz import Digraph # For workflow visualization
from typing_extensions import TypedDict
from typing import List
from utils.build_rag import RAG
# Fetch API Keys
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
# Check for Missing API Keys
if not OPENAI_API_KEY or not TAVILY_API_KEY:
st.error("❌ API keys missing! Please set `OPENAI_API_KEY` and `TAVILY_API_KEY` in your `.env` file.")
st.stop() # Stop the app execution
# Set up LLM and Tools
llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY)
web_search_tool = TavilySearchResults(api_key=TAVILY_API_KEY, k=2)
# Prompt templates
def get_prompt():
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
return ChatPromptTemplate.from_template(template)
# Define Graph State
class GraphState(TypedDict):
question: str
generation: str
web_search: str
documents: List[Document]
# RAG Setup
rag = RAG()
retriever = rag.get_retriever()
prompt = get_prompt()
output_parser = StrOutputParser()
# Nodes
def retrieve(state):
question = state["question"]
documents = retriever.get_relevant_documents(question)
st.sidebar.write(f"Retrieved Documents: {len(documents)}")
return {"documents": documents, "question": question}
def grade_documents(state):
question = state["question"]
documents = state["documents"]
filtered_docs = []
web_search = "No"
for doc in documents:
score = {"binary_score": "yes"} # Dummy grader; integrate as needed
if score["binary_score"] == "yes":
filtered_docs.append(doc)
else:
web_search = "Yes"
st.sidebar.write(f"Document Grading Results: {len(filtered_docs)} relevant")
return {"documents": filtered_docs, "web_search": web_search, "question": question}
def generate(state):
context = "\n".join([doc.page_content for doc in state["documents"]])
response = output_parser.parse(llm.invoke({"context": context, "question": state["question"]}).content)
return {"generation": response}
def transform_query(state):
question = state["question"]
new_question = llm.invoke(f"Rewrite: {question}").content
st.sidebar.write(f"Rewritten Question: {new_question}")
return {"question": new_question}
def web_search(state):
question = state["question"]
results = web_search_tool.invoke({"query": question})
docs = "\n".join([result["content"] for result in results])
state["documents"].append(Document(page_content=docs))
st.sidebar.write("Web Search Completed")
return {"documents": state["documents"], "question": question}
def decide_to_generate(state):
return "generate" if state["web_search"] == "No" else "transform_query"
# Build Graph
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("web_search_node", web_search)
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges("grade_documents", decide_to_generate, {"transform_query": "transform_query", "generate": "generate"})
workflow.add_edge("transform_query", "web_search_node")
workflow.add_edge("web_search_node", "generate")
workflow.add_edge("generate", END)
app = workflow.compile()
# Visualize Workflows
def plot_workflow():
graph = Digraph()
graph.attr(size='6,6')
# Add nodes
graph.node("retrieve", "Retrieve Documents")
graph.node("grade_documents", "Grade Documents")
graph.node("generate", "Generate Answer")
graph.node("transform_query", "Transform Query")
graph.node("web_search_node", "Web Search")
graph.node("END", "End")
# Add edges
graph.edge("retrieve", "grade_documents")
graph.edge("grade_documents", "generate", label="Relevant Docs")
graph.edge("grade_documents", "transform_query", label="No Relevant Docs")
graph.edge("transform_query", "web_search_node")
graph.edge("web_search_node", "generate")
graph.edge("generate", "END")
return graph
# Streamlit App
st.title("Self-Corrective RAG")
st.write("### Compare RAG Pipeline Outputs (With and Without Self-Correction)")
# Plot Workflow
st.subheader("Workflow Visualization")
st.graphviz_chart(plot_workflow().source)
# User Input
question = st.text_input("Enter your question:", "What is Llama2?")
if st.button("Run Comparison"):
# Run Basic RAG
st.subheader("Without Self-Correction:")
docs = retriever.invoke(question)
basic_context = "\n".join([doc.page_content for doc in docs])
basic_response = output_parser.parse(llm.invoke({"context": basic_context, "question": question}).content)
st.write(basic_response)
# Run Self-Corrective RAG
st.subheader("With Self-Correction:")
inputs = {"question": question}
final_generation = ""
for output in app.stream(inputs):
for key, value in output.items():
if key == "generation":
final_generation = value
st.write(final_generation)
|