heymenn commited on
Commit
96ed966
·
verified ·
1 Parent(s): 3cf99bc

Update kig_core/planner.py

Browse files
Files changed (1) hide show
  1. kig_core/planner.py +16 -5
kig_core/planner.py CHANGED
@@ -4,6 +4,8 @@ from typing import List, Dict, Any
4
  from langgraph.graph import StateGraph, END
5
  from langgraph.checkpoint.memory import MemorySaver # Or SqliteSaver etc.
6
 
 
 
7
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
8
  from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
9
 
@@ -107,6 +109,14 @@ def execute_plan_step(state: PlannerState) -> Dict[str, Any]:
107
  "step_outputs": current_step_outputs
108
  }
109
 
 
 
 
 
 
 
 
 
110
 
111
  def generate_structured_issues(state: PlannerState) -> Dict[str, Any]:
112
  """Generates the final structured Key Issues based on all gathered context."""
@@ -130,8 +140,9 @@ def generate_structured_issues(state: PlannerState) -> Dict[str, Any]:
130
  # --- Call LLM for Structured Output ---
131
  issue_llm = get_llm(settings.main_llm_model)
132
  # Use PydanticOutputParser for robust parsing
133
- output_parser = JsonOutputParser(pydantic_object=List[KeyIssue])
134
 
 
135
  prompt = KEY_ISSUE_STRUCTURING_PROMPT.partial(
136
  schema=output_parser.get_format_instructions(), # Inject schema instructions if needed by prompt
137
  )
@@ -139,11 +150,11 @@ def generate_structured_issues(state: PlannerState) -> Dict[str, Any]:
139
  chain = prompt | issue_llm | output_parser
140
 
141
  try:
142
- structured_issues = chain.invoke({
143
  "user_query": user_query,
144
  "context": full_context
145
  })
146
-
147
  # Ensure IDs are sequential if the LLM didn't assign them correctly
148
  for i, issue in enumerate(structured_issues):
149
  issue.id = i + 1
@@ -152,8 +163,8 @@ def generate_structured_issues(state: PlannerState) -> Dict[str, Any]:
152
  final_message = f"Generated {len(structured_issues)} Key Issues based on the query '{user_query}'."
153
  return {
154
  "key_issues": structured_issues,
155
- "messages": [AIMessage(content=final_message)], # Final summary message
156
- "error": None # Clear any previous errors
157
  }
158
  except Exception as e:
159
  logger.error(f"Failed to generate or parse structured key issues: {e}", exc_info=True)
 
4
  from langgraph.graph import StateGraph, END
5
  from langgraph.checkpoint.memory import MemorySaver # Or SqliteSaver etc.
6
 
7
+ from pydantic import BaseModel, Field
8
+
9
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
10
  from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
11
 
 
109
  "step_outputs": current_step_outputs
110
  }
111
 
112
+ class KeyIssue(BaseModel):
113
+ # define your fields here
114
+ id: int
115
+ description: str
116
+
117
+ class KeyIssueList(BaseModel):
118
+ key_issues: List[KeyIssue] = Field(description="List of key issues")
119
+
120
 
121
  def generate_structured_issues(state: PlannerState) -> Dict[str, Any]:
122
  """Generates the final structured Key Issues based on all gathered context."""
 
140
  # --- Call LLM for Structured Output ---
141
  issue_llm = get_llm(settings.main_llm_model)
142
  # Use PydanticOutputParser for robust parsing
143
+ output_parser = JsonOutputParser(pydantic_object=KeyIssueList)
144
 
145
+
146
  prompt = KEY_ISSUE_STRUCTURING_PROMPT.partial(
147
  schema=output_parser.get_format_instructions(), # Inject schema instructions if needed by prompt
148
  )
 
150
  chain = prompt | issue_llm | output_parser
151
 
152
  try:
153
+ structured_issues_obj = chain.invoke({
154
  "user_query": user_query,
155
  "context": full_context
156
  })
157
+ structured_issues = structured_issues_obj.key_issues
158
  # Ensure IDs are sequential if the LLM didn't assign them correctly
159
  for i, issue in enumerate(structured_issues):
160
  issue.id = i + 1
 
163
  final_message = f"Generated {len(structured_issues)} Key Issues based on the query '{user_query}'."
164
  return {
165
  "key_issues": structured_issues,
166
+ "messages": [AIMessage(content=final_message)],
167
+ "error": None
168
  }
169
  except Exception as e:
170
  logger.error(f"Failed to generate or parse structured key issues: {e}", exc_info=True)