KIG / kig_core /graph_operations.py
heymenn's picture
Update kig_core/graph_operations.py
bc4ffe8 verified
import re
import logging
import time
from typing import List, Dict, Any, Optional, Tuple
from random import sample, shuffle
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.pydantic_v1 import Field, BaseModel as V1BaseModel # For grader models if needed
from .config import settings
from .graph_client import neo4j_client # Use the central client
from .llm_interface import get_llm, invoke_llm
from .prompts import (
CYPHER_GENERATION_PROMPT, CONCEPT_SELECTION_PROMPT,
BINARY_GRADER_PROMPT, SCORE_GRADER_PROMPT
)
from .schemas import KeyIssue # Import if needed here, maybe not
logger = logging.getLogger(__name__)
# --- Helper Functions ---
def extract_cypher(text: str) -> str:
"""Extracts the first Cypher code block or returns the text itself."""
pattern = r"```(?:cypher)?\s*(.*?)\s*```"
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
return match.group(1).strip() if match else text.strip()
def format_doc_for_llm(doc: Dict[str, Any]) -> str:
"""Formats a document dictionary into a string for LLM context."""
return "\n".join(f"**{key}**: {value}" for key, value in doc.items() if value)
# --- Cypher Generation ---
def generate_cypher_auto(question: str) -> str:
"""Generates Cypher using the 'auto' method."""
logger.info("Generating Cypher using 'auto' method.")
# Schema fetching needs implementation if required by the prompt/LLM
# schema_info = neo4j_client.get_schema() # Placeholder
schema_info = "Schema not available." # Default if not implemented
cypher_llm = get_llm(settings.main_llm_model) # Or a specific cypher model
chain = (
{"question": RunnablePassthrough(), "schema": lambda x: schema_info}
| CYPHER_GENERATION_PROMPT
| cypher_llm
| StrOutputParser()
| extract_cypher
)
return invoke_llm(chain,question)
def generate_cypher_guided(question: str, plan_step: int) -> str:
"""Generates Cypher using the 'guided' method based on concepts."""
logger.info(f"Generating Cypher using 'guided' method for plan step {plan_step}.")
try:
concepts = neo4j_client.get_concepts()
if not concepts:
logger.warning("No concepts found in Neo4j for guided cypher generation.")
return "" # Or raise error
concept_llm = get_llm(settings.main_llm_model) # Or a specific concept model
concept_chain = (
CONCEPT_SELECTION_PROMPT
| concept_llm
| StrOutputParser()
)
selected_concept = invoke_llm(concept_chain,{
"question": question,
"concepts": "\n".join(concepts)
}).strip()
logger.info(f"Concept selected by LLM: {selected_concept}")
# Basic check if the selected concept is valid
if selected_concept not in concepts:
logger.warning(f"LLM selected concept '{selected_concept}' not in the known list. Attempting fallback or ignoring.")
# Optional: Add fuzzy matching or similarity search here
# For now, we might default or return empty
# Let's try a simple substring check as a fallback
found_match = None
for c in concepts:
if selected_concept.lower() in c.lower():
found_match = c
logger.info(f"Found potential match: '{found_match}'")
break
if not found_match:
logger.error(f"Could not validate selected concept: {selected_concept}")
return "" # Return empty query if concept is invalid
selected_concept = found_match
# Determine the target node type based on plan step (example logic)
# This mapping might need adjustment based on the actual plan structure
if plan_step <= 1: # Steps 0 and 1: Context gathering
target = "(ts:TechnicalSpecification)"
fields = "ts.title, ts.scope, ts.description"
elif plan_step == 2: # Step 2: Research papers?
target = "(rp:ResearchPaper)"
fields = "rp.title, rp.abstract"
else: # Later steps might involve KeyIssues themselves or other types
target = "(n)" # Generic fallback
fields = "n.title, n.description" # Assuming common fields
# Construct Cypher query
# Ensure selected_concept is properly escaped if needed, though parameters are safer
cypher = f"MATCH (c:Concept {{name: $conceptName}})-[:RELATED_TO]-{target} RETURN {fields}"
# We return the query and the parameters separately for safe execution
# However, the planner currently expects just the string. Let's construct it directly for now.
# Be cautious about injection if concept names can contain special chars. Binding is preferred.
escaped_concept = selected_concept.replace("'", "\\'") # Basic escaping
cypher = f"MATCH (c:Concept {{name: '{escaped_concept}'}})-[:RELATED_TO]-{target} RETURN {fields}"
logger.info(f"Generated guided Cypher: {cypher}")
return cypher
except Exception as e:
logger.error(f"Error during guided cypher generation: {e}", exc_info=True)
time.sleep(60)
return "" # Return empty on error
# --- Document Retrieval ---
def retrieve_documents(cypher_query: str) -> List[Dict[str, Any]]:
"""Retrieves documents from Neo4j using a Cypher query."""
if not cypher_query:
logger.warning("Received empty Cypher query, skipping retrieval.")
return []
logger.info(f"Retrieving documents with Cypher: {cypher_query} limit 10")
try:
# Use the centralized client's query method
raw_results = neo4j_client.query(cypher_query + " limit 10")
# Basic cleaning/deduplication (can be enhanced)
processed_results = []
seen = set()
for doc in raw_results:
# Create a frozenset of items for hashable representation to detect duplicates
doc_items = frozenset(doc.items())
if doc_items not in seen:
processed_results.append(doc)
seen.add(doc_items)
logger.info(f"Retrieved {len(processed_results)} unique documents.")
return processed_results
except (ConnectionError, ValueError, RuntimeError) as e:
# Errors already logged in neo4j_client
logger.error(f"Document retrieval failed: {e}")
return [] # Return empty list on failure
# --- Document Evaluation ---
# Define Pydantic models for structured LLM grader output (if not using built-in LCEL structured output)
class GradeDocumentsBinary(V1BaseModel):
"""Binary score for relevance check."""
binary_score: str = Field(description="Relevant? 'yes' or 'no'")
class GradeDocumentsScore(V1BaseModel):
"""Score for relevance check."""
rationale: str = Field(description="Rationale for the score.")
score: float = Field(description="Relevance score (0.0 to 1.0)")
def evaluate_documents(
docs: List[Dict[str, Any]],
query: str
) -> List[Dict[str, Any]]:
"""Evaluates document relevance to a query using configured method."""
if not docs:
return []
logger.info(f"Evaluating {len(docs)} documents for relevance to query: '{query}' using method: {settings.eval_method}")
eval_llm = get_llm(settings.eval_llm_model)
valid_docs_with_scores: List[Tuple[Dict[str, Any], float]] = []
# Consider using LCEL's structured output capabilities directly if the model supports it well
# This simplifies parsing. Example for binary:
# binary_grader = BINARY_GRADER_PROMPT | eval_llm.with_structured_output(GradeDocumentsBinary)
if settings.eval_method == "binary":
binary_grader = BINARY_GRADER_PROMPT | eval_llm | StrOutputParser() # Fallback to string parsing
for doc in docs:
formatted_doc = format_doc_for_llm(doc)
if not formatted_doc.strip(): continue
try:
result = invoke_llm(binary_grader,{"question": query, "document": formatted_doc})
logger.debug(f"Binary grader result for doc '{doc.get('title', 'N/A')}': {result}")
if result and 'yes' in result.lower():
valid_docs_with_scores.append((doc, 1.0)) # Score 1.0 for relevant
except Exception as e:
logger.warning(f"Binary grading failed for a document: {e}", exc_info=True)
elif settings.eval_method == "score":
# Using JSON parser as a robust fallback for score extraction
score_grader = SCORE_GRADER_PROMPT | eval_llm | JsonOutputParser(pydantic_object=GradeDocumentsScore)
for doc in docs:
formatted_doc = format_doc_for_llm(doc)
if not formatted_doc.strip(): continue
try:
result: GradeDocumentsScore = invoke_llm(score_grader,{"query": query, "document": formatted_doc})
logger.debug(f"Score grader result for doc '{doc.get('title', 'N/A')}': Score={result.score}, Rationale={result.rationale}")
if result.score >= settings.eval_threshold:
valid_docs_with_scores.append((doc, result.score))
except Exception as e:
logger.warning(f"Score grading failed for a document: {e}", exc_info=True)
# Optionally treat as relevant on failure? Or skip? Skipping for now.
# Sort by score if applicable, then limit
if settings.eval_method == 'score':
valid_docs_with_scores.sort(key=lambda item: item[1], reverse=True)
# Limit to max_docs
final_docs = [doc for doc, score in valid_docs_with_scores[:settings.max_docs]]
logger.info(f"Found {len(final_docs)} relevant documents after evaluation and filtering.")
return final_docs