Rahul-8799 commited on
Commit
e26200a
Β·
verified Β·
1 Parent(s): da93d0d

Update utils/langgraph_pipeline.py

Browse files
Files changed (1) hide show
  1. utils/langgraph_pipeline.py +22 -33
utils/langgraph_pipeline.py CHANGED
@@ -12,9 +12,6 @@ from langchain_core.messages.base import BaseMessage
12
  from typing import TypedDict, List
13
 
14
 
15
- # --------------------------
16
- # State Definitions
17
- # --------------------------
18
  class InputState(TypedDict):
19
  messages: List[BaseMessage]
20
  chat_log: list
@@ -29,58 +26,49 @@ class OutputState(TypedDict):
29
  chat_log: list
30
 
31
 
32
- # --------------------------
33
- # Correctly Implemented Bridge Node
34
- # --------------------------
35
- def bridge_to_product_manager(state: dict) -> dict:
36
- """Ensure AIMessage is added after HumanMessage."""
37
  messages = state.get("messages", [])
38
  chat_log = state.get("chat_log", [])
39
 
40
  if not messages or not isinstance(messages[-1], HumanMessage):
41
  raise ValueError("Expected last message to be a HumanMessage")
42
 
43
- prompt = messages[-1].content
44
 
45
- structured = f"""# Stakeholder Prompt
46
 
47
- A new product request has been submitted:
48
 
49
- "{prompt}"
50
-
51
- Please convert it into a structured product specification including:
52
  - Goals
53
- - Features
54
  - User Stories
55
- - KPIs
56
  """
57
 
58
- ai_msg = AIMessage(content=structured)
59
 
60
  return {
61
- "messages": messages + [ai_msg],
62
- "chat_log": chat_log + [{"role": "System", "content": structured}],
63
  }
64
 
65
 
66
- # --------------------------
67
- # ToolNodes for All Agents
68
- # --------------------------
 
69
  pm_node = ToolNode([product_manager_agent.run])
70
  proj_node = ToolNode([project_manager_agent.run])
71
  arch_node = ToolNode([software_architect_agent.run])
72
  dev_node = ToolNode([software_engineer_agent.run])
73
  qa_node = ToolNode([quality_assurance_agent.run])
74
 
75
-
76
- # --------------------------
77
- # Graph Definition
78
- # --------------------------
79
  graph = StateGraph(input=InputState, output=OutputState)
80
 
81
- # βœ… Native function for bridge node (no ToolNode needed)
82
- graph.add_node("Bridge", bridge_to_product_manager)
83
-
84
  graph.add_node("ProductManager", pm_node)
85
  graph.add_node("ProjectManager", proj_node)
86
  graph.add_node("SoftwareArchitect", arch_node)
@@ -98,14 +86,15 @@ graph.add_edge("QualityAssurance", END)
98
  compiled_graph = graph.compile()
99
 
100
 
101
- # --------------------------
102
- # Pipeline Entry Function
103
- # --------------------------
104
  def run_pipeline_and_save(prompt: str):
105
  initial_state = {
106
  "messages": [HumanMessage(content=prompt)],
107
  "chat_log": [],
108
  }
109
 
 
 
 
110
  final_state = compiled_graph.invoke(initial_state)
111
- return final_state["chat_log"], final_state["qa_output"]
 
12
  from typing import TypedDict, List
13
 
14
 
 
 
 
15
  class InputState(TypedDict):
16
  messages: List[BaseMessage]
17
  chat_log: list
 
26
  chat_log: list
27
 
28
 
29
+ # βœ… Fixed Bridge node (must return a proper state with messages and chat_log)
30
+ def bridge_to_pm(state: dict) -> dict:
 
 
 
31
  messages = state.get("messages", [])
32
  chat_log = state.get("chat_log", [])
33
 
34
  if not messages or not isinstance(messages[-1], HumanMessage):
35
  raise ValueError("Expected last message to be a HumanMessage")
36
 
37
+ user_prompt = messages[-1].content
38
 
39
+ structured_context = f"""# Stakeholder Prompt
40
 
41
+ "{user_prompt}"
42
 
43
+ Please convert this into a structured product specification including:
 
 
44
  - Goals
45
+ - Key Features
46
  - User Stories
47
+ - Success Metrics
48
  """
49
 
50
+ new_ai_msg = AIMessage(content=structured_context)
51
 
52
  return {
53
+ "messages": messages + [new_ai_msg], # Important: last item must be AIMessage
54
+ "chat_log": chat_log + [{"role": "System", "content": structured_context}],
55
  }
56
 
57
 
58
+ # βœ… Wrap bridge node in a ToolNode
59
+ bridge_node = ToolNode([bridge_to_pm])
60
+
61
+ # βœ… Other agents as ToolNodes
62
  pm_node = ToolNode([product_manager_agent.run])
63
  proj_node = ToolNode([project_manager_agent.run])
64
  arch_node = ToolNode([software_architect_agent.run])
65
  dev_node = ToolNode([software_engineer_agent.run])
66
  qa_node = ToolNode([quality_assurance_agent.run])
67
 
68
+ # βœ… Build LangGraph
 
 
 
69
  graph = StateGraph(input=InputState, output=OutputState)
70
 
71
+ graph.add_node("Bridge", bridge_node)
 
 
72
  graph.add_node("ProductManager", pm_node)
73
  graph.add_node("ProjectManager", proj_node)
74
  graph.add_node("SoftwareArchitect", arch_node)
 
86
  compiled_graph = graph.compile()
87
 
88
 
89
+ # βœ… Run function
 
 
90
  def run_pipeline_and_save(prompt: str):
91
  initial_state = {
92
  "messages": [HumanMessage(content=prompt)],
93
  "chat_log": [],
94
  }
95
 
96
+ # Sanity check
97
+ assert isinstance(initial_state["messages"][-1], HumanMessage)
98
+
99
  final_state = compiled_graph.invoke(initial_state)
100
+ return final_state["chat_log"], final_state["qa_output"]