Spaces:
Running
Running
Upload 12 files
Browse files- .env +34 -0
- app.py +175 -0
- kig_core/config.py +89 -0
- kig_core/graph_client.py +91 -0
- kig_core/graph_operations.py +210 -0
- kig_core/llm_interface.py +59 -0
- kig_core/planner.py +226 -0
- kig_core/processing.py +127 -0
- kig_core/prompts.py +140 -0
- kig_core/schemas.py +55 -0
- kig_core/utils.py +41 -0
- requirements.txt +26 -0
.env
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Neo4j Credentials
|
2 |
+
NEO4J_URI="neo4j+s://4985272f.databases.neo4j.io"
|
3 |
+
NEO4J_USERNAME="neo4j"
|
4 |
+
NEO4J_PASSWORD="YOUR_NEO4J_PASSWORD" # Replace with your actual password
|
5 |
+
|
6 |
+
# API Keys
|
7 |
+
OPENAI_API_KEY="YOUR_OPENAI_API_KEY" # Replace if using OpenAI models
|
8 |
+
GEMINI_API_KEY="YOUR_GEMINI_API_KEY" # Replace with your actual key
|
9 |
+
LANGSMITH_API_KEY="YOUR_LANGSMITH_API_KEY" # Replace with your actual key (optional but recommended for tracing)
|
10 |
+
LANGCHAIN_PROJECT="KIG_Refactored" # Optional: For LangSmith tracing
|
11 |
+
|
12 |
+
# LLM Configuration
|
13 |
+
MAIN_LLM_MODEL="gemini-1.5-flash" # Or another preferred model
|
14 |
+
EVAL_LLM_MODEL="gemini-1.5-flash"
|
15 |
+
SUMMARIZE_LLM_MODEL="gemini-1.5-flash"
|
16 |
+
|
17 |
+
# Planner Configuration
|
18 |
+
PLAN_METHOD="generation" # or "modification"
|
19 |
+
USE_DETAILED_QUERY="false" # or "true"
|
20 |
+
|
21 |
+
# Graph Operations Configuration
|
22 |
+
CYPHER_GEN_METHOD="guided" # or "auto"
|
23 |
+
VALIDATE_CYPHER="false" # or "true"
|
24 |
+
EVAL_METHOD="binary" # or "score"
|
25 |
+
EVAL_THRESHOLD="0.7"
|
26 |
+
MAX_DOCS="10"
|
27 |
+
|
28 |
+
# Processing Configuration
|
29 |
+
# Define processing steps as a JSON string or handle differently if complex needed
|
30 |
+
PROCESS_STEPS='["summarize"]' # Example: Just summarize
|
31 |
+
COMPRESSION_METHOD="llm_lingua" # if used
|
32 |
+
COMPRESS_RATE="0.5" # if used
|
33 |
+
|
34 |
+
# Add other parameters as needed
|
app.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import logging
|
4 |
+
import time
|
5 |
+
import json # For displaying dicts/lists nicely
|
6 |
+
|
7 |
+
# Import core components from the refactored library
|
8 |
+
from kig_core.config import settings # Loads config on import
|
9 |
+
from kig_core.schemas import PlannerState, KeyIssue, GraphConfig
|
10 |
+
from kig_core.planner import build_graph
|
11 |
+
from kig_core.utils import key_issues_to_dataframe, dataframe_to_excel_bytes
|
12 |
+
from kig_core.graph_client import neo4j_client # Import the initialized client instance
|
13 |
+
|
14 |
+
# Configure logging for Streamlit app
|
15 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
# --- Streamlit Page Configuration ---
|
19 |
+
st.set_page_config(page_title="Key Issue Generator (KIG)", layout="wide")
|
20 |
+
st.title(" KIG - Key Issue Generator ")
|
21 |
+
st.write("Generate structured Key Issues from knowledge graph context.")
|
22 |
+
|
23 |
+
# --- Sidebar ---
|
24 |
+
with st.sidebar:
|
25 |
+
st.header(" Status & Info ")
|
26 |
+
# Check Neo4j Connectivity on startup
|
27 |
+
neo4j_status = st.empty()
|
28 |
+
try:
|
29 |
+
# Accessing the client instance will trigger verification if not already done
|
30 |
+
neo4j_client._get_driver() # Ensure connection is attempted
|
31 |
+
neo4j_status.success("Neo4j Connection Verified")
|
32 |
+
can_run = True
|
33 |
+
except ConnectionError as e:
|
34 |
+
neo4j_status.error(f"Neo4j Error: {e}")
|
35 |
+
can_run = False
|
36 |
+
except Exception as e:
|
37 |
+
neo4j_status.error(f"Neo4j Init Error: {e}")
|
38 |
+
can_run = False
|
39 |
+
|
40 |
+
st.header("Configuration")
|
41 |
+
# Display some key settings (be careful with secrets)
|
42 |
+
st.text(f"Main LLM: {settings.main_llm_model}")
|
43 |
+
st.text(f"Neo4j URI: {settings.neo4j_uri}")
|
44 |
+
st.text(f"Plan Method: {settings.plan_method}")
|
45 |
+
st.text(f"Max Docs: {settings.max_docs}")
|
46 |
+
|
47 |
+
st.header("About")
|
48 |
+
st.info("""
|
49 |
+
This app uses LLMs and a Neo4j graph to:
|
50 |
+
1. Plan an approach based on your query.
|
51 |
+
2. Execute the plan, retrieving & processing graph data.
|
52 |
+
3. Generate structured Key Issues.
|
53 |
+
4. Output results to an Excel file.
|
54 |
+
""")
|
55 |
+
|
56 |
+
# --- Main Application Logic ---
|
57 |
+
st.header("Enter Your Query")
|
58 |
+
user_query = st.text_area(
|
59 |
+
"Describe the technical requirement or area you want to explore for Key Issues:",
|
60 |
+
"What are the main challenges and potential key issues in deploying edge computing for real-time AI-driven traffic management systems in smart cities?",
|
61 |
+
height=150
|
62 |
+
)
|
63 |
+
|
64 |
+
# Session state to store results across reruns if needed
|
65 |
+
if 'key_issues_result' not in st.session_state:
|
66 |
+
st.session_state.key_issues_result = None
|
67 |
+
if 'log_messages' not in st.session_state:
|
68 |
+
st.session_state.log_messages = []
|
69 |
+
|
70 |
+
# Placeholder for status updates
|
71 |
+
status_placeholder = st.empty()
|
72 |
+
results_placeholder = st.container()
|
73 |
+
log_placeholder = st.expander("Show Execution Log")
|
74 |
+
|
75 |
+
if st.button("Generate Key Issues", type="primary", disabled=not can_run):
|
76 |
+
if not user_query:
|
77 |
+
st.error("Please enter a query.")
|
78 |
+
else:
|
79 |
+
st.session_state.key_issues_result = None # Clear previous results
|
80 |
+
st.session_state.log_messages = ["Starting Key Issue generation..."]
|
81 |
+
|
82 |
+
with st.spinner("Processing... Building graph and executing workflow..."):
|
83 |
+
start_time = time.time()
|
84 |
+
try:
|
85 |
+
# Build the graph
|
86 |
+
status_placeholder.info("Building workflow graph...")
|
87 |
+
app_graph = build_graph()
|
88 |
+
st.session_state.log_messages.append("Workflow graph built.")
|
89 |
+
|
90 |
+
# Define the initial state
|
91 |
+
initial_state: PlannerState = {
|
92 |
+
"user_query": user_query,
|
93 |
+
"messages": [HumanMessage(content=user_query)],
|
94 |
+
"plan": [],
|
95 |
+
"current_plan_step_index": -1, # Will be set by start_planning
|
96 |
+
"step_outputs": {},
|
97 |
+
"key_issues": [],
|
98 |
+
"error": None
|
99 |
+
}
|
100 |
+
|
101 |
+
# Configuration for the graph run (e.g., thread_id for memory)
|
102 |
+
# Using user query hash as a simple thread identifier for memory (if used)
|
103 |
+
import hashlib
|
104 |
+
thread_id = hashlib.sha256(user_query.encode()).hexdigest()[:8]
|
105 |
+
config: GraphConfig = {"configurable": {"thread_id": thread_id}}
|
106 |
+
|
107 |
+
status_placeholder.info("Executing workflow... (This may take a while)")
|
108 |
+
st.session_state.log_messages.append("Invoking graph stream...")
|
109 |
+
|
110 |
+
final_state = None
|
111 |
+
# Stream events for logging/updates
|
112 |
+
for i, step_state in enumerate(app_graph.stream(initial_state, config=config)):
|
113 |
+
# step_state is a dictionary where keys are node names
|
114 |
+
node_name = list(step_state.keys())[0]
|
115 |
+
node_output = step_state[node_name]
|
116 |
+
log_msg = f"Step {i+1}: Node '{node_name}' executed."
|
117 |
+
st.session_state.log_messages.append(log_msg)
|
118 |
+
# logger.info(log_msg) # Log to console as well
|
119 |
+
# logger.debug(f"Node output: {node_output}")
|
120 |
+
|
121 |
+
# You could update the status placeholder more dynamically here
|
122 |
+
# status_placeholder.info(f"Executing: {node_name}...")
|
123 |
+
|
124 |
+
final_state = node_output # Keep track of the latest state
|
125 |
+
|
126 |
+
|
127 |
+
end_time = time.time()
|
128 |
+
st.session_state.log_messages.append(f"Workflow finished in {end_time - start_time:.2f} seconds.")
|
129 |
+
status_placeholder.success(f"Processing Complete! ({end_time - start_time:.2f}s)")
|
130 |
+
|
131 |
+
# --- Process Final Results ---
|
132 |
+
if final_state and not final_state.get("error"):
|
133 |
+
generated_issues = final_state.get("key_issues", [])
|
134 |
+
st.session_state.key_issues_result = generated_issues
|
135 |
+
st.session_state.log_messages.append(f"Successfully extracted {len(generated_issues)} key issues.")
|
136 |
+
elif final_state and final_state.get("error"):
|
137 |
+
error_msg = final_state.get("error", "Unknown error")
|
138 |
+
st.session_state.log_messages.append(f"Workflow failed: {error_msg}")
|
139 |
+
status_placeholder.error(f"Workflow failed: {error_msg}")
|
140 |
+
else:
|
141 |
+
st.session_state.log_messages.append("Workflow finished, but no final state or key issues found.")
|
142 |
+
status_placeholder.warning("Workflow finished, but no key issues were generated.")
|
143 |
+
|
144 |
+
except Exception as e:
|
145 |
+
end_time = time.time()
|
146 |
+
logger.error(f"An error occurred during graph execution: {e}", exc_info=True)
|
147 |
+
status_placeholder.error(f"An unexpected error occurred: {e}")
|
148 |
+
st.session_state.log_messages.append(f"FATAL ERROR: {e}")
|
149 |
+
|
150 |
+
|
151 |
+
# --- Display Results ---
|
152 |
+
if st.session_state.key_issues_result:
|
153 |
+
issues = st.session_state.key_issues_result
|
154 |
+
results_placeholder.subheader(f"Generated Key Issues ({len(issues)})")
|
155 |
+
|
156 |
+
df = key_issues_to_dataframe(issues)
|
157 |
+
|
158 |
+
if not df.empty:
|
159 |
+
# Display as DataFrame
|
160 |
+
results_placeholder.dataframe(df, use_container_width=True)
|
161 |
+
|
162 |
+
# Provide download button
|
163 |
+
excel_bytes = dataframe_to_excel_bytes(df)
|
164 |
+
results_placeholder.download_button(
|
165 |
+
label="📥 Download Key Issues as Excel",
|
166 |
+
data=excel_bytes,
|
167 |
+
file_name="key_issues_output.xlsx",
|
168 |
+
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
results_placeholder.info("No key issues were generated or parsed correctly.")
|
172 |
+
|
173 |
+
# Display logs
|
174 |
+
with log_placeholder:
|
175 |
+
st.code("\n".join(st.session_state.log_messages), language="text")
|
kig_core/config.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
3 |
+
from pydantic import Field, SecretStr, HttpUrl, validator, Json
|
4 |
+
from typing import List, Optional, Literal, Union
|
5 |
+
|
6 |
+
# Helper function to load .env file if it exists
|
7 |
+
# Ensure python-dotenv is installed: pip install python-dotenv
|
8 |
+
try:
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
print("Attempting to load .env file...")
|
11 |
+
if load_dotenv():
|
12 |
+
print(".env file loaded successfully.")
|
13 |
+
else:
|
14 |
+
print(".env file not found or empty.")
|
15 |
+
except ImportError:
|
16 |
+
print("python-dotenv not installed, skipping .env file loading.")
|
17 |
+
pass # Optional: Handle missing dotenv library
|
18 |
+
|
19 |
+
|
20 |
+
class Settings(BaseSettings):
|
21 |
+
# Load from .env file
|
22 |
+
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', extra='ignore')
|
23 |
+
|
24 |
+
# Neo4j Credentials
|
25 |
+
neo4j_uri: str = Field(..., validation_alias='NEO4J_URI')
|
26 |
+
neo4j_username: str = Field("neo4j", validation_alias='NEO4J_USERNAME')
|
27 |
+
neo4j_password: SecretStr = Field(..., validation_alias='NEO4J_PASSWORD')
|
28 |
+
|
29 |
+
# API Keys
|
30 |
+
openai_api_key: Optional[SecretStr] = Field(None, validation_alias='OPENAI_API_KEY')
|
31 |
+
gemini_api_key: Optional[SecretStr] = Field(None, validation_alias='GEMINI_API_KEY')
|
32 |
+
langsmith_api_key: Optional[SecretStr] = Field(None, validation_alias='LANGSMITH_API_KEY')
|
33 |
+
langchain_project: Optional[str] = Field("KIG_Refactored", validation_alias='LANGCHAIN_PROJECT')
|
34 |
+
|
35 |
+
# LLM Configuration
|
36 |
+
main_llm_model: str = Field("gemini-1.5-flash", validation_alias='MAIN_LLM_MODEL')
|
37 |
+
eval_llm_model: str = Field("gemini-1.5-flash", validation_alias='EVAL_LLM_MODEL')
|
38 |
+
summarize_llm_model: str = Field("gemini-1.5-flash", validation_alias='SUMMARIZE_LLM_MODEL')
|
39 |
+
# Add other models if needed (e.g., cypher gen, concept selection)
|
40 |
+
|
41 |
+
# Planner Configuration
|
42 |
+
plan_method: Literal["generation", "modification"] = Field("generation", validation_alias='PLAN_METHOD')
|
43 |
+
use_detailed_query: bool = Field(False, validation_alias='USE_DETAILED_QUERY')
|
44 |
+
|
45 |
+
# Graph Operations Configuration
|
46 |
+
cypher_gen_method: Literal["guided", "auto"] = Field("guided", validation_alias='CYPHER_GEN_METHOD')
|
47 |
+
validate_cypher: bool = Field(False, validation_alias='VALIDATE_CYPHER')
|
48 |
+
eval_method: Literal["binary", "score"] = Field("binary", validation_alias='EVAL_METHOD')
|
49 |
+
eval_threshold: float = Field(0.7, validation_alias='EVAL_THRESHOLD')
|
50 |
+
max_docs: int = Field(10, validation_alias='MAX_DOCS')
|
51 |
+
|
52 |
+
# Processing Configuration
|
53 |
+
# Load processing steps from JSON string in .env
|
54 |
+
process_steps: Json[List[Union[str, dict]]] = Field('["summarize"]', validation_alias='PROCESS_STEPS')
|
55 |
+
compression_method: Optional[str] = Field(None, validation_alias='COMPRESSION_METHOD')
|
56 |
+
compress_rate: Optional[float] = Field(0.5, validation_alias='COMPRESS_RATE')
|
57 |
+
|
58 |
+
# Langsmith Tracing (set automatically based on key)
|
59 |
+
langsmith_tracing_v2: str = "false"
|
60 |
+
|
61 |
+
@validator('langsmith_tracing_v2', pre=True, always=True)
|
62 |
+
def set_langsmith_tracing(cls, v, values):
|
63 |
+
return "true" if values.get('langsmith_api_key') else "false"
|
64 |
+
|
65 |
+
def configure_langsmith(self):
|
66 |
+
"""Sets Langsmith environment variables if API key is provided."""
|
67 |
+
if self.langsmith_api_key:
|
68 |
+
os.environ["LANGCHAIN_TRACING_V2"] = self.langsmith_tracing_v2
|
69 |
+
os.environ["LANGCHAIN_API_KEY"] = self.langsmith_api_key.get_secret_value()
|
70 |
+
if self.langchain_project:
|
71 |
+
os.environ["LANGCHAIN_PROJECT"] = self.langchain_project
|
72 |
+
print("Langsmith configured.")
|
73 |
+
else:
|
74 |
+
# Ensure tracing is disabled if no key
|
75 |
+
os.environ["LANGCHAIN_TRACING_V2"] = "false"
|
76 |
+
print("Langsmith key not found, tracing disabled.")
|
77 |
+
|
78 |
+
# Create a single instance to be imported elsewhere
|
79 |
+
settings = Settings()
|
80 |
+
# Automatically configure Langsmith on import
|
81 |
+
settings.configure_langsmith()
|
82 |
+
|
83 |
+
# Optionally set Gemini key in environment if needed by library implicitly
|
84 |
+
if settings.gemini_api_key:
|
85 |
+
os.environ["GOOGLE_API_KEY"] = settings.gemini_api_key.get_secret_value()
|
86 |
+
print("Set GOOGLE_API_KEY environment variable.")
|
87 |
+
if settings.openai_api_key:
|
88 |
+
os.environ["OPENAI_API_KEY"] = settings.openai_api_key.get_secret_value()
|
89 |
+
print("Set OPENAI_API_KEY environment variable.")
|
kig_core/graph_client.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from neo4j import GraphDatabase, Driver, exceptions
|
2 |
+
from .config import settings
|
3 |
+
import logging
|
4 |
+
from typing import List, Dict, Any, Optional
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
class Neo4jClient:
|
9 |
+
_driver: Optional[Driver] = None
|
10 |
+
|
11 |
+
def _get_driver(self) -> Driver:
|
12 |
+
"""Initializes and returns the Neo4j driver instance."""
|
13 |
+
if self._driver is None or self._driver.closed():
|
14 |
+
logger.info(f"Initializing Neo4j Driver for URI: {settings.neo4j_uri}")
|
15 |
+
try:
|
16 |
+
self._driver = GraphDatabase.driver(
|
17 |
+
settings.neo4j_uri,
|
18 |
+
auth=(settings.neo4j_username, settings.neo4j_password.get_secret_value())
|
19 |
+
)
|
20 |
+
# Verify connectivity during initialization
|
21 |
+
self._driver.verify_connectivity()
|
22 |
+
logger.info("Neo4j Driver initialized and connection verified.")
|
23 |
+
except exceptions.AuthError as e:
|
24 |
+
logger.error(f"Neo4j Authentication Error: {e}", exc_info=True)
|
25 |
+
raise ConnectionError("Neo4j Authentication Failed. Check credentials.") from e
|
26 |
+
except exceptions.ServiceUnavailable as e:
|
27 |
+
logger.error(f"Neo4j Service Unavailable: {e}", exc_info=True)
|
28 |
+
raise ConnectionError(f"Could not connect to Neo4j at {settings.neo4j_uri}. Ensure DB is running and reachable.") from e
|
29 |
+
except Exception as e:
|
30 |
+
logger.error(f"Unexpected error initializing Neo4j Driver: {e}", exc_info=True)
|
31 |
+
raise ConnectionError("An unexpected error occurred connecting to Neo4j.") from e
|
32 |
+
return self._driver
|
33 |
+
|
34 |
+
def close(self):
|
35 |
+
"""Closes the Neo4j driver connection."""
|
36 |
+
if self._driver and not self._driver.closed():
|
37 |
+
logger.info("Closing Neo4j Driver.")
|
38 |
+
self._driver.close()
|
39 |
+
self._driver = None
|
40 |
+
|
41 |
+
def query(self, cypher_query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
42 |
+
"""Executes a Cypher query and returns the results."""
|
43 |
+
driver = self._get_driver()
|
44 |
+
logger.debug(f"Executing Cypher: {cypher_query} with params: {params}")
|
45 |
+
try:
|
46 |
+
# Use session/transaction for robust execution
|
47 |
+
with driver.session() as session:
|
48 |
+
result = session.run(cypher_query, params or {})
|
49 |
+
# Convert Neo4j Records to dictionaries
|
50 |
+
data = [record.data() for record in result]
|
51 |
+
logger.debug(f"Query returned {len(data)} records.")
|
52 |
+
return data
|
53 |
+
except (exceptions.ServiceUnavailable, exceptions.SessionExpired) as e:
|
54 |
+
logger.error(f"Neo4j connection error during query: {e}", exc_info=True)
|
55 |
+
# Attempt to close the potentially broken driver so it reconnects next time
|
56 |
+
self.close()
|
57 |
+
raise ConnectionError("Neo4j connection error during query execution.") from e
|
58 |
+
except exceptions.CypherSyntaxError as e:
|
59 |
+
logger.error(f"Neo4j Cypher Syntax Error: {e}\nQuery: {cypher_query}", exc_info=True)
|
60 |
+
raise ValueError("Invalid Cypher query syntax.") from e
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"Unexpected error during Neo4j query: {e}", exc_info=True)
|
63 |
+
raise RuntimeError("An unexpected error occurred during the Neo4j query.") from e
|
64 |
+
|
65 |
+
def get_schema(self, force_refresh: bool = False) -> Dict[str, Any]:
|
66 |
+
""" Fetches the graph schema. Placeholder - Langchain community graph has better schema fetching."""
|
67 |
+
# For simplicity, returning empty. Implement actual schema fetching if needed.
|
68 |
+
# Consider using langchain_community.graphs.Neo4jGraph for schema handling if complex interactions are needed.
|
69 |
+
logger.warning("Neo4jClient.get_schema() is a placeholder. Implement if schema needed.")
|
70 |
+
return {} # Placeholder
|
71 |
+
|
72 |
+
def get_concepts(self) -> List[str]:
|
73 |
+
"""Fetches all Concept names from the graph."""
|
74 |
+
cypher = "MATCH (c:Concept) RETURN c.name AS name ORDER BY name"
|
75 |
+
results = self.query(cypher)
|
76 |
+
return [record['name'] for record in results if 'name' in record]
|
77 |
+
|
78 |
+
def get_concept_description(self, concept_name: str) -> Optional[str]:
|
79 |
+
"""Fetches the description for a specific concept."""
|
80 |
+
cypher = "MATCH (c:Concept {name: $name}) RETURN c.description AS description LIMIT 1"
|
81 |
+
params = {"name": concept_name}
|
82 |
+
results = self.query(cypher, params)
|
83 |
+
return results[0]['description'] if results and 'description' in results[0] else None
|
84 |
+
|
85 |
+
|
86 |
+
# Create a single instance for the application to use
|
87 |
+
neo4j_client = Neo4jClient()
|
88 |
+
|
89 |
+
# Ensure the client is closed gracefully when the application exits
|
90 |
+
import atexit
|
91 |
+
atexit.register(neo4j_client.close)
|
kig_core/graph_operations.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import logging
|
3 |
+
from typing import List, Dict, Any, Optional, Tuple
|
4 |
+
from random import sample, shuffle
|
5 |
+
|
6 |
+
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
|
7 |
+
from langchain_core.runnables import Runnable, RunnablePassthrough
|
8 |
+
from langchain_core.pydantic_v1 import Field, BaseModel as V1BaseModel # For grader models if needed
|
9 |
+
|
10 |
+
from .config import settings
|
11 |
+
from .graph_client import neo4j_client # Use the central client
|
12 |
+
from .llm_interface import get_llm
|
13 |
+
from .prompts import (
|
14 |
+
CYPHER_GENERATION_PROMPT, CONCEPT_SELECTION_PROMPT,
|
15 |
+
BINARY_GRADER_PROMPT, SCORE_GRADER_PROMPT
|
16 |
+
)
|
17 |
+
from .schemas import KeyIssue # Import if needed here, maybe not
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
# --- Helper Functions ---
|
22 |
+
def extract_cypher(text: str) -> str:
|
23 |
+
"""Extracts the first Cypher code block or returns the text itself."""
|
24 |
+
pattern = r"```(?:cypher)?\s*(.*?)\s*```"
|
25 |
+
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
|
26 |
+
return match.group(1).strip() if match else text.strip()
|
27 |
+
|
28 |
+
def format_doc_for_llm(doc: Dict[str, Any]) -> str:
|
29 |
+
"""Formats a document dictionary into a string for LLM context."""
|
30 |
+
return "\n".join(f"**{key}**: {value}" for key, value in doc.items() if value)
|
31 |
+
|
32 |
+
|
33 |
+
# --- Cypher Generation ---
|
34 |
+
def generate_cypher_auto(question: str) -> str:
|
35 |
+
"""Generates Cypher using the 'auto' method."""
|
36 |
+
logger.info("Generating Cypher using 'auto' method.")
|
37 |
+
# Schema fetching needs implementation if required by the prompt/LLM
|
38 |
+
# schema_info = neo4j_client.get_schema() # Placeholder
|
39 |
+
schema_info = "Schema not available." # Default if not implemented
|
40 |
+
|
41 |
+
cypher_llm = get_llm(settings.main_llm_model) # Or a specific cypher model
|
42 |
+
chain = (
|
43 |
+
{"question": RunnablePassthrough(), "schema": lambda x: schema_info}
|
44 |
+
| CYPHER_GENERATION_PROMPT
|
45 |
+
| cypher_llm
|
46 |
+
| StrOutputParser()
|
47 |
+
| extract_cypher
|
48 |
+
)
|
49 |
+
return chain.invoke(question)
|
50 |
+
|
51 |
+
def generate_cypher_guided(question: str, plan_step: int) -> str:
|
52 |
+
"""Generates Cypher using the 'guided' method based on concepts."""
|
53 |
+
logger.info(f"Generating Cypher using 'guided' method for plan step {plan_step}.")
|
54 |
+
try:
|
55 |
+
concepts = neo4j_client.get_concepts()
|
56 |
+
if not concepts:
|
57 |
+
logger.warning("No concepts found in Neo4j for guided cypher generation.")
|
58 |
+
return "" # Or raise error
|
59 |
+
|
60 |
+
concept_llm = get_llm(settings.main_llm_model) # Or a specific concept model
|
61 |
+
concept_chain = (
|
62 |
+
CONCEPT_SELECTION_PROMPT
|
63 |
+
| concept_llm
|
64 |
+
| StrOutputParser()
|
65 |
+
)
|
66 |
+
selected_concept = concept_chain.invoke({
|
67 |
+
"question": question,
|
68 |
+
"concepts": "\n".join(concepts)
|
69 |
+
}).strip()
|
70 |
+
|
71 |
+
logger.info(f"Concept selected by LLM: {selected_concept}")
|
72 |
+
|
73 |
+
# Basic check if the selected concept is valid
|
74 |
+
if selected_concept not in concepts:
|
75 |
+
logger.warning(f"LLM selected concept '{selected_concept}' not in the known list. Attempting fallback or ignoring.")
|
76 |
+
# Optional: Add fuzzy matching or similarity search here
|
77 |
+
# For now, we might default or return empty
|
78 |
+
# Let's try a simple substring check as a fallback
|
79 |
+
found_match = None
|
80 |
+
for c in concepts:
|
81 |
+
if selected_concept.lower() in c.lower():
|
82 |
+
found_match = c
|
83 |
+
logger.info(f"Found potential match: '{found_match}'")
|
84 |
+
break
|
85 |
+
if not found_match:
|
86 |
+
logger.error(f"Could not validate selected concept: {selected_concept}")
|
87 |
+
return "" # Return empty query if concept is invalid
|
88 |
+
selected_concept = found_match
|
89 |
+
|
90 |
+
|
91 |
+
# Determine the target node type based on plan step (example logic)
|
92 |
+
# This mapping might need adjustment based on the actual plan structure
|
93 |
+
if plan_step <= 1: # Steps 0 and 1: Context gathering
|
94 |
+
target = "(ts:TechnicalSpecification)"
|
95 |
+
fields = "ts.title, ts.scope, ts.description"
|
96 |
+
elif plan_step == 2: # Step 2: Research papers?
|
97 |
+
target = "(rp:ResearchPaper)"
|
98 |
+
fields = "rp.title, rp.abstract"
|
99 |
+
else: # Later steps might involve KeyIssues themselves or other types
|
100 |
+
target = "(n)" # Generic fallback
|
101 |
+
fields = "n.title, n.description" # Assuming common fields
|
102 |
+
|
103 |
+
# Construct Cypher query
|
104 |
+
# Ensure selected_concept is properly escaped if needed, though parameters are safer
|
105 |
+
cypher = f"MATCH (c:Concept {{name: $conceptName}})-[:RELATED_TO]-{target} RETURN {fields}"
|
106 |
+
# We return the query and the parameters separately for safe execution
|
107 |
+
# However, the planner currently expects just the string. Let's construct it directly for now.
|
108 |
+
# Be cautious about injection if concept names can contain special chars. Binding is preferred.
|
109 |
+
escaped_concept = selected_concept.replace("'", "\\'") # Basic escaping
|
110 |
+
cypher = f"MATCH (c:Concept {{name: '{escaped_concept}'}})-[:RELATED_TO]-{target} RETURN {fields}"
|
111 |
+
|
112 |
+
logger.info(f"Generated guided Cypher: {cypher}")
|
113 |
+
return cypher
|
114 |
+
|
115 |
+
except Exception as e:
|
116 |
+
logger.error(f"Error during guided cypher generation: {e}", exc_info=True)
|
117 |
+
return "" # Return empty on error
|
118 |
+
|
119 |
+
|
120 |
+
# --- Document Retrieval ---
|
121 |
+
def retrieve_documents(cypher_query: str) -> List[Dict[str, Any]]:
|
122 |
+
"""Retrieves documents from Neo4j using a Cypher query."""
|
123 |
+
if not cypher_query:
|
124 |
+
logger.warning("Received empty Cypher query, skipping retrieval.")
|
125 |
+
return []
|
126 |
+
logger.info(f"Retrieving documents with Cypher: {cypher_query}")
|
127 |
+
try:
|
128 |
+
# Use the centralized client's query method
|
129 |
+
raw_results = neo4j_client.query(cypher_query)
|
130 |
+
# Basic cleaning/deduplication (can be enhanced)
|
131 |
+
processed_results = []
|
132 |
+
seen = set()
|
133 |
+
for doc in raw_results:
|
134 |
+
# Create a frozenset of items for hashable representation to detect duplicates
|
135 |
+
doc_items = frozenset(doc.items())
|
136 |
+
if doc_items not in seen:
|
137 |
+
processed_results.append(doc)
|
138 |
+
seen.add(doc_items)
|
139 |
+
logger.info(f"Retrieved {len(processed_results)} unique documents.")
|
140 |
+
return processed_results
|
141 |
+
except (ConnectionError, ValueError, RuntimeError) as e:
|
142 |
+
# Errors already logged in neo4j_client
|
143 |
+
logger.error(f"Document retrieval failed: {e}")
|
144 |
+
return [] # Return empty list on failure
|
145 |
+
|
146 |
+
|
147 |
+
# --- Document Evaluation ---
|
148 |
+
# Define Pydantic models for structured LLM grader output (if not using built-in LCEL structured output)
|
149 |
+
class GradeDocumentsBinary(V1BaseModel):
|
150 |
+
"""Binary score for relevance check."""
|
151 |
+
binary_score: str = Field(description="Relevant? 'yes' or 'no'")
|
152 |
+
|
153 |
+
class GradeDocumentsScore(V1BaseModel):
|
154 |
+
"""Score for relevance check."""
|
155 |
+
rationale: str = Field(description="Rationale for the score.")
|
156 |
+
score: float = Field(description="Relevance score (0.0 to 1.0)")
|
157 |
+
|
158 |
+
def evaluate_documents(
|
159 |
+
docs: List[Dict[str, Any]],
|
160 |
+
query: str
|
161 |
+
) -> List[Dict[str, Any]]:
|
162 |
+
"""Evaluates document relevance to a query using configured method."""
|
163 |
+
if not docs:
|
164 |
+
return []
|
165 |
+
|
166 |
+
logger.info(f"Evaluating {len(docs)} documents for relevance to query: '{query}' using method: {settings.eval_method}")
|
167 |
+
eval_llm = get_llm(settings.eval_llm_model)
|
168 |
+
valid_docs_with_scores: List[Tuple[Dict[str, Any], float]] = []
|
169 |
+
|
170 |
+
# Consider using LCEL's structured output capabilities directly if the model supports it well
|
171 |
+
# This simplifies parsing. Example for binary:
|
172 |
+
# binary_grader = BINARY_GRADER_PROMPT | eval_llm.with_structured_output(GradeDocumentsBinary)
|
173 |
+
|
174 |
+
if settings.eval_method == "binary":
|
175 |
+
binary_grader = BINARY_GRADER_PROMPT | eval_llm | StrOutputParser() # Fallback to string parsing
|
176 |
+
for doc in docs:
|
177 |
+
formatted_doc = format_doc_for_llm(doc)
|
178 |
+
if not formatted_doc.strip(): continue
|
179 |
+
try:
|
180 |
+
result = binary_grader.invoke({"question": query, "document": formatted_doc})
|
181 |
+
logger.debug(f"Binary grader result for doc '{doc.get('title', 'N/A')}': {result}")
|
182 |
+
if result and 'yes' in result.lower():
|
183 |
+
valid_docs_with_scores.append((doc, 1.0)) # Score 1.0 for relevant
|
184 |
+
except Exception as e:
|
185 |
+
logger.warning(f"Binary grading failed for a document: {e}", exc_info=True)
|
186 |
+
|
187 |
+
elif settings.eval_method == "score":
|
188 |
+
# Using JSON parser as a robust fallback for score extraction
|
189 |
+
score_grader = SCORE_GRADER_PROMPT | eval_llm | JsonOutputParser(pydantic_object=GradeDocumentsScore)
|
190 |
+
for doc in docs:
|
191 |
+
formatted_doc = format_doc_for_llm(doc)
|
192 |
+
if not formatted_doc.strip(): continue
|
193 |
+
try:
|
194 |
+
result: GradeDocumentsScore = score_grader.invoke({"query": query, "document": formatted_doc})
|
195 |
+
logger.debug(f"Score grader result for doc '{doc.get('title', 'N/A')}': Score={result.score}, Rationale={result.rationale}")
|
196 |
+
if result.score >= settings.eval_threshold:
|
197 |
+
valid_docs_with_scores.append((doc, result.score))
|
198 |
+
except Exception as e:
|
199 |
+
logger.warning(f"Score grading failed for a document: {e}", exc_info=True)
|
200 |
+
# Optionally treat as relevant on failure? Or skip? Skipping for now.
|
201 |
+
|
202 |
+
# Sort by score if applicable, then limit
|
203 |
+
if settings.eval_method == 'score':
|
204 |
+
valid_docs_with_scores.sort(key=lambda item: item[1], reverse=True)
|
205 |
+
|
206 |
+
# Limit to max_docs
|
207 |
+
final_docs = [doc for doc, score in valid_docs_with_scores[:settings.max_docs]]
|
208 |
+
logger.info(f"Found {len(final_docs)} relevant documents after evaluation and filtering.")
|
209 |
+
|
210 |
+
return final_docs
|
kig_core/llm_interface.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
3 |
+
from langchain_openai import ChatOpenAI
|
4 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
5 |
+
from .config import settings
|
6 |
+
import logging
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
# Store initialized models to avoid re-creating them repeatedly
|
11 |
+
_llm_cache = {}
|
12 |
+
|
13 |
+
def get_llm(model_name: str) -> BaseChatModel:
|
14 |
+
"""
|
15 |
+
Returns an initialized LangChain chat model based on the provided name.
|
16 |
+
Caches initialized models.
|
17 |
+
"""
|
18 |
+
global _llm_cache
|
19 |
+
if model_name in _llm_cache:
|
20 |
+
return _llm_cache[model_name]
|
21 |
+
|
22 |
+
logger.info(f"Initializing LLM: {model_name}")
|
23 |
+
|
24 |
+
if model_name.startswith("gemini"):
|
25 |
+
if not settings.gemini_api_key:
|
26 |
+
raise ValueError("GEMINI_API_KEY is not configured.")
|
27 |
+
try:
|
28 |
+
# Use GOOGLE_API_KEY environment variable set in config.py
|
29 |
+
llm = ChatGoogleGenerativeAI(model=model_name)
|
30 |
+
_llm_cache[model_name] = llm
|
31 |
+
logger.info(f"Initialized Google Generative AI model: {model_name}")
|
32 |
+
return llm
|
33 |
+
except Exception as e:
|
34 |
+
logger.error(f"Failed to initialize Gemini model '{model_name}': {e}", exc_info=True)
|
35 |
+
raise RuntimeError(f"Could not initialize Gemini model: {e}") from e
|
36 |
+
|
37 |
+
elif model_name.startswith("gpt"):
|
38 |
+
if not settings.openai_api_key:
|
39 |
+
raise ValueError("OPENAI_API_KEY is not configured.")
|
40 |
+
try:
|
41 |
+
# Base URL can be added here if using a proxy
|
42 |
+
# base_url = "https://your-proxy-if-needed/"
|
43 |
+
llm = ChatOpenAI(model=model_name, api_key=settings.openai_api_key) # Base URL optional
|
44 |
+
_llm_cache[model_name] = llm
|
45 |
+
logger.info(f"Initialized OpenAI model: {model_name}")
|
46 |
+
return llm
|
47 |
+
except Exception as e:
|
48 |
+
logger.error(f"Failed to initialize OpenAI model '{model_name}': {e}", exc_info=True)
|
49 |
+
raise RuntimeError(f"Could not initialize OpenAI model: {e}") from e
|
50 |
+
|
51 |
+
# Add other model providers (Anthropic, Groq, etc.) here if needed
|
52 |
+
|
53 |
+
else:
|
54 |
+
logger.error(f"Unsupported model provider for model name: {model_name}")
|
55 |
+
raise ValueError(f"Model '{model_name}' is not supported or configuration is missing.")
|
56 |
+
|
57 |
+
# Example usage (could be called from other modules)
|
58 |
+
# main_llm = get_llm(settings.main_llm_model)
|
59 |
+
# eval_llm = get_llm(settings.eval_llm_model)
|
kig_core/planner.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import re
|
3 |
+
from typing import List, Dict, Any
|
4 |
+
from langgraph.graph import StateGraph, END
|
5 |
+
from langgraph.checkpoint.memory import MemorySaver # Or SqliteSaver etc.
|
6 |
+
|
7 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
8 |
+
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
|
9 |
+
|
10 |
+
from .config import settings
|
11 |
+
from .schemas import PlannerState, KeyIssue, GraphConfig # Import schemas
|
12 |
+
from .prompts import get_initial_planner_prompt, KEY_ISSUE_STRUCTURING_PROMPT
|
13 |
+
from .llm_interface import get_llm
|
14 |
+
from .graph_operations import (
|
15 |
+
generate_cypher_auto, generate_cypher_guided,
|
16 |
+
retrieve_documents, evaluate_documents
|
17 |
+
)
|
18 |
+
from .processing import process_documents
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
# --- Graph Nodes ---
|
23 |
+
|
24 |
+
def start_planning(state: PlannerState) -> Dict[str, Any]:
|
25 |
+
"""Generates the initial plan based on the user query."""
|
26 |
+
logger.info("Node: start_planning")
|
27 |
+
user_query = state['user_query']
|
28 |
+
if not user_query:
|
29 |
+
return {"error": "User query is empty."}
|
30 |
+
|
31 |
+
initial_prompt = get_initial_planner_prompt(settings.plan_method, user_query)
|
32 |
+
llm = get_llm(settings.main_llm_model)
|
33 |
+
chain = initial_prompt | llm | StrOutputParser()
|
34 |
+
|
35 |
+
try:
|
36 |
+
plan_text = chain.invoke({}) # Prompt already includes query
|
37 |
+
logger.debug(f"Raw plan text: {plan_text}")
|
38 |
+
|
39 |
+
# Extract plan steps (simple regex, might need refinement)
|
40 |
+
plan_match = re.search(r"Plan:(.*?)<END_OF_PLAN>", plan_text, re.DOTALL | re.IGNORECASE)
|
41 |
+
if plan_match:
|
42 |
+
plan_steps = [step.strip() for step in re.split(r"\n\s*\d+\.\s*", plan_match.group(1)) if step.strip()]
|
43 |
+
logger.info(f"Extracted plan: {plan_steps}")
|
44 |
+
return {
|
45 |
+
"plan": plan_steps,
|
46 |
+
"current_plan_step_index": 0,
|
47 |
+
"messages": [AIMessage(content=plan_text)],
|
48 |
+
"step_outputs": {} # Initialize step outputs
|
49 |
+
}
|
50 |
+
else:
|
51 |
+
logger.error("Could not parse plan from LLM response.")
|
52 |
+
return {"error": "Failed to parse plan from LLM response.", "messages": [AIMessage(content=plan_text)]}
|
53 |
+
except Exception as e:
|
54 |
+
logger.error(f"Error during plan generation: {e}", exc_info=True)
|
55 |
+
return {"error": f"LLM error during plan generation: {e}"}
|
56 |
+
|
57 |
+
|
58 |
+
def execute_plan_step(state: PlannerState) -> Dict[str, Any]:
|
59 |
+
"""Executes the current step of the plan (retrieval, processing)."""
|
60 |
+
current_index = state['current_plan_step_index']
|
61 |
+
plan = state['plan']
|
62 |
+
user_query = state['user_query'] # Use original query for context
|
63 |
+
|
64 |
+
if current_index >= len(plan):
|
65 |
+
logger.warning("Plan step index out of bounds, attempting to finalize.")
|
66 |
+
# This should ideally be handled by the conditional edge, but as a fallback
|
67 |
+
return {"error": "Plan execution finished unexpectedly."}
|
68 |
+
|
69 |
+
step_description = plan[current_index]
|
70 |
+
logger.info(f"Node: execute_plan_step - Step {current_index + 1}/{len(plan)}: {step_description}")
|
71 |
+
|
72 |
+
# --- Determine Query for Retrieval ---
|
73 |
+
# Simple approach: Use step description or original query?
|
74 |
+
# Let's use the step description combined with the original query for context.
|
75 |
+
query_for_retrieval = f"Regarding the query '{user_query}', focus on: {step_description}"
|
76 |
+
logger.info(f"Query for retrieval: {query_for_retrieval}")
|
77 |
+
|
78 |
+
# --- Generate Cypher ---
|
79 |
+
cypher_query = ""
|
80 |
+
if settings.cypher_gen_method == 'auto':
|
81 |
+
cypher_query = generate_cypher_auto(query_for_retrieval)
|
82 |
+
elif settings.cypher_gen_method == 'guided':
|
83 |
+
cypher_query = generate_cypher_guided(query_for_retrieval, current_index)
|
84 |
+
# TODO: Add cypher validation if settings.validate_cypher is True
|
85 |
+
|
86 |
+
# --- Retrieve Documents ---
|
87 |
+
retrieved_docs = retrieve_documents(cypher_query)
|
88 |
+
|
89 |
+
# --- Evaluate Documents ---
|
90 |
+
evaluated_docs = evaluate_documents(retrieved_docs, query_for_retrieval)
|
91 |
+
|
92 |
+
# --- Process Documents ---
|
93 |
+
# Using configured processing steps
|
94 |
+
processed_docs_content = process_documents(evaluated_docs, settings.process_steps)
|
95 |
+
|
96 |
+
# --- Store Step Output ---
|
97 |
+
# Store the processed content relevant to this step
|
98 |
+
step_output = "\n\n".join(processed_docs_content) if processed_docs_content else "No relevant information found for this step."
|
99 |
+
current_step_outputs = state.get('step_outputs', {})
|
100 |
+
current_step_outputs[current_index] = step_output
|
101 |
+
|
102 |
+
logger.info(f"Finished executing plan step {current_index + 1}. Stored output.")
|
103 |
+
|
104 |
+
return {
|
105 |
+
"current_plan_step_index": current_index + 1,
|
106 |
+
"messages": [SystemMessage(content=f"Completed plan step {current_index + 1}. Context gathered:\n{step_output[:500]}...")], # Add summary message
|
107 |
+
"step_outputs": current_step_outputs
|
108 |
+
}
|
109 |
+
|
110 |
+
|
111 |
+
def generate_structured_issues(state: PlannerState) -> Dict[str, Any]:
|
112 |
+
"""Generates the final structured Key Issues based on all gathered context."""
|
113 |
+
logger.info("Node: generate_structured_issues")
|
114 |
+
|
115 |
+
user_query = state['user_query']
|
116 |
+
step_outputs = state.get('step_outputs', {})
|
117 |
+
|
118 |
+
# --- Combine Context from All Steps ---
|
119 |
+
full_context = f"Original User Query: {user_query}\n\n"
|
120 |
+
full_context += "Context gathered during planning:\n"
|
121 |
+
for i, output in sorted(step_outputs.items()):
|
122 |
+
full_context += f"--- Context from Step {i+1} ---\n{output}\n\n"
|
123 |
+
|
124 |
+
if not step_outputs:
|
125 |
+
full_context += "No context was gathered during the planning steps.\n"
|
126 |
+
|
127 |
+
logger.info(f"Generating key issues using combined context (length: {len(full_context)} chars).")
|
128 |
+
# logger.debug(f"Full Context for Key Issue Generation:\n{full_context}") # Optional: log full context
|
129 |
+
|
130 |
+
# --- Call LLM for Structured Output ---
|
131 |
+
issue_llm = get_llm(settings.main_llm_model)
|
132 |
+
# Use PydanticOutputParser for robust parsing
|
133 |
+
output_parser = JsonOutputParser(pydantic_object=List[KeyIssue])
|
134 |
+
|
135 |
+
prompt = KEY_ISSUE_STRUCTURING_PROMPT.partial(
|
136 |
+
# schema=output_parser.get_format_instructions(), # Inject schema instructions if needed by prompt
|
137 |
+
)
|
138 |
+
|
139 |
+
chain = prompt | issue_llm | output_parser
|
140 |
+
|
141 |
+
try:
|
142 |
+
structured_issues = chain.invoke({
|
143 |
+
"user_query": user_query,
|
144 |
+
"context": full_context
|
145 |
+
})
|
146 |
+
|
147 |
+
# Ensure IDs are sequential if the LLM didn't assign them correctly
|
148 |
+
for i, issue in enumerate(structured_issues):
|
149 |
+
issue.id = i + 1
|
150 |
+
|
151 |
+
logger.info(f"Successfully generated {len(structured_issues)} structured key issues.")
|
152 |
+
final_message = f"Generated {len(structured_issues)} Key Issues based on the query '{user_query}'."
|
153 |
+
return {
|
154 |
+
"key_issues": structured_issues,
|
155 |
+
"messages": [AIMessage(content=final_message)], # Final summary message
|
156 |
+
"error": None # Clear any previous errors
|
157 |
+
}
|
158 |
+
except Exception as e:
|
159 |
+
logger.error(f"Failed to generate or parse structured key issues: {e}", exc_info=True)
|
160 |
+
# Attempt to get raw output for debugging if possible
|
161 |
+
raw_output = "Could not retrieve raw output."
|
162 |
+
try:
|
163 |
+
raw_chain = prompt | issue_llm | StrOutputParser()
|
164 |
+
raw_output = raw_chain.invoke({"user_query": user_query, "context": full_context})
|
165 |
+
logger.debug(f"Raw output from failed JSON parsing:\n{raw_output}")
|
166 |
+
except Exception as raw_e:
|
167 |
+
logger.error(f"Could not even get raw output: {raw_e}")
|
168 |
+
|
169 |
+
return {"error": f"Failed to generate structured key issues: {e}. Raw output hint: {raw_output[:500]}..."}
|
170 |
+
|
171 |
+
|
172 |
+
# --- Conditional Edges ---
|
173 |
+
|
174 |
+
def should_continue_planning(state: PlannerState) -> str:
|
175 |
+
"""Determines if there are more plan steps to execute."""
|
176 |
+
logger.debug("Edge: should_continue_planning")
|
177 |
+
if state.get("error"):
|
178 |
+
logger.error(f"Error state detected: {state['error']}. Ending execution.")
|
179 |
+
return "error_state" # Go to a potential error handling end node
|
180 |
+
|
181 |
+
current_index = state['current_plan_step_index']
|
182 |
+
plan_length = len(state.get('plan', []))
|
183 |
+
|
184 |
+
if current_index < plan_length:
|
185 |
+
logger.debug(f"Continuing plan execution. Next step index: {current_index}")
|
186 |
+
return "continue_execution"
|
187 |
+
else:
|
188 |
+
logger.debug("Plan finished. Proceeding to final generation.")
|
189 |
+
return "finalize"
|
190 |
+
|
191 |
+
|
192 |
+
# --- Build Graph ---
|
193 |
+
def build_graph():
|
194 |
+
"""Builds the LangGraph workflow."""
|
195 |
+
workflow = StateGraph(PlannerState)
|
196 |
+
|
197 |
+
# Add nodes
|
198 |
+
workflow.add_node("start_planning", start_planning)
|
199 |
+
workflow.add_node("execute_plan_step", execute_plan_step)
|
200 |
+
workflow.add_node("generate_issues", generate_structured_issues)
|
201 |
+
# Optional: Add an error handling node
|
202 |
+
workflow.add_node("error_node", lambda state: {"messages": [SystemMessage(content=f"Execution failed: {state.get('error', 'Unknown error')}") ]})
|
203 |
+
|
204 |
+
|
205 |
+
# Define edges
|
206 |
+
workflow.set_entry_point("start_planning")
|
207 |
+
workflow.add_edge("start_planning", "execute_plan_step") # Assume plan is always generated
|
208 |
+
|
209 |
+
workflow.add_conditional_edges(
|
210 |
+
"execute_plan_step",
|
211 |
+
should_continue_planning,
|
212 |
+
{
|
213 |
+
"continue_execution": "execute_plan_step", # Loop back to execute next step
|
214 |
+
"finalize": "generate_issues", # Move to final generation
|
215 |
+
"error_state": "error_node" # Go to error node
|
216 |
+
}
|
217 |
+
)
|
218 |
+
|
219 |
+
workflow.add_edge("generate_issues", END)
|
220 |
+
workflow.add_edge("error_node", END) # End after error
|
221 |
+
|
222 |
+
# Compile the graph with memory (optional)
|
223 |
+
# memory = MemorySaver() # Use if state needs persistence between runs
|
224 |
+
# app_graph = workflow.compile(checkpointer=memory)
|
225 |
+
app_graph = workflow.compile()
|
226 |
+
return app_graph
|
kig_core/processing.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import List, Dict, Any, Union
|
3 |
+
from langchain_core.output_parsers import StrOutputParser
|
4 |
+
|
5 |
+
from .config import settings
|
6 |
+
from .llm_interface import get_llm
|
7 |
+
from .prompts import SUMMARIZER_PROMPT
|
8 |
+
from .graph_operations import format_doc_for_llm # Reuse formatting
|
9 |
+
|
10 |
+
# Import llmlingua if compression is used
|
11 |
+
try:
|
12 |
+
from llmlingua import PromptCompressor
|
13 |
+
LLMLINGUA_AVAILABLE = True
|
14 |
+
except ImportError:
|
15 |
+
LLMLINGUA_AVAILABLE = False
|
16 |
+
PromptCompressor = None # Define as None if not available
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
_compressor_cache = {}
|
21 |
+
|
22 |
+
def get_compressor(method: str) -> Optional['PromptCompressor']:
|
23 |
+
"""Initializes and caches llmlingua compressor."""
|
24 |
+
if not LLMLINGUA_AVAILABLE:
|
25 |
+
logger.warning("LLMLingua not installed, compression unavailable.")
|
26 |
+
return None
|
27 |
+
if method not in _compressor_cache:
|
28 |
+
logger.info(f"Initializing LLMLingua compressor: {method}")
|
29 |
+
try:
|
30 |
+
# Adjust model names and params as needed
|
31 |
+
if method == "llm_lingua2":
|
32 |
+
model_name = "microsoft/llmlingua-2-xlm-roberta-large-meetingbank"
|
33 |
+
use_llmlingua2 = True
|
34 |
+
elif method == "llm_lingua":
|
35 |
+
model_name = "microsoft/phi-2" # Requires ~8GB RAM
|
36 |
+
use_llmlingua2 = False
|
37 |
+
else:
|
38 |
+
logger.error(f"Unsupported compression method: {method}")
|
39 |
+
return None
|
40 |
+
|
41 |
+
_compressor_cache[method] = PromptCompressor(
|
42 |
+
model_name=model_name,
|
43 |
+
use_llmlingua2=use_llmlingua2,
|
44 |
+
device_map="cpu" # Or "cuda" if GPU available
|
45 |
+
)
|
46 |
+
except Exception as e:
|
47 |
+
logger.error(f"Failed to initialize LLMLingua compressor '{method}': {e}", exc_info=True)
|
48 |
+
return None
|
49 |
+
return _compressor_cache[method]
|
50 |
+
|
51 |
+
|
52 |
+
def summarize_document(doc_content: str) -> str:
|
53 |
+
"""Summarizes a single document using the configured LLM."""
|
54 |
+
logger.debug("Summarizing document...")
|
55 |
+
try:
|
56 |
+
summarize_llm = get_llm(settings.summarize_llm_model)
|
57 |
+
summarize_chain = SUMMARIZER_PROMPT | summarize_llm | StrOutputParser()
|
58 |
+
summary = summarize_chain.invoke({"document": doc_content})
|
59 |
+
logger.debug("Summarization complete.")
|
60 |
+
return summary
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"Summarization failed: {e}", exc_info=True)
|
63 |
+
return f"Error during summarization: {e}" # Return error message instead of failing
|
64 |
+
|
65 |
+
|
66 |
+
def compress_document(doc_content: str) -> str:
|
67 |
+
"""Compresses a single document using LLMLingua."""
|
68 |
+
logger.debug(f"Compressing document using method: {settings.compression_method}...")
|
69 |
+
if not settings.compression_method:
|
70 |
+
logger.warning("Compression method not configured, skipping.")
|
71 |
+
return doc_content
|
72 |
+
|
73 |
+
compressor = get_compressor(settings.compression_method)
|
74 |
+
if not compressor:
|
75 |
+
logger.warning("Compressor not available, skipping compression.")
|
76 |
+
return doc_content
|
77 |
+
|
78 |
+
try:
|
79 |
+
# Adjust compression parameters as needed
|
80 |
+
# rate = settings.compress_rate or 0.5
|
81 |
+
# force_tokens = ['\n', '.', ',', '?', '!'] # Example tokens
|
82 |
+
# context? instructions? question?
|
83 |
+
|
84 |
+
# Simple compression for now:
|
85 |
+
result = compressor.compress_prompt(doc_content, rate=settings.compress_rate or 0.5)
|
86 |
+
compressed_text = result.get("compressed_prompt", doc_content)
|
87 |
+
|
88 |
+
original_len = len(doc_content.split())
|
89 |
+
compressed_len = len(compressed_text.split())
|
90 |
+
logger.debug(f"Compression complete. Original words: {original_len}, Compressed words: {compressed_len}")
|
91 |
+
return compressed_text
|
92 |
+
except Exception as e:
|
93 |
+
logger.error(f"Compression failed: {e}", exc_info=True)
|
94 |
+
return f"Error during compression: {e}" # Return error message
|
95 |
+
|
96 |
+
|
97 |
+
def process_documents(
|
98 |
+
docs: List[Dict[str, Any]],
|
99 |
+
processing_steps: List[Union[str, dict]]
|
100 |
+
) -> List[str]:
|
101 |
+
"""Processes a list of documents according to the specified steps."""
|
102 |
+
logger.info(f"Processing {len(docs)} documents with steps: {processing_steps}")
|
103 |
+
if not docs:
|
104 |
+
return []
|
105 |
+
|
106 |
+
processed_outputs = []
|
107 |
+
for i, doc in enumerate(docs):
|
108 |
+
logger.info(f"Processing document {i+1}/{len(docs)}...")
|
109 |
+
current_content = format_doc_for_llm(doc) # Start with formatted original doc
|
110 |
+
|
111 |
+
for step in processing_steps:
|
112 |
+
if step == "summarize":
|
113 |
+
current_content = summarize_document(current_content)
|
114 |
+
elif step == "compress":
|
115 |
+
current_content = compress_document(current_content)
|
116 |
+
elif isinstance(step, dict):
|
117 |
+
# Placeholder for custom processing steps defined by dicts
|
118 |
+
logger.warning(f"Custom processing step not implemented: {step}")
|
119 |
+
# Add logic here if needed: extract params, call specific LLM/function
|
120 |
+
pass
|
121 |
+
else:
|
122 |
+
logger.warning(f"Unknown processing step type: {step}")
|
123 |
+
|
124 |
+
processed_outputs.append(current_content) # Add the final processed content for this doc
|
125 |
+
|
126 |
+
logger.info("Document processing finished.")
|
127 |
+
return processed_outputs
|
kig_core/prompts.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
2 |
+
from .schemas import KeyIssue # Import the Pydantic model
|
3 |
+
|
4 |
+
# --- Cypher Generation ---
|
5 |
+
CYPHER_GENERATION_TEMPLATE = """Task: Generate Cypher statement to query a graph database.
|
6 |
+
Instructions:
|
7 |
+
Use only the provided relationship types and properties in the schema.
|
8 |
+
Do not use any other relationship types or properties that are not provided.
|
9 |
+
Schema:
|
10 |
+
{schema}
|
11 |
+
|
12 |
+
Note: Do not include explanations or apologies. Respond only with the Cypher statement.
|
13 |
+
Do not respond to questions unrelated to Cypher generation.
|
14 |
+
|
15 |
+
The question is:
|
16 |
+
{question}"""
|
17 |
+
CYPHER_GENERATION_PROMPT = PromptTemplate.from_template(CYPHER_GENERATION_TEMPLATE)
|
18 |
+
|
19 |
+
|
20 |
+
# --- Concept Selection (for 'guided' cypher gen) ---
|
21 |
+
CONCEPT_SELECTION_TEMPLATE = """Task: Select the most relevant Concept from the list below for the user's question.
|
22 |
+
Instructions:
|
23 |
+
Output ONLY the name of the single most relevant concept. No explanations.
|
24 |
+
|
25 |
+
Concepts:
|
26 |
+
{concepts}
|
27 |
+
|
28 |
+
User Question:
|
29 |
+
{question}"""
|
30 |
+
CONCEPT_SELECTION_PROMPT = PromptTemplate.from_template(CONCEPT_SELECTION_TEMPLATE)
|
31 |
+
|
32 |
+
|
33 |
+
# --- Document Relevance Grading ---
|
34 |
+
BINARY_GRADER_TEMPLATE = """Assess the relevance of the retrieved document to the user question.
|
35 |
+
Goal is to filter out clearly erroneous retrievals.
|
36 |
+
If the document contains keywords or semantic meaning related to the question, grade as relevant.
|
37 |
+
Output 'yes' or 'no'."""
|
38 |
+
BINARY_GRADER_PROMPT = ChatPromptTemplate.from_messages([
|
39 |
+
("system", BINARY_GRADER_TEMPLATE),
|
40 |
+
("human", "Retrieved document:\n\n{document}\n\nUser question: {question}"),
|
41 |
+
])
|
42 |
+
|
43 |
+
SCORE_GRADER_TEMPLATE = """Analyze the query and the document. Quantify the relevance.
|
44 |
+
Provide rationale before the score.
|
45 |
+
Output a score between 0 (irrelevant) and 1 (completely relevant)."""
|
46 |
+
SCORE_GRADER_PROMPT = ChatPromptTemplate.from_messages([
|
47 |
+
("system", SCORE_GRADER_TEMPLATE),
|
48 |
+
("human", "Passage:\n\n{document}\n\nUser query: {query}"),
|
49 |
+
])
|
50 |
+
|
51 |
+
|
52 |
+
# --- Planning ---
|
53 |
+
PLAN_GENERATION_TEMPLATE = """You are a standardization expert planning to identify NEW and INNOVATIVE Key Issues related to a technical requirement.
|
54 |
+
Devise a concise, step-by-step plan to achieve this.
|
55 |
+
Consider steps like: Understanding the core problem, Researching existing standards/innovations, Identifying potential gaps/challenges, Formulating Key Issues, and Refining/Detailing them.
|
56 |
+
Output the plan starting with 'Plan:' and numbering each step. End the plan with '<END_OF_PLAN>'."""
|
57 |
+
|
58 |
+
PLAN_MODIFICATION_TEMPLATE = """You are a standardization expert planning to identify NEW and INNOVATIVE Key Issues related to a technical requirement.
|
59 |
+
Adapt the following generic plan template to the specific requirement. Keep it concise.
|
60 |
+
|
61 |
+
### PLAN TEMPLATE ###
|
62 |
+
Plan:
|
63 |
+
1. **Understand Core Requirement**: Analyze the user query to define the scope.
|
64 |
+
2. **Gather Context**: Retrieve relevant specifications, standards, and recent research papers.
|
65 |
+
3. **Identify Gaps & Challenges**: Based on context, brainstorm potential new issues and challenges.
|
66 |
+
4. **Formulate Key Issues**: Structure the findings into distinct Key Issues.
|
67 |
+
5. **Refine & Detail**: Elaborate on each Key Issue, outlining specific challenges.
|
68 |
+
<END_OF_PLAN>
|
69 |
+
### END OF PLAN TEMPLATE ###
|
70 |
+
|
71 |
+
Output the adapted plan starting with 'Plan:' and numbering each step. End with '<END_OF_PLAN>'."""
|
72 |
+
|
73 |
+
|
74 |
+
# --- Document Processing ---
|
75 |
+
SUMMARIZER_TEMPLATE = """You are a 3GPP standardization expert.
|
76 |
+
Summarize the key information in the provided document in simple technical English relevant to identifying potential Key Issues. Focus on challenges, gaps, or novel aspects.
|
77 |
+
|
78 |
+
Document:
|
79 |
+
{document}"""
|
80 |
+
SUMMARIZER_PROMPT = ChatPromptTemplate.from_template(SUMMARIZER_TEMPLATE)
|
81 |
+
|
82 |
+
|
83 |
+
# --- Key Issue Structuring (New) ---
|
84 |
+
# This prompt guides the LLM to output structured Key Issues based on gathered context.
|
85 |
+
# It references the Pydantic model 'KeyIssue' for the desired format.
|
86 |
+
KEY_ISSUE_STRUCTURING_TEMPLATE = f"""Based on the provided context (summaries of relevant documents, research findings, etc.), identify and formulate distinct Key Issues related to the original user query.
|
87 |
+
|
88 |
+
User Query: {{user_query}}
|
89 |
+
|
90 |
+
Context:
|
91 |
+
{{context}}
|
92 |
+
|
93 |
+
For each Key Issue identified, provide the following information in the exact JSON format described below. Output a JSON list containing multiple KeyIssue objects.
|
94 |
+
|
95 |
+
JSON Schema for each Key Issue object:
|
96 |
+
{{
|
97 |
+
"id": "Sequential integer ID starting from 1",
|
98 |
+
"title": "Concise title for the key issue (max 15 words)",
|
99 |
+
"description": "Detailed description of the key issue (2-4 sentences)",
|
100 |
+
"challenges": ["List of specific challenges related to this issue (strings)", "Each challenge as a separate string"],
|
101 |
+
"potential_impact": "Brief description of the potential impact if not addressed (optional, max 30 words)"
|
102 |
+
}}
|
103 |
+
|
104 |
+
Example Format:
|
105 |
+
[
|
106 |
+
{{
|
107 |
+
"id": 1,
|
108 |
+
"title": "Scalability of AI Models in Low-Resource Settings",
|
109 |
+
"description": "Deploying complex AI models for healthcare diagnostics in areas with limited computational power and data connectivity presents significant scalability challenges. Existing models often require substantial resources.",
|
110 |
+
"challenges": ["High computational requirements of current models", "Intermittent or low-bandwidth network connectivity", "Lack of large, localized datasets for training/fine-tuning"],
|
111 |
+
"potential_impact": "Limits equitable access to advanced AI-driven healthcare diagnostics."
|
112 |
+
}},
|
113 |
+
{{
|
114 |
+
"id": 2,
|
115 |
+
"title": "...",
|
116 |
+
"description": "...",
|
117 |
+
"challenges": ["...", "..."],
|
118 |
+
"potential_impact": "..."
|
119 |
+
}}
|
120 |
+
]
|
121 |
+
|
122 |
+
Generate the JSON list of Key Issues based *only* on the provided context and user query. Ensure the output is a valid JSON list.
|
123 |
+
"""
|
124 |
+
KEY_ISSUE_STRUCTURING_PROMPT = ChatPromptTemplate.from_template(KEY_ISSUE_STRUCTURING_TEMPLATE)
|
125 |
+
|
126 |
+
|
127 |
+
# --- Initial Prompt Selection ---
|
128 |
+
def get_initial_planner_prompt(plan_method: str, user_query: str) -> ChatPromptTemplate:
|
129 |
+
if plan_method == "generation":
|
130 |
+
template = PLAN_GENERATION_TEMPLATE
|
131 |
+
elif plan_method == "modification":
|
132 |
+
template = PLAN_MODIFICATION_TEMPLATE
|
133 |
+
else:
|
134 |
+
raise ValueError("Invalid plan_method")
|
135 |
+
|
136 |
+
# Return as ChatPromptTemplate for consistency
|
137 |
+
return ChatPromptTemplate.from_messages([
|
138 |
+
("system", template),
|
139 |
+
("human", user_query)
|
140 |
+
])
|
kig_core/schemas.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Any, Optional, Union
|
2 |
+
from typing_extensions import TypedDict
|
3 |
+
from langchain_core.messages import BaseMessage
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
from langgraph.graph.message import add_messages
|
6 |
+
|
7 |
+
# --- Pydantic Models for Structured Output ---
|
8 |
+
|
9 |
+
class KeyIssue(BaseModel):
|
10 |
+
"""Represents a single generated Key Issue."""
|
11 |
+
id: int = Field(..., description="Sequential ID for the key issue")
|
12 |
+
title: str = Field(..., description="A concise title for the key issue")
|
13 |
+
description: str = Field(..., description="Detailed description of the key issue")
|
14 |
+
challenges: List[str] = Field(default_factory=list, description="Specific challenges associated with this issue")
|
15 |
+
potential_impact: Optional[str] = Field(None, description="Potential impact if the issue is not addressed")
|
16 |
+
# Add source tracking if possible/needed from the processed docs
|
17 |
+
# sources: List[str] = Field(default_factory=list, description="Source documents relevant to this issue")
|
18 |
+
|
19 |
+
|
20 |
+
# --- TypedDicts for LangGraph State ---
|
21 |
+
|
22 |
+
class GraphConfig(TypedDict):
|
23 |
+
"""Configuration passed to the graph execution."""
|
24 |
+
thread_id: str
|
25 |
+
# Add other config items needed at runtime if not globally available via settings
|
26 |
+
|
27 |
+
class BaseState(TypedDict):
|
28 |
+
"""Base state common across potentially multiple graphs."""
|
29 |
+
messages: Annotated[List[BaseMessage], add_messages]
|
30 |
+
error: Optional[str] # To store potential errors during execution
|
31 |
+
|
32 |
+
class PlannerState(BaseState):
|
33 |
+
"""State specific to the main planner graph."""
|
34 |
+
user_query: str
|
35 |
+
plan: List[str] # The high-level plan steps
|
36 |
+
current_plan_step_index: int # Index of the current step being executed
|
37 |
+
# Stored data from previous steps (e.g., summaries)
|
38 |
+
# Use a dictionary to store context relevant to each plan step
|
39 |
+
step_outputs: Dict[int, Any] # Stores output (e.g., processed docs) from each step
|
40 |
+
# Final structured output
|
41 |
+
key_issues: List[KeyIssue]
|
42 |
+
|
43 |
+
|
44 |
+
class DataRetrievalState(TypedDict):
|
45 |
+
"""State for a potential data retrieval sub-graph."""
|
46 |
+
query_for_retrieval: str # The specific query for this retrieval step
|
47 |
+
retrieved_docs: List[Dict[str, Any]] # Raw docs from Neo4j
|
48 |
+
evaluated_docs: List[Dict[str, Any]] # Docs after relevance grading
|
49 |
+
cypher_queries: List[str] # Generated Cypher queries
|
50 |
+
|
51 |
+
class ProcessingState(TypedDict):
|
52 |
+
"""State for a potential document processing sub-graph."""
|
53 |
+
docs_to_process: List[Dict[str, Any]] # Documents passed for processing
|
54 |
+
processed_docs: List[Union[str, Dict[str, Any]]] # Processed/summarized docs
|
55 |
+
processing_steps_config: List[Union[str, dict]] # Configuration for processing
|
kig_core/utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import io
|
3 |
+
import logging
|
4 |
+
from typing import List
|
5 |
+
from .schemas import KeyIssue # Import the Pydantic model
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
def key_issues_to_dataframe(key_issues: List[KeyIssue]) -> pd.DataFrame:
|
10 |
+
"""Converts a list of KeyIssue objects into a Pandas DataFrame."""
|
11 |
+
if not key_issues:
|
12 |
+
return pd.DataFrame()
|
13 |
+
# Use Pydantic's .model_dump() for robust serialization
|
14 |
+
data = [ki.model_dump() for ki in key_issues]
|
15 |
+
df = pd.DataFrame(data)
|
16 |
+
# Optional: Reorder or rename columns if needed
|
17 |
+
# df = df[['id', 'title', 'description', 'challenges', 'potential_impact']] # Example reordering
|
18 |
+
return df
|
19 |
+
|
20 |
+
def dataframe_to_excel_bytes(df: pd.DataFrame) -> bytes:
|
21 |
+
"""Converts a Pandas DataFrame to Excel format in memory (bytes)."""
|
22 |
+
logger.info("Generating Excel file from DataFrame...")
|
23 |
+
output = io.BytesIO()
|
24 |
+
try:
|
25 |
+
# Use BytesIO object as the target file
|
26 |
+
with pd.ExcelWriter(output, engine='openpyxl') as writer:
|
27 |
+
df.to_excel(writer, index=False, sheet_name='Key Issues')
|
28 |
+
excel_data = output.getvalue()
|
29 |
+
logger.info("Excel file generated successfully.")
|
30 |
+
return excel_data
|
31 |
+
except Exception as e:
|
32 |
+
logger.error(f"Failed to generate Excel file: {e}", exc_info=True)
|
33 |
+
raise RuntimeError("Failed to create Excel output.") from e
|
34 |
+
|
35 |
+
# Removed: format_df (HTML specific, less relevant for Excel output)
|
36 |
+
# Removed: init_app (handled by config.py)
|
37 |
+
# Removed: get_model (handled by llm_interface.py)
|
38 |
+
# Removed: clear_memory (handle state/memory management within LangGraph setup if needed)
|
39 |
+
# Removed: _set_env (handled by config.py and dotenv)
|
40 |
+
# Kept: format_doc (renamed to format_doc_for_llm in graph_operations.py)
|
41 |
+
# Removed: update_doc_history (reducer logic should be handled in LangGraph state definition/nodes)
|
requirements.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Langchain Core & Ecosystem
|
2 |
+
langchain-core>=0.2.29
|
3 |
+
langchain-google-genai>=1.0.9 # For Gemini
|
4 |
+
langchain-openai>=0.1.21 # If using OpenAI
|
5 |
+
langgraph>=0.1.10
|
6 |
+
langchain-community>=0.2.10 # For Neo4jGraph if needed, other community integrations
|
7 |
+
|
8 |
+
# LLM & Processing Libraries
|
9 |
+
# llmlingua==0.2.2 # Uncomment if using compression
|
10 |
+
google-generativeai>=0.7.2 # Underlying Gemini library
|
11 |
+
|
12 |
+
# Neo4j
|
13 |
+
neo4j>=5.24.0
|
14 |
+
|
15 |
+
# Streamlit & Data Handling
|
16 |
+
streamlit>=1.31.0
|
17 |
+
pandas>=2.1.3
|
18 |
+
openpyxl>=3.1.5 # For Excel writing with Pandas
|
19 |
+
|
20 |
+
# Configuration & Utilities
|
21 |
+
pydantic>=2.9.0
|
22 |
+
pydantic-settings>=2.4.0 # For BaseSettings
|
23 |
+
python-dotenv>=1.0.1 # For loading .env files
|
24 |
+
|
25 |
+
# Optional: For LangSmith Tracing
|
26 |
+
# langsmith>=0.1.100
|