DrishtiSharma commited on
Commit
7e6d9e2
Β·
verified Β·
1 Parent(s): 10624e9

Update interim.py

Browse files
Files changed (1) hide show
  1. interim.py +13 -5
interim.py CHANGED
@@ -54,13 +54,15 @@ class RAGWorkflowBasic(Workflow):
54
  class RAGWorkflowRerank(RAGWorkflowBasic):
55
  @step
56
  def rerank(self, ctx: Context, ev: RetrieverEvent) -> RerankEvent:
57
- # Use query directly as a mock QueryBundle
58
- query_bundle = {"query_str": ctx.get("query")}
59
- # Initialize the reranker
 
60
  llm_rerank = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192"))
61
- # Rerank nodes
62
  reranked_nodes = llm_rerank.postprocess_nodes(ev.nodes, query_bundle)
63
  return RerankEvent(nodes=reranked_nodes)
 
64
 
65
 
66
  # Function to Visualize Workflows
@@ -96,24 +98,30 @@ elif data_source == "Provide PDF URL":
96
  if st.button("Run Workflow"):
97
  if os.listdir(temp_dir):
98
  st.write("### Step 1: Ingesting Documents...")
 
99
  embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
100
  documents = SimpleDirectoryReader(temp_dir).load_data()
101
  index = VectorStoreIndex.from_documents(documents, embed_model=embed_model)
102
 
103
  st.write("### Step 2: Retrieving Documents...")
 
104
  retriever = index.as_retriever(similarity_top_k=2)
105
  nodes = retriever.retrieve(query)
106
 
107
  if workflow_choice == "Workflow with Reranking":
108
  st.write("### Step 3: Reranking Results...")
 
 
109
  reranker = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192"))
110
- nodes = reranker.postprocess_nodes(nodes, query)
111
 
112
  st.write("### Step 4: Synthesizing Response...")
 
113
  summarizer = CompactAndRefine(llm=Groq(model="llama3-70b-8192"))
114
  response = summarizer.synthesize(query, nodes=nodes)
115
  st.markdown(f"### **Response:** {response}")
116
 
 
117
  st.write("### Workflow Visualization")
118
  workflow_class = RAGWorkflowRerank if workflow_choice == "Workflow with Reranking" else RAGWorkflowBasic
119
  visualize_workflow(workflow_class, "workflow.html")
 
54
  class RAGWorkflowRerank(RAGWorkflowBasic):
55
  @step
56
  def rerank(self, ctx: Context, ev: RetrieverEvent) -> RerankEvent:
57
+ # Construct QueryBundle properly using the query string
58
+ query_str = ctx.get("query") # Retrieve the query from the context
59
+ query_bundle = QueryBundle(query_str=query_str)
60
+ # Initialize the LLM reranker
61
  llm_rerank = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192"))
62
+ # Perform the reranking using QueryBundle
63
  reranked_nodes = llm_rerank.postprocess_nodes(ev.nodes, query_bundle)
64
  return RerankEvent(nodes=reranked_nodes)
65
+
66
 
67
 
68
  # Function to Visualize Workflows
 
98
  if st.button("Run Workflow"):
99
  if os.listdir(temp_dir):
100
  st.write("### Step 1: Ingesting Documents...")
101
+ # Step 1: Ingest Documents
102
  embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
103
  documents = SimpleDirectoryReader(temp_dir).load_data()
104
  index = VectorStoreIndex.from_documents(documents, embed_model=embed_model)
105
 
106
  st.write("### Step 2: Retrieving Documents...")
107
+ # Step 2: Retrieve Documents
108
  retriever = index.as_retriever(similarity_top_k=2)
109
  nodes = retriever.retrieve(query)
110
 
111
  if workflow_choice == "Workflow with Reranking":
112
  st.write("### Step 3: Reranking Results...")
113
+ # Step 3: Wrap query into QueryBundle and rerank
114
+ query_bundle = QueryBundle(query_str=query) # Wrap query into QueryBundle
115
  reranker = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192"))
116
+ nodes = reranker.postprocess_nodes(nodes, query_bundle)
117
 
118
  st.write("### Step 4: Synthesizing Response...")
119
+ # Step 4: Synthesize Response
120
  summarizer = CompactAndRefine(llm=Groq(model="llama3-70b-8192"))
121
  response = summarizer.synthesize(query, nodes=nodes)
122
  st.markdown(f"### **Response:** {response}")
123
 
124
+ # Workflow Visualization
125
  st.write("### Workflow Visualization")
126
  workflow_class = RAGWorkflowRerank if workflow_choice == "Workflow with Reranking" else RAGWorkflowBasic
127
  visualize_workflow(workflow_class, "workflow.html")