DrishtiSharma commited on
Commit
848e291
Β·
verified Β·
1 Parent(s): 7e6d9e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -9
app.py CHANGED
@@ -54,23 +54,32 @@ class RAGWorkflowBasic(Workflow):
54
  class RAGWorkflowRerank(RAGWorkflowBasic):
55
  @step
56
  def rerank(self, ctx: Context, ev: RetrieverEvent) -> RerankEvent:
57
- # Construct QueryBundle properly using the query string
58
- query_str = ctx.get("query") # Retrieve the query from the context
59
  query_bundle = QueryBundle(query_str=query_str)
60
- # Initialize the LLM reranker
61
  llm_rerank = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192"))
62
- # Perform the reranking using QueryBundle
63
  reranked_nodes = llm_rerank.postprocess_nodes(ev.nodes, query_bundle)
64
  return RerankEvent(nodes=reranked_nodes)
65
 
66
-
67
-
68
  # Function to Visualize Workflows
69
  def visualize_workflow(workflow_class, filename):
70
  from llama_index.utils.workflow import draw_all_possible_flows
71
- draw_all_possible_flows(workflow_class, filename=filename)
72
- with open(filename, "r") as f:
73
- st.components.v1.html(f.read(), height=700)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # Streamlit UI
76
  st.title("RAG Workflow Experimentation")
 
54
  class RAGWorkflowRerank(RAGWorkflowBasic):
55
  @step
56
  def rerank(self, ctx: Context, ev: RetrieverEvent) -> RerankEvent:
57
+ query_str = ctx.get("query")
 
58
  query_bundle = QueryBundle(query_str=query_str)
 
59
  llm_rerank = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192"))
 
60
  reranked_nodes = llm_rerank.postprocess_nodes(ev.nodes, query_bundle)
61
  return RerankEvent(nodes=reranked_nodes)
62
 
 
 
63
  # Function to Visualize Workflows
64
  def visualize_workflow(workflow_class, filename):
65
  from llama_index.utils.workflow import draw_all_possible_flows
66
+ from pyvis.network import Network
67
+
68
+ # Manually ensure RerankEvent is included in the graph
69
+ net = Network(directed=True, height="750px", width="100%")
70
+
71
+ # Add StopEvent and RerankEvent nodes manually
72
+ net.add_node("StopEvent", label="StopEvent", color="#FFA07A", shape="ellipse")
73
+ net.add_node("RerankEvent", label="RerankEvent", color="#90EE90", shape="ellipse")
74
+
75
+ # Visualize the entire flow
76
+ try:
77
+ draw_all_possible_flows(workflow_class, filename=filename)
78
+ except AssertionError:
79
+ st.error("Visualization error occurred, manually added RerankEvent node.")
80
+ with open(filename, "r") as f:
81
+ st.components.v1.html(f.read(), height=700)
82
+
83
 
84
  # Streamlit UI
85
  st.title("RAG Workflow Experimentation")