|
import os |
|
import tempfile |
|
import llama_index |
|
import streamlit as st |
|
from llama_index.core.workflow import (Event,StartEvent,StopEvent,Workflow,step,Context) |
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
from llama_index.llms.groq import Groq |
|
from llama_index.core import VectorStoreIndex |
|
from llama_index.core.schema import QueryBundle |
|
from llama_index.core.schema import NodeWithScore |
|
from llama_index.core.response_synthesizers import CompactAndRefine |
|
from llama_index.readers.web import SimpleWebPageReader |
|
from llama_index.core import SimpleDirectoryReader |
|
from llama_index.core.schema import NodeWithScore |
|
from llama_index.core.postprocessor.llm_rerank import LLMRerank |
|
from pyvis.network import Network |
|
import chromadb |
|
|
|
|
|
os.environ["GROQ_API_KEY"] = st.secrets["GROQ_API_KEY"] |
|
|
|
|
|
class RetrieverEvent(Event): |
|
nodes: list[NodeWithScore] |
|
|
|
class RerankEvent(Event): |
|
nodes: list[NodeWithScore] |
|
|
|
|
|
class RAGWorkflowBasic(Workflow): |
|
@step |
|
def ingest(self, ctx: Context, ev: StartEvent) -> StopEvent: |
|
dir_path = ev.get("dirname") |
|
documents = SimpleDirectoryReader(dir_path).load_data() |
|
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5") |
|
index = VectorStoreIndex.from_documents(documents, embed_model=embed_model) |
|
return StopEvent(result=index) |
|
|
|
@step |
|
def retrieve(self, ctx: Context, ev: StartEvent) -> RetrieverEvent: |
|
index = ev.get("index") |
|
query = ev.get("query") |
|
retriever = index.as_retriever(similarity_top_k=2) |
|
nodes = retriever.retrieve(query) |
|
return RetrieverEvent(nodes=nodes) |
|
|
|
@step |
|
def synthesize(self, ctx: Context, ev: RetrieverEvent) -> StopEvent: |
|
llm = Groq(model="llama3-70b-8192") |
|
summarizer = CompactAndRefine(llm=llm) |
|
response = summarizer.synthesize(ctx.get("query"), nodes=ev.nodes) |
|
return StopEvent(result=response) |
|
|
|
class RAGWorkflowRerank(RAGWorkflowBasic): |
|
@step |
|
def rerank(self, ctx: Context, ev: RetrieverEvent) -> RerankEvent: |
|
query_str = ctx.get("query") |
|
query_bundle = QueryBundle(query_str=query_str) |
|
llm_rerank = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192")) |
|
reranked_nodes = llm_rerank.postprocess_nodes(ev.nodes, query_bundle) |
|
return RerankEvent(nodes=reranked_nodes) |
|
|
|
|
|
def visualize_workflow(workflow_class, filename): |
|
from llama_index.utils.workflow import draw_all_possible_flows |
|
from pyvis.network import Network |
|
|
|
|
|
net = Network(directed=True, height="750px", width="100%") |
|
|
|
|
|
net.add_node("StopEvent", label="StopEvent", color="#FFA07A", shape="ellipse") |
|
net.add_node("RerankEvent", label="RerankEvent", color="#90EE90", shape="ellipse") |
|
|
|
|
|
try: |
|
draw_all_possible_flows(workflow_class, filename=filename) |
|
except AssertionError: |
|
st.error("Visualization error occurred, manually added RerankEvent node.") |
|
with open(filename, "r") as f: |
|
st.components.v1.html(f.read(), height=700) |
|
|
|
|
|
|
|
st.title("RAG Workflow Experimentation") |
|
workflow_choice = st.selectbox("Choose Workflow:", ["Basic Workflow", "Workflow with Reranking"]) |
|
|
|
|
|
data_source = st.sidebar.radio("Data Source", ["Upload PDF", "Provide PDF URL"]) |
|
query = st.sidebar.text_input("Enter Query", "What is Fibromyalgia?") |
|
temp_dir = tempfile.mkdtemp() |
|
|
|
if data_source == "Upload PDF": |
|
uploaded_files = st.sidebar.file_uploader("Upload PDFs", accept_multiple_files=True, type=["pdf"]) |
|
if uploaded_files: |
|
for file in uploaded_files: |
|
with open(os.path.join(temp_dir, file.name), "wb") as f: |
|
f.write(file.read()) |
|
st.sidebar.success("PDFs Uploaded Successfully!") |
|
elif data_source == "Provide PDF URL": |
|
pdf_url = st.sidebar.text_input("Enter PDF URL") |
|
if st.sidebar.button("Download PDF"): |
|
os.system(f"wget {pdf_url} -O {temp_dir}/downloaded.pdf") |
|
st.sidebar.success("PDF Downloaded Successfully!") |
|
|
|
|
|
if st.button("Run Workflow"): |
|
if os.listdir(temp_dir): |
|
st.write("### Step 1: Ingesting Documents...") |
|
|
|
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5") |
|
documents = SimpleDirectoryReader(temp_dir).load_data() |
|
index = VectorStoreIndex.from_documents(documents, embed_model=embed_model) |
|
|
|
st.write("### Step 2: Retrieving Documents...") |
|
|
|
retriever = index.as_retriever(similarity_top_k=2) |
|
nodes = retriever.retrieve(query) |
|
|
|
if workflow_choice == "Workflow with Reranking": |
|
st.write("### Step 3: Reranking Results...") |
|
|
|
query_bundle = QueryBundle(query_str=query) |
|
reranker = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192")) |
|
nodes = reranker.postprocess_nodes(nodes, query_bundle) |
|
|
|
st.write("### Step 4: Synthesizing Response...") |
|
|
|
summarizer = CompactAndRefine(llm=Groq(model="llama3-70b-8192")) |
|
response = summarizer.synthesize(query, nodes=nodes) |
|
st.markdown(f"### **Response:** {response}") |
|
|
|
|
|
st.write("### Workflow Visualization") |
|
workflow_class = RAGWorkflowRerank if workflow_choice == "Workflow with Reranking" else RAGWorkflowBasic |
|
visualize_workflow(workflow_class, "workflow.html") |
|
else: |
|
st.error("No PDF files found. Upload or download a PDF first.") |
|
|