KIG / kig_core /planner.py
heymenn's picture
Update kig_core/planner.py
bee2a16 verified
import logging
import re
import time
from typing import List, Dict, Any, Optional
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver # Or SqliteSaver etc.
from pydantic import BaseModel, Field
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from .config import settings
from .schemas import PlannerState, KeyIssue, GraphConfig # Import schemas
from .prompts import get_initial_planner_prompt, KEY_ISSUE_STRUCTURING_PROMPT
from .llm_interface import get_llm, invoke_llm
from .graph_operations import (
generate_cypher_auto, generate_cypher_guided,
retrieve_documents, evaluate_documents
)
from .processing import process_documents
logger = logging.getLogger(__name__)
# --- Graph Nodes ---
def start_planning(state: PlannerState) -> Dict[str, Any]:
"""Generates the initial plan based on the user query."""
logger.info("Node: start_planning")
user_query = state['user_query']
if not user_query:
return {"error": "User query is empty."}
initial_prompt = get_initial_planner_prompt(settings.plan_method, user_query)
llm = get_llm(settings.main_llm_model)
chain = initial_prompt | llm | StrOutputParser()
try:
plan_text = invoke_llm(chain,{}) # Prompt already includes query
logger.debug(f"Raw plan text: {plan_text}")
# Extract plan steps (simple regex, might need refinement)
plan_match = re.search(r"Plan:(.*?)<END_OF_PLAN>", plan_text, re.DOTALL | re.IGNORECASE)
if plan_match:
plan_steps = [step.strip() for step in re.split(r"\n\s*\d+\.\s*", plan_match.group(1)) if step.strip()]
logger.info(f"Extracted plan: {plan_steps}")
return {
"plan": plan_steps,
"current_plan_step_index": 0,
"messages": [AIMessage(content=plan_text)],
"step_outputs": {} # Initialize step outputs
}
else:
logger.error("Could not parse plan from LLM response.")
return {"error": "Failed to parse plan from LLM response.", "messages": [AIMessage(content=plan_text)]}
except Exception as e:
logger.error(f"Error during plan generation: {e}", exc_info=True)
return {"error": f"LLM error during plan generation: {e}"}
def execute_plan_step(state: PlannerState) -> Dict[str, Any]:
"""Executes the current step of the plan (retrieval, processing)."""
current_index = state['current_plan_step_index']
plan = state['plan']
user_query = state['user_query'] # Use original query for context
if current_index >= len(plan):
logger.warning("Plan step index out of bounds, attempting to finalize.")
# This should ideally be handled by the conditional edge, but as a fallback
return {"error": "Plan execution finished unexpectedly."}
step_description = plan[current_index]
logger.info(f"Node: execute_plan_step - Step {current_index + 1}/{len(plan)}: {step_description}")
# --- Determine Query for Retrieval ---
# Simple approach: Use step description or original query?
# Let's use the step description combined with the original query for context.
query_for_retrieval = f"Regarding the query '{user_query}', focus on: {step_description}"
logger.info(f"Query for retrieval: {query_for_retrieval}")
# --- Generate Cypher ---
cypher_query = ""
if settings.cypher_gen_method == 'auto':
cypher_query = generate_cypher_auto(query_for_retrieval)
elif settings.cypher_gen_method == 'guided':
cypher_query = generate_cypher_guided(query_for_retrieval, current_index)
# TODO: Add cypher validation if settings.validate_cypher is True
# --- Retrieve Documents ---
retrieved_docs = retrieve_documents(cypher_query)
# --- Evaluate Documents ---
evaluated_docs = evaluate_documents(retrieved_docs, query_for_retrieval)
# --- Process Documents ---
# Using configured processing steps
processed_docs_content = process_documents(evaluated_docs, settings.process_steps)
# --- Store Step Output ---
# Store the processed content relevant to this step
step_output = "\n\n".join(processed_docs_content) if processed_docs_content else "No relevant information found for this step."
current_step_outputs = state.get('step_outputs', {})
current_step_outputs[current_index] = step_output
logger.info(f"Finished executing plan step {current_index + 1}. Stored output.")
return {
"current_plan_step_index": current_index + 1,
"messages": [SystemMessage(content=f"Completed plan step {current_index + 1}. Context gathered:\n{step_output[:500]}...")], # Add summary message
"step_outputs": current_step_outputs
}
class KeyIssue(BaseModel):
# define your fields here
id: int
description: str
class KeyIssueList(BaseModel):
key_issues: List[KeyIssue] = Field(description="List of key issues")
class KeyIssueInvoke(BaseModel):
id: int
title: str
description: str
challenges: List[str]
potential_impact: Optional[str] = None
def generate_structured_issues(state: PlannerState) -> Dict[str, Any]:
"""Generates the final structured Key Issues based on all gathered context."""
logger.info("Node: generate_structured_issues")
user_query = state['user_query']
step_outputs = state.get('step_outputs', {})
# --- Combine Context from All Steps ---
full_context = f"Original User Query: {user_query}\n\n"
full_context += "Context gathered during planning:\n"
for i, output in sorted(step_outputs.items()):
full_context += f"--- Context from Step {i+1} ---\n{output}\n\n"
if not step_outputs:
full_context += "No context was gathered during the planning steps.\n"
logger.info(f"Generating key issues using combined context (length: {len(full_context)} chars).")
# logger.debug(f"Full Context for Key Issue Generation:\n{full_context}") # Optional: log full context
# --- Call LLM for Structured Output ---
issue_llm = get_llm(settings.main_llm_model)
# Use PydanticOutputParser for robust parsing
output_parser = JsonOutputParser(pydantic_object=KeyIssueList)
prompt = KEY_ISSUE_STRUCTURING_PROMPT.partial(
schema=output_parser.get_format_instructions(), # Inject schema instructions if needed by prompt
)
chain = prompt | issue_llm | output_parser
try:
structured_issues_obj = invoke_llm(chain, {
"user_query": user_query,
"context": full_context
})
print(f"structured_issues_obj => type : {type(structured_issues_obj)}, value : {structured_issues_obj}")
# If the output is a dict with a key 'key_issues', extract it
if isinstance(structured_issues_obj, dict) and 'key_issues' in structured_issues_obj:
issues_data = structured_issues_obj['key_issues']
else:
issues_data = structured_issues_obj # Assume it's already a list of dicts
# Always convert to KeyIssueInvoke objects
key_issues_list = [KeyIssueInvoke(**issue_dict) for issue_dict in issues_data]
# Ensure IDs are sequential if the LLM didn't assign them correctly
for i, issue in enumerate(key_issues_list):
issue.id = i + 1
logger.info(f"Successfully generated {len(key_issues_list)} structured key issues.")
final_message = f"Generated {len(key_issues_list)} Key Issues based on the query '{user_query}'."
return {
"key_issues": key_issues_list,
"messages": [AIMessage(content=final_message)],
"error": None
}
except Exception as e:
logger.error(f"Failed to generate or parse structured key issues: {e}", exc_info=True)
# Attempt to get raw output for debugging if possible
raw_output = "Could not retrieve raw output."
try:
raw_chain = prompt | issue_llm | StrOutputParser()
raw_output = invoke_llm(raw_chain, {"user_query": user_query, "context": full_context})
logger.debug(f"Raw output from failed JSON parsing:\n{raw_output}")
except Exception as raw_e:
logger.error(f"Could not even get raw output: {raw_e}")
return {"error": f"Failed to generate structured key issues: {e}. Raw output hint: {raw_output[:500]}..."}
# --- Conditional Edges ---
def should_continue_planning(state: PlannerState) -> str:
"""Determines if there are more plan steps to execute."""
logger.debug("Edge: should_continue_planning")
if state.get("error"):
logger.error(f"Error state detected: {state['error']}. Ending execution.")
return "error_state" # Go to a potential error handling end node
current_index = state['current_plan_step_index']
plan_length = len(state.get('plan', []))
if current_index < plan_length:
logger.debug(f"Continuing plan execution. Next step index: {current_index}")
return "continue_execution"
else:
logger.debug("Plan finished. Proceeding to final generation.")
return "finalize"
# --- Build Graph ---
def build_graph():
"""Builds the LangGraph workflow."""
workflow = StateGraph(PlannerState)
# Add nodes
workflow.add_node("start_planning", start_planning)
workflow.add_node("execute_plan_step", execute_plan_step)
workflow.add_node("generate_issues", generate_structured_issues)
# Optional: Add an error handling node
workflow.add_node("error_node", lambda state: {"messages": [SystemMessage(content=f"Execution failed: {state.get('error', 'Unknown error')}") ]})
# Define edges
workflow.set_entry_point("start_planning")
workflow.add_edge("start_planning", "execute_plan_step") # Assume plan is always generated
workflow.add_conditional_edges(
"execute_plan_step",
should_continue_planning,
{
"continue_execution": "execute_plan_step", # Loop back to execute next step
"finalize": "generate_issues", # Move to final generation
"error_state": "error_node" # Go to error node
}
)
workflow.add_edge("generate_issues", END)
workflow.add_edge("error_node", END) # End after error
# Compile the graph with memory (optional)
# memory = MemorySaver() # Use if state needs persistence between runs
# app_graph = workflow.compile(checkpointer=memory)
app_graph = workflow.compile()
return app_graph