Spaces:
Running
Running
import logging | |
import re | |
import time | |
from typing import List, Dict, Any, Optional | |
from langgraph.graph import StateGraph, END | |
from langgraph.checkpoint.memory import MemorySaver # Or SqliteSaver etc. | |
from pydantic import BaseModel, Field | |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage | |
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser | |
from .config import settings | |
from .schemas import PlannerState, KeyIssue, GraphConfig # Import schemas | |
from .prompts import get_initial_planner_prompt, KEY_ISSUE_STRUCTURING_PROMPT | |
from .llm_interface import get_llm, invoke_llm | |
from .graph_operations import ( | |
generate_cypher_auto, generate_cypher_guided, | |
retrieve_documents, evaluate_documents | |
) | |
from .processing import process_documents | |
logger = logging.getLogger(__name__) | |
# --- Graph Nodes --- | |
def start_planning(state: PlannerState) -> Dict[str, Any]: | |
"""Generates the initial plan based on the user query.""" | |
logger.info("Node: start_planning") | |
user_query = state['user_query'] | |
if not user_query: | |
return {"error": "User query is empty."} | |
initial_prompt = get_initial_planner_prompt(settings.plan_method, user_query) | |
llm = get_llm(settings.main_llm_model) | |
chain = initial_prompt | llm | StrOutputParser() | |
try: | |
plan_text = invoke_llm(chain,{}) # Prompt already includes query | |
logger.debug(f"Raw plan text: {plan_text}") | |
# Extract plan steps (simple regex, might need refinement) | |
plan_match = re.search(r"Plan:(.*?)<END_OF_PLAN>", plan_text, re.DOTALL | re.IGNORECASE) | |
if plan_match: | |
plan_steps = [step.strip() for step in re.split(r"\n\s*\d+\.\s*", plan_match.group(1)) if step.strip()] | |
logger.info(f"Extracted plan: {plan_steps}") | |
return { | |
"plan": plan_steps, | |
"current_plan_step_index": 0, | |
"messages": [AIMessage(content=plan_text)], | |
"step_outputs": {} # Initialize step outputs | |
} | |
else: | |
logger.error("Could not parse plan from LLM response.") | |
return {"error": "Failed to parse plan from LLM response.", "messages": [AIMessage(content=plan_text)]} | |
except Exception as e: | |
logger.error(f"Error during plan generation: {e}", exc_info=True) | |
return {"error": f"LLM error during plan generation: {e}"} | |
def execute_plan_step(state: PlannerState) -> Dict[str, Any]: | |
"""Executes the current step of the plan (retrieval, processing).""" | |
current_index = state['current_plan_step_index'] | |
plan = state['plan'] | |
user_query = state['user_query'] # Use original query for context | |
if current_index >= len(plan): | |
logger.warning("Plan step index out of bounds, attempting to finalize.") | |
# This should ideally be handled by the conditional edge, but as a fallback | |
return {"error": "Plan execution finished unexpectedly."} | |
step_description = plan[current_index] | |
logger.info(f"Node: execute_plan_step - Step {current_index + 1}/{len(plan)}: {step_description}") | |
# --- Determine Query for Retrieval --- | |
# Simple approach: Use step description or original query? | |
# Let's use the step description combined with the original query for context. | |
query_for_retrieval = f"Regarding the query '{user_query}', focus on: {step_description}" | |
logger.info(f"Query for retrieval: {query_for_retrieval}") | |
# --- Generate Cypher --- | |
cypher_query = "" | |
if settings.cypher_gen_method == 'auto': | |
cypher_query = generate_cypher_auto(query_for_retrieval) | |
elif settings.cypher_gen_method == 'guided': | |
cypher_query = generate_cypher_guided(query_for_retrieval, current_index) | |
# TODO: Add cypher validation if settings.validate_cypher is True | |
# --- Retrieve Documents --- | |
retrieved_docs = retrieve_documents(cypher_query) | |
# --- Evaluate Documents --- | |
evaluated_docs = evaluate_documents(retrieved_docs, query_for_retrieval) | |
# --- Process Documents --- | |
# Using configured processing steps | |
processed_docs_content = process_documents(evaluated_docs, settings.process_steps) | |
# --- Store Step Output --- | |
# Store the processed content relevant to this step | |
step_output = "\n\n".join(processed_docs_content) if processed_docs_content else "No relevant information found for this step." | |
current_step_outputs = state.get('step_outputs', {}) | |
current_step_outputs[current_index] = step_output | |
logger.info(f"Finished executing plan step {current_index + 1}. Stored output.") | |
return { | |
"current_plan_step_index": current_index + 1, | |
"messages": [SystemMessage(content=f"Completed plan step {current_index + 1}. Context gathered:\n{step_output[:500]}...")], # Add summary message | |
"step_outputs": current_step_outputs | |
} | |
class KeyIssue(BaseModel): | |
# define your fields here | |
id: int | |
description: str | |
class KeyIssueList(BaseModel): | |
key_issues: List[KeyIssue] = Field(description="List of key issues") | |
class KeyIssueInvoke(BaseModel): | |
id: int | |
title: str | |
description: str | |
challenges: List[str] | |
potential_impact: Optional[str] = None | |
def generate_structured_issues(state: PlannerState) -> Dict[str, Any]: | |
"""Generates the final structured Key Issues based on all gathered context.""" | |
logger.info("Node: generate_structured_issues") | |
user_query = state['user_query'] | |
step_outputs = state.get('step_outputs', {}) | |
# --- Combine Context from All Steps --- | |
full_context = f"Original User Query: {user_query}\n\n" | |
full_context += "Context gathered during planning:\n" | |
for i, output in sorted(step_outputs.items()): | |
full_context += f"--- Context from Step {i+1} ---\n{output}\n\n" | |
if not step_outputs: | |
full_context += "No context was gathered during the planning steps.\n" | |
logger.info(f"Generating key issues using combined context (length: {len(full_context)} chars).") | |
# logger.debug(f"Full Context for Key Issue Generation:\n{full_context}") # Optional: log full context | |
# --- Call LLM for Structured Output --- | |
issue_llm = get_llm(settings.main_llm_model) | |
# Use PydanticOutputParser for robust parsing | |
output_parser = JsonOutputParser(pydantic_object=KeyIssueList) | |
prompt = KEY_ISSUE_STRUCTURING_PROMPT.partial( | |
schema=output_parser.get_format_instructions(), # Inject schema instructions if needed by prompt | |
) | |
chain = prompt | issue_llm | output_parser | |
try: | |
structured_issues_obj = invoke_llm(chain, { | |
"user_query": user_query, | |
"context": full_context | |
}) | |
print(f"structured_issues_obj => type : {type(structured_issues_obj)}, value : {structured_issues_obj}") | |
# If the output is a dict with a key 'key_issues', extract it | |
if isinstance(structured_issues_obj, dict) and 'key_issues' in structured_issues_obj: | |
issues_data = structured_issues_obj['key_issues'] | |
else: | |
issues_data = structured_issues_obj # Assume it's already a list of dicts | |
# Always convert to KeyIssueInvoke objects | |
key_issues_list = [KeyIssueInvoke(**issue_dict) for issue_dict in issues_data] | |
# Ensure IDs are sequential if the LLM didn't assign them correctly | |
for i, issue in enumerate(key_issues_list): | |
issue.id = i + 1 | |
logger.info(f"Successfully generated {len(key_issues_list)} structured key issues.") | |
final_message = f"Generated {len(key_issues_list)} Key Issues based on the query '{user_query}'." | |
return { | |
"key_issues": key_issues_list, | |
"messages": [AIMessage(content=final_message)], | |
"error": None | |
} | |
except Exception as e: | |
logger.error(f"Failed to generate or parse structured key issues: {e}", exc_info=True) | |
# Attempt to get raw output for debugging if possible | |
raw_output = "Could not retrieve raw output." | |
try: | |
raw_chain = prompt | issue_llm | StrOutputParser() | |
raw_output = invoke_llm(raw_chain, {"user_query": user_query, "context": full_context}) | |
logger.debug(f"Raw output from failed JSON parsing:\n{raw_output}") | |
except Exception as raw_e: | |
logger.error(f"Could not even get raw output: {raw_e}") | |
return {"error": f"Failed to generate structured key issues: {e}. Raw output hint: {raw_output[:500]}..."} | |
# --- Conditional Edges --- | |
def should_continue_planning(state: PlannerState) -> str: | |
"""Determines if there are more plan steps to execute.""" | |
logger.debug("Edge: should_continue_planning") | |
if state.get("error"): | |
logger.error(f"Error state detected: {state['error']}. Ending execution.") | |
return "error_state" # Go to a potential error handling end node | |
current_index = state['current_plan_step_index'] | |
plan_length = len(state.get('plan', [])) | |
if current_index < plan_length: | |
logger.debug(f"Continuing plan execution. Next step index: {current_index}") | |
return "continue_execution" | |
else: | |
logger.debug("Plan finished. Proceeding to final generation.") | |
return "finalize" | |
# --- Build Graph --- | |
def build_graph(): | |
"""Builds the LangGraph workflow.""" | |
workflow = StateGraph(PlannerState) | |
# Add nodes | |
workflow.add_node("start_planning", start_planning) | |
workflow.add_node("execute_plan_step", execute_plan_step) | |
workflow.add_node("generate_issues", generate_structured_issues) | |
# Optional: Add an error handling node | |
workflow.add_node("error_node", lambda state: {"messages": [SystemMessage(content=f"Execution failed: {state.get('error', 'Unknown error')}") ]}) | |
# Define edges | |
workflow.set_entry_point("start_planning") | |
workflow.add_edge("start_planning", "execute_plan_step") # Assume plan is always generated | |
workflow.add_conditional_edges( | |
"execute_plan_step", | |
should_continue_planning, | |
{ | |
"continue_execution": "execute_plan_step", # Loop back to execute next step | |
"finalize": "generate_issues", # Move to final generation | |
"error_state": "error_node" # Go to error node | |
} | |
) | |
workflow.add_edge("generate_issues", END) | |
workflow.add_edge("error_node", END) # End after error | |
# Compile the graph with memory (optional) | |
# memory = MemorySaver() # Use if state needs persistence between runs | |
# app_graph = workflow.compile(checkpointer=memory) | |
app_graph = workflow.compile() | |
return app_graph |