Update interim.py
Browse files- interim.py +13 -5
interim.py
CHANGED
@@ -54,13 +54,15 @@ class RAGWorkflowBasic(Workflow):
|
|
54 |
class RAGWorkflowRerank(RAGWorkflowBasic):
|
55 |
@step
|
56 |
def rerank(self, ctx: Context, ev: RetrieverEvent) -> RerankEvent:
|
57 |
-
#
|
58 |
-
|
59 |
-
|
|
|
60 |
llm_rerank = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192"))
|
61 |
-
#
|
62 |
reranked_nodes = llm_rerank.postprocess_nodes(ev.nodes, query_bundle)
|
63 |
return RerankEvent(nodes=reranked_nodes)
|
|
|
64 |
|
65 |
|
66 |
# Function to Visualize Workflows
|
@@ -96,24 +98,30 @@ elif data_source == "Provide PDF URL":
|
|
96 |
if st.button("Run Workflow"):
|
97 |
if os.listdir(temp_dir):
|
98 |
st.write("### Step 1: Ingesting Documents...")
|
|
|
99 |
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
|
100 |
documents = SimpleDirectoryReader(temp_dir).load_data()
|
101 |
index = VectorStoreIndex.from_documents(documents, embed_model=embed_model)
|
102 |
|
103 |
st.write("### Step 2: Retrieving Documents...")
|
|
|
104 |
retriever = index.as_retriever(similarity_top_k=2)
|
105 |
nodes = retriever.retrieve(query)
|
106 |
|
107 |
if workflow_choice == "Workflow with Reranking":
|
108 |
st.write("### Step 3: Reranking Results...")
|
|
|
|
|
109 |
reranker = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192"))
|
110 |
-
nodes = reranker.postprocess_nodes(nodes,
|
111 |
|
112 |
st.write("### Step 4: Synthesizing Response...")
|
|
|
113 |
summarizer = CompactAndRefine(llm=Groq(model="llama3-70b-8192"))
|
114 |
response = summarizer.synthesize(query, nodes=nodes)
|
115 |
st.markdown(f"### **Response:** {response}")
|
116 |
|
|
|
117 |
st.write("### Workflow Visualization")
|
118 |
workflow_class = RAGWorkflowRerank if workflow_choice == "Workflow with Reranking" else RAGWorkflowBasic
|
119 |
visualize_workflow(workflow_class, "workflow.html")
|
|
|
54 |
class RAGWorkflowRerank(RAGWorkflowBasic):
|
55 |
@step
|
56 |
def rerank(self, ctx: Context, ev: RetrieverEvent) -> RerankEvent:
|
57 |
+
# Construct QueryBundle properly using the query string
|
58 |
+
query_str = ctx.get("query") # Retrieve the query from the context
|
59 |
+
query_bundle = QueryBundle(query_str=query_str)
|
60 |
+
# Initialize the LLM reranker
|
61 |
llm_rerank = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192"))
|
62 |
+
# Perform the reranking using QueryBundle
|
63 |
reranked_nodes = llm_rerank.postprocess_nodes(ev.nodes, query_bundle)
|
64 |
return RerankEvent(nodes=reranked_nodes)
|
65 |
+
|
66 |
|
67 |
|
68 |
# Function to Visualize Workflows
|
|
|
98 |
if st.button("Run Workflow"):
|
99 |
if os.listdir(temp_dir):
|
100 |
st.write("### Step 1: Ingesting Documents...")
|
101 |
+
# Step 1: Ingest Documents
|
102 |
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
|
103 |
documents = SimpleDirectoryReader(temp_dir).load_data()
|
104 |
index = VectorStoreIndex.from_documents(documents, embed_model=embed_model)
|
105 |
|
106 |
st.write("### Step 2: Retrieving Documents...")
|
107 |
+
# Step 2: Retrieve Documents
|
108 |
retriever = index.as_retriever(similarity_top_k=2)
|
109 |
nodes = retriever.retrieve(query)
|
110 |
|
111 |
if workflow_choice == "Workflow with Reranking":
|
112 |
st.write("### Step 3: Reranking Results...")
|
113 |
+
# Step 3: Wrap query into QueryBundle and rerank
|
114 |
+
query_bundle = QueryBundle(query_str=query) # Wrap query into QueryBundle
|
115 |
reranker = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192"))
|
116 |
+
nodes = reranker.postprocess_nodes(nodes, query_bundle)
|
117 |
|
118 |
st.write("### Step 4: Synthesizing Response...")
|
119 |
+
# Step 4: Synthesize Response
|
120 |
summarizer = CompactAndRefine(llm=Groq(model="llama3-70b-8192"))
|
121 |
response = summarizer.synthesize(query, nodes=nodes)
|
122 |
st.markdown(f"### **Response:** {response}")
|
123 |
|
124 |
+
# Workflow Visualization
|
125 |
st.write("### Workflow Visualization")
|
126 |
workflow_class = RAGWorkflowRerank if workflow_choice == "Workflow with Reranking" else RAGWorkflowBasic
|
127 |
visualize_workflow(workflow_class, "workflow.html")
|