import logging import re import time from typing import List, Dict, Any, Optional from langgraph.graph import StateGraph, END from langgraph.checkpoint.memory import MemorySaver # Or SqliteSaver etc. from pydantic import BaseModel, Field from langchain_core.messages import HumanMessage, AIMessage, SystemMessage from langchain_core.output_parsers import StrOutputParser, JsonOutputParser from .config import settings from .schemas import PlannerState, KeyIssue, GraphConfig # Import schemas from .prompts import get_initial_planner_prompt, KEY_ISSUE_STRUCTURING_PROMPT from .llm_interface import get_llm, invoke_llm from .graph_operations import ( generate_cypher_auto, generate_cypher_guided, retrieve_documents, evaluate_documents ) from .processing import process_documents logger = logging.getLogger(__name__) # --- Graph Nodes --- def start_planning(state: PlannerState) -> Dict[str, Any]: """Generates the initial plan based on the user query.""" logger.info("Node: start_planning") user_query = state['user_query'] if not user_query: return {"error": "User query is empty."} initial_prompt = get_initial_planner_prompt(settings.plan_method, user_query) llm = get_llm(settings.main_llm_model) chain = initial_prompt | llm | StrOutputParser() try: plan_text = invoke_llm(chain,{}) # Prompt already includes query logger.debug(f"Raw plan text: {plan_text}") # Extract plan steps (simple regex, might need refinement) plan_match = re.search(r"Plan:(.*?)", plan_text, re.DOTALL | re.IGNORECASE) if plan_match: plan_steps = [step.strip() for step in re.split(r"\n\s*\d+\.\s*", plan_match.group(1)) if step.strip()] logger.info(f"Extracted plan: {plan_steps}") return { "plan": plan_steps, "current_plan_step_index": 0, "messages": [AIMessage(content=plan_text)], "step_outputs": {} # Initialize step outputs } else: logger.error("Could not parse plan from LLM response.") return {"error": "Failed to parse plan from LLM response.", "messages": [AIMessage(content=plan_text)]} except Exception as e: logger.error(f"Error during plan generation: {e}", exc_info=True) return {"error": f"LLM error during plan generation: {e}"} def execute_plan_step(state: PlannerState) -> Dict[str, Any]: """Executes the current step of the plan (retrieval, processing).""" current_index = state['current_plan_step_index'] plan = state['plan'] user_query = state['user_query'] # Use original query for context if current_index >= len(plan): logger.warning("Plan step index out of bounds, attempting to finalize.") # This should ideally be handled by the conditional edge, but as a fallback return {"error": "Plan execution finished unexpectedly."} step_description = plan[current_index] logger.info(f"Node: execute_plan_step - Step {current_index + 1}/{len(plan)}: {step_description}") # --- Determine Query for Retrieval --- # Simple approach: Use step description or original query? # Let's use the step description combined with the original query for context. query_for_retrieval = f"Regarding the query '{user_query}', focus on: {step_description}" logger.info(f"Query for retrieval: {query_for_retrieval}") # --- Generate Cypher --- cypher_query = "" if settings.cypher_gen_method == 'auto': cypher_query = generate_cypher_auto(query_for_retrieval) elif settings.cypher_gen_method == 'guided': cypher_query = generate_cypher_guided(query_for_retrieval, current_index) # TODO: Add cypher validation if settings.validate_cypher is True # --- Retrieve Documents --- retrieved_docs = retrieve_documents(cypher_query) # --- Evaluate Documents --- evaluated_docs = evaluate_documents(retrieved_docs, query_for_retrieval) # --- Process Documents --- # Using configured processing steps processed_docs_content = process_documents(evaluated_docs, settings.process_steps) # --- Store Step Output --- # Store the processed content relevant to this step step_output = "\n\n".join(processed_docs_content) if processed_docs_content else "No relevant information found for this step." current_step_outputs = state.get('step_outputs', {}) current_step_outputs[current_index] = step_output logger.info(f"Finished executing plan step {current_index + 1}. Stored output.") return { "current_plan_step_index": current_index + 1, "messages": [SystemMessage(content=f"Completed plan step {current_index + 1}. Context gathered:\n{step_output[:500]}...")], # Add summary message "step_outputs": current_step_outputs } class KeyIssue(BaseModel): # define your fields here id: int description: str class KeyIssueList(BaseModel): key_issues: List[KeyIssue] = Field(description="List of key issues") class KeyIssueInvoke(BaseModel): id: int title: str description: str challenges: List[str] potential_impact: Optional[str] = None def generate_structured_issues(state: PlannerState) -> Dict[str, Any]: """Generates the final structured Key Issues based on all gathered context.""" logger.info("Node: generate_structured_issues") user_query = state['user_query'] step_outputs = state.get('step_outputs', {}) # --- Combine Context from All Steps --- full_context = f"Original User Query: {user_query}\n\n" full_context += "Context gathered during planning:\n" for i, output in sorted(step_outputs.items()): full_context += f"--- Context from Step {i+1} ---\n{output}\n\n" if not step_outputs: full_context += "No context was gathered during the planning steps.\n" logger.info(f"Generating key issues using combined context (length: {len(full_context)} chars).") # logger.debug(f"Full Context for Key Issue Generation:\n{full_context}") # Optional: log full context # --- Call LLM for Structured Output --- issue_llm = get_llm(settings.main_llm_model) # Use PydanticOutputParser for robust parsing output_parser = JsonOutputParser(pydantic_object=KeyIssueList) prompt = KEY_ISSUE_STRUCTURING_PROMPT.partial( schema=output_parser.get_format_instructions(), # Inject schema instructions if needed by prompt ) chain = prompt | issue_llm | output_parser try: structured_issues_obj = invoke_llm(chain, { "user_query": user_query, "context": full_context }) print(f"structured_issues_obj => type : {type(structured_issues_obj)}, value : {structured_issues_obj}") # If the output is a dict with a key 'key_issues', extract it if isinstance(structured_issues_obj, dict) and 'key_issues' in structured_issues_obj: issues_data = structured_issues_obj['key_issues'] else: issues_data = structured_issues_obj # Assume it's already a list of dicts # Always convert to KeyIssueInvoke objects key_issues_list = [KeyIssueInvoke(**issue_dict) for issue_dict in issues_data] # Ensure IDs are sequential if the LLM didn't assign them correctly for i, issue in enumerate(key_issues_list): issue.id = i + 1 logger.info(f"Successfully generated {len(key_issues_list)} structured key issues.") final_message = f"Generated {len(key_issues_list)} Key Issues based on the query '{user_query}'." return { "key_issues": key_issues_list, "messages": [AIMessage(content=final_message)], "error": None } except Exception as e: logger.error(f"Failed to generate or parse structured key issues: {e}", exc_info=True) # Attempt to get raw output for debugging if possible raw_output = "Could not retrieve raw output." try: raw_chain = prompt | issue_llm | StrOutputParser() raw_output = invoke_llm(raw_chain, {"user_query": user_query, "context": full_context}) logger.debug(f"Raw output from failed JSON parsing:\n{raw_output}") except Exception as raw_e: logger.error(f"Could not even get raw output: {raw_e}") return {"error": f"Failed to generate structured key issues: {e}. Raw output hint: {raw_output[:500]}..."} # --- Conditional Edges --- def should_continue_planning(state: PlannerState) -> str: """Determines if there are more plan steps to execute.""" logger.debug("Edge: should_continue_planning") if state.get("error"): logger.error(f"Error state detected: {state['error']}. Ending execution.") return "error_state" # Go to a potential error handling end node current_index = state['current_plan_step_index'] plan_length = len(state.get('plan', [])) if current_index < plan_length: logger.debug(f"Continuing plan execution. Next step index: {current_index}") return "continue_execution" else: logger.debug("Plan finished. Proceeding to final generation.") return "finalize" # --- Build Graph --- def build_graph(): """Builds the LangGraph workflow.""" workflow = StateGraph(PlannerState) # Add nodes workflow.add_node("start_planning", start_planning) workflow.add_node("execute_plan_step", execute_plan_step) workflow.add_node("generate_issues", generate_structured_issues) # Optional: Add an error handling node workflow.add_node("error_node", lambda state: {"messages": [SystemMessage(content=f"Execution failed: {state.get('error', 'Unknown error')}") ]}) # Define edges workflow.set_entry_point("start_planning") workflow.add_edge("start_planning", "execute_plan_step") # Assume plan is always generated workflow.add_conditional_edges( "execute_plan_step", should_continue_planning, { "continue_execution": "execute_plan_step", # Loop back to execute next step "finalize": "generate_issues", # Move to final generation "error_state": "error_node" # Go to error node } ) workflow.add_edge("generate_issues", END) workflow.add_edge("error_node", END) # End after error # Compile the graph with memory (optional) # memory = MemorySaver() # Use if state needs persistence between runs # app_graph = workflow.compile(checkpointer=memory) app_graph = workflow.compile() return app_graph