Update app.py
Browse files
app.py
CHANGED
@@ -54,23 +54,32 @@ class RAGWorkflowBasic(Workflow):
|
|
54 |
class RAGWorkflowRerank(RAGWorkflowBasic):
|
55 |
@step
|
56 |
def rerank(self, ctx: Context, ev: RetrieverEvent) -> RerankEvent:
|
57 |
-
|
58 |
-
query_str = ctx.get("query") # Retrieve the query from the context
|
59 |
query_bundle = QueryBundle(query_str=query_str)
|
60 |
-
# Initialize the LLM reranker
|
61 |
llm_rerank = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192"))
|
62 |
-
# Perform the reranking using QueryBundle
|
63 |
reranked_nodes = llm_rerank.postprocess_nodes(ev.nodes, query_bundle)
|
64 |
return RerankEvent(nodes=reranked_nodes)
|
65 |
|
66 |
-
|
67 |
-
|
68 |
# Function to Visualize Workflows
|
69 |
def visualize_workflow(workflow_class, filename):
|
70 |
from llama_index.utils.workflow import draw_all_possible_flows
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
# Streamlit UI
|
76 |
st.title("RAG Workflow Experimentation")
|
|
|
54 |
class RAGWorkflowRerank(RAGWorkflowBasic):
|
55 |
@step
|
56 |
def rerank(self, ctx: Context, ev: RetrieverEvent) -> RerankEvent:
|
57 |
+
query_str = ctx.get("query")
|
|
|
58 |
query_bundle = QueryBundle(query_str=query_str)
|
|
|
59 |
llm_rerank = LLMRerank(top_n=2, llm=Groq(model="llama3-70b-8192"))
|
|
|
60 |
reranked_nodes = llm_rerank.postprocess_nodes(ev.nodes, query_bundle)
|
61 |
return RerankEvent(nodes=reranked_nodes)
|
62 |
|
|
|
|
|
63 |
# Function to Visualize Workflows
|
64 |
def visualize_workflow(workflow_class, filename):
|
65 |
from llama_index.utils.workflow import draw_all_possible_flows
|
66 |
+
from pyvis.network import Network
|
67 |
+
|
68 |
+
# Manually ensure RerankEvent is included in the graph
|
69 |
+
net = Network(directed=True, height="750px", width="100%")
|
70 |
+
|
71 |
+
# Add StopEvent and RerankEvent nodes manually
|
72 |
+
net.add_node("StopEvent", label="StopEvent", color="#FFA07A", shape="ellipse")
|
73 |
+
net.add_node("RerankEvent", label="RerankEvent", color="#90EE90", shape="ellipse")
|
74 |
+
|
75 |
+
# Visualize the entire flow
|
76 |
+
try:
|
77 |
+
draw_all_possible_flows(workflow_class, filename=filename)
|
78 |
+
except AssertionError:
|
79 |
+
st.error("Visualization error occurred, manually added RerankEvent node.")
|
80 |
+
with open(filename, "r") as f:
|
81 |
+
st.components.v1.html(f.read(), height=700)
|
82 |
+
|
83 |
|
84 |
# Streamlit UI
|
85 |
st.title("RAG Workflow Experimentation")
|