DrishtiSharma's picture
Update app.py
848e291 verified
import os
import tempfile
import llama_index
import streamlit as st
from llama_index.core.workflow import (Event,StartEvent,StopEvent,Workflow,step,Context)
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.groq import Groq
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import QueryBundle
from llama_index.core.schema import NodeWithScore
from llama_index.core.response_synthesizers import CompactAndRefine
from llama_index.readers.web import SimpleWebPageReader
from llama_index.core import SimpleDirectoryReader
from llama_index.core.schema import NodeWithScore
from llama_index.core.postprocessor.llm_rerank import LLMRerank
from pyvis.network import Network
import chromadb
# Set API Key
os.environ["GROQ_API_KEY"] = st.secrets["GROQ_API_KEY"]
# Define Events
class RetrieverEvent(Event):
nodes: list[NodeWithScore]
class RerankEvent(Event):
nodes: list[NodeWithScore]
# Workflow Variations
class RAGWorkflowBasic(Workflow):
@step
def ingest(self, ctx: Context, ev: StartEvent) -> StopEvent:
dir_path = ev.get("dirname")
documents = SimpleDirectoryReader(dir_path).load_data()
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
index = VectorStoreIndex.from_documents(documents, embed_model=embed_model)
return StopEvent(result=index)
@step
def retrieve(self, ctx: Context, ev: StartEvent) -> RetrieverEvent:
index = ev.get("index")
query = ev.get("query")
retriever = index.as_retriever(similarity_top_k=2)
nodes = retriever.retrieve(query)
return RetrieverEvent(nodes=nodes)
@step
def synthesize(self, ctx: Context, ev: RetrieverEvent) -> StopEvent:
llm = Groq(model="llama3-70b-8192")
summarizer = CompactAndRefine(llm=llm)
response = summarizer.synthesize(ctx.get("query"), nodes=ev.nodes)
return StopEvent(result=response)
class RAGWorkflowRerank(RAGWorkflowBasic):
@step
def rerank(self, ctx: Context, ev: RetrieverEvent) -> RerankEvent:
query_str = ctx.get("query")
query_bundle = QueryBundle(query_str=query_str)
llm_rerank = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192"))
reranked_nodes = llm_rerank.postprocess_nodes(ev.nodes, query_bundle)
return RerankEvent(nodes=reranked_nodes)
# Function to Visualize Workflows
def visualize_workflow(workflow_class, filename):
from llama_index.utils.workflow import draw_all_possible_flows
from pyvis.network import Network
# Manually ensure RerankEvent is included in the graph
net = Network(directed=True, height="750px", width="100%")
# Add StopEvent and RerankEvent nodes manually
net.add_node("StopEvent", label="StopEvent", color="#FFA07A", shape="ellipse")
net.add_node("RerankEvent", label="RerankEvent", color="#90EE90", shape="ellipse")
# Visualize the entire flow
try:
draw_all_possible_flows(workflow_class, filename=filename)
except AssertionError:
st.error("Visualization error occurred, manually added RerankEvent node.")
with open(filename, "r") as f:
st.components.v1.html(f.read(), height=700)
# Streamlit UI
st.title("RAG Workflow Experimentation")
workflow_choice = st.selectbox("Choose Workflow:", ["Basic Workflow", "Workflow with Reranking"])
# Sidebar: Upload or Input PDFs
data_source = st.sidebar.radio("Data Source", ["Upload PDF", "Provide PDF URL"])
query = st.sidebar.text_input("Enter Query", "What is Fibromyalgia?")
temp_dir = tempfile.mkdtemp()
if data_source == "Upload PDF":
uploaded_files = st.sidebar.file_uploader("Upload PDFs", accept_multiple_files=True, type=["pdf"])
if uploaded_files:
for file in uploaded_files:
with open(os.path.join(temp_dir, file.name), "wb") as f:
f.write(file.read())
st.sidebar.success("PDFs Uploaded Successfully!")
elif data_source == "Provide PDF URL":
pdf_url = st.sidebar.text_input("Enter PDF URL")
if st.sidebar.button("Download PDF"):
os.system(f"wget {pdf_url} -O {temp_dir}/downloaded.pdf")
st.sidebar.success("PDF Downloaded Successfully!")
# Run Workflow
if st.button("Run Workflow"):
if os.listdir(temp_dir):
st.write("### Step 1: Ingesting Documents...")
# Step 1: Ingest Documents
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
documents = SimpleDirectoryReader(temp_dir).load_data()
index = VectorStoreIndex.from_documents(documents, embed_model=embed_model)
st.write("### Step 2: Retrieving Documents...")
# Step 2: Retrieve Documents
retriever = index.as_retriever(similarity_top_k=2)
nodes = retriever.retrieve(query)
if workflow_choice == "Workflow with Reranking":
st.write("### Step 3: Reranking Results...")
# Step 3: Wrap query into QueryBundle and rerank
query_bundle = QueryBundle(query_str=query) # Wrap query into QueryBundle
reranker = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192"))
nodes = reranker.postprocess_nodes(nodes, query_bundle)
st.write("### Step 4: Synthesizing Response...")
# Step 4: Synthesize Response
summarizer = CompactAndRefine(llm=Groq(model="llama3-70b-8192"))
response = summarizer.synthesize(query, nodes=nodes)
st.markdown(f"### **Response:** {response}")
# Workflow Visualization
st.write("### Workflow Visualization")
workflow_class = RAGWorkflowRerank if workflow_choice == "Workflow with Reranking" else RAGWorkflowBasic
visualize_workflow(workflow_class, "workflow.html")
else:
st.error("No PDF files found. Upload or download a PDF first.")