File size: 9,934 Bytes
1bcef92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import re
import logging
import time
from typing import List, Dict, Any, Optional, Tuple
from random import sample, shuffle

from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.pydantic_v1 import Field, BaseModel as V1BaseModel # For grader models if needed

from .config import settings
from .graph_client import neo4j_client # Use the central client
from .llm_interface import get_llm, invoke_llm
from .prompts import (
    CYPHER_GENERATION_PROMPT, CONCEPT_SELECTION_PROMPT,
    BINARY_GRADER_PROMPT, SCORE_GRADER_PROMPT
)
from .schemas import KeyIssue # Import if needed here, maybe not

logger = logging.getLogger(__name__)

# --- Helper Functions ---
def extract_cypher(text: str) -> str:
    """Extracts the first Cypher code block or returns the text itself."""
    pattern = r"```(?:cypher)?\s*(.*?)\s*```"
    match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
    return match.group(1).strip() if match else text.strip()

def format_doc_for_llm(doc: Dict[str, Any]) -> str:
    """Formats a document dictionary into a string for LLM context."""
    return "\n".join(f"**{key}**: {value}" for key, value in doc.items() if value)


# --- Cypher Generation ---
def generate_cypher_auto(question: str) -> str:
    """Generates Cypher using the 'auto' method."""
    logger.info("Generating Cypher using 'auto' method.")
    # Schema fetching needs implementation if required by the prompt/LLM
    # schema_info = neo4j_client.get_schema() # Placeholder
    schema_info = "Schema not available." # Default if not implemented

    cypher_llm = get_llm(settings.main_llm_model) # Or a specific cypher model
    chain = (
        {"question": RunnablePassthrough(), "schema": lambda x: schema_info}
        | CYPHER_GENERATION_PROMPT
        | cypher_llm
        | StrOutputParser()
        | extract_cypher
    )
    return invoke_llm(chain,question)

def generate_cypher_guided(question: str, plan_step: int) -> str:
    """Generates Cypher using the 'guided' method based on concepts."""
    logger.info(f"Generating Cypher using 'guided' method for plan step {plan_step}.")
    try:
        concepts = neo4j_client.get_concepts()
        if not concepts:
            logger.warning("No concepts found in Neo4j for guided cypher generation.")
            return "" # Or raise error

        concept_llm = get_llm(settings.main_llm_model) # Or a specific concept model
        concept_chain = (
            CONCEPT_SELECTION_PROMPT
            | concept_llm
            | StrOutputParser()
        )
        selected_concept = invoke_llm(concept_chain,{
            "question": question,
            "concepts": "\n".join(concepts)
        }).strip()

        logger.info(f"Concept selected by LLM: {selected_concept}")

        # Basic check if the selected concept is valid
        if selected_concept not in concepts:
             logger.warning(f"LLM selected concept '{selected_concept}' not in the known list. Attempting fallback or ignoring.")
             # Optional: Add fuzzy matching or similarity search here
             # For now, we might default or return empty
             # Let's try a simple substring check as a fallback
             found_match = None
             for c in concepts:
                 if selected_concept.lower() in c.lower():
                     found_match = c
                     logger.info(f"Found potential match: '{found_match}'")
                     break
             if not found_match:
                 logger.error(f"Could not validate selected concept: {selected_concept}")
                 return "" # Return empty query if concept is invalid
             selected_concept = found_match


        # Determine the target node type based on plan step (example logic)
        # This mapping might need adjustment based on the actual plan structure
        if plan_step <= 1: # Steps 0 and 1: Context gathering
             target = "(ts:TechnicalSpecification)"
             fields = "ts.title, ts.scope, ts.description"
        elif plan_step == 2: # Step 2: Research papers?
             target = "(rp:ResearchPaper)"
             fields = "rp.title, rp.abstract"
        else: # Later steps might involve KeyIssues themselves or other types
             target = "(n)" # Generic fallback
             fields = "n.title, n.description" # Assuming common fields

        # Construct Cypher query
        # Ensure selected_concept is properly escaped if needed, though parameters are safer
        cypher = f"MATCH (c:Concept {{name: $conceptName}})-[:RELATED_TO]-{target} RETURN {fields}"
        # We return the query and the parameters separately for safe execution
        # However, the planner currently expects just the string. Let's construct it directly for now.
        # Be cautious about injection if concept names can contain special chars. Binding is preferred.
        escaped_concept = selected_concept.replace("'", "\\'") # Basic escaping
        cypher = f"MATCH (c:Concept {{name: '{escaped_concept}'}})-[:RELATED_TO]-{target} RETURN {fields}"

        logger.info(f"Generated guided Cypher: {cypher}")
        return cypher

    except Exception as e:
        logger.error(f"Error during guided cypher generation: {e}", exc_info=True)
        time.sleep(60)
        return "" # Return empty on error


# --- Document Retrieval ---
def retrieve_documents(cypher_query: str) -> List[Dict[str, Any]]:
    """Retrieves documents from Neo4j using a Cypher query."""
    if not cypher_query:
        logger.warning("Received empty Cypher query, skipping retrieval.")
        return []
    logger.info(f"Retrieving documents with Cypher: {cypher_query} limit 10")
    try:
        # Use the centralized client's query method
        raw_results = neo4j_client.query(cypher_query + " limit 10")
        # Basic cleaning/deduplication (can be enhanced)
        processed_results = []
        seen = set()
        for doc in raw_results:
             # Create a frozenset of items for hashable representation to detect duplicates
             doc_items = frozenset(doc.items())
             if doc_items not in seen:
                 processed_results.append(doc)
                 seen.add(doc_items)
        logger.info(f"Retrieved {len(processed_results)} unique documents.")
        return processed_results
    except (ConnectionError, ValueError, RuntimeError) as e:
        # Errors already logged in neo4j_client
        logger.error(f"Document retrieval failed: {e}")
        return [] # Return empty list on failure


# --- Document Evaluation ---
# Define Pydantic models for structured LLM grader output (if not using built-in LCEL structured output)
class GradeDocumentsBinary(V1BaseModel):
    """Binary score for relevance check."""
    binary_score: str = Field(description="Relevant? 'yes' or 'no'")

class GradeDocumentsScore(V1BaseModel):
     """Score for relevance check."""
     rationale: str = Field(description="Rationale for the score.")
     score: float = Field(description="Relevance score (0.0 to 1.0)")

def evaluate_documents(
    docs: List[Dict[str, Any]],
    query: str
) -> List[Dict[str, Any]]:
    """Evaluates document relevance to a query using configured method."""
    if not docs:
        return []

    logger.info(f"Evaluating {len(docs)} documents for relevance to query: '{query}' using method: {settings.eval_method}")
    eval_llm = get_llm(settings.eval_llm_model)
    valid_docs_with_scores: List[Tuple[Dict[str, Any], float]] = []

    # Consider using LCEL's structured output capabilities directly if the model supports it well
    # This simplifies parsing. Example for binary:
    # binary_grader = BINARY_GRADER_PROMPT | eval_llm.with_structured_output(GradeDocumentsBinary)

    if settings.eval_method == "binary":
        binary_grader = BINARY_GRADER_PROMPT | eval_llm | StrOutputParser() # Fallback to string parsing
        for doc in docs:
            formatted_doc = format_doc_for_llm(doc)
            if not formatted_doc.strip(): continue
            try:
                result = invoke_llm(binary_grader,{"question": query, "document": formatted_doc})
                logger.debug(f"Binary grader result for doc '{doc.get('title', 'N/A')}': {result}")
                if result and 'yes' in result.lower():
                    valid_docs_with_scores.append((doc, 1.0)) # Score 1.0 for relevant
            except Exception as e:
                logger.warning(f"Binary grading failed for a document: {e}", exc_info=True)

    elif settings.eval_method == "score":
        # Using JSON parser as a robust fallback for score extraction
        score_grader = SCORE_GRADER_PROMPT | eval_llm | JsonOutputParser(pydantic_object=GradeDocumentsScore)
        for doc in docs:
            formatted_doc = format_doc_for_llm(doc)
            if not formatted_doc.strip(): continue
            try:
                result: GradeDocumentsScore = invoke_llm(score_grader,{"query": query, "document": formatted_doc})
                logger.debug(f"Score grader result for doc '{doc.get('title', 'N/A')}': Score={result.score}, Rationale={result.rationale}")
                if result.score >= settings.eval_threshold:
                    valid_docs_with_scores.append((doc, result.score))
            except Exception as e:
                logger.warning(f"Score grading failed for a document: {e}", exc_info=True)
                # Optionally treat as relevant on failure? Or skip? Skipping for now.

    # Sort by score if applicable, then limit
    if settings.eval_method == 'score':
        valid_docs_with_scores.sort(key=lambda item: item[1], reverse=True)

    # Limit to max_docs
    final_docs = [doc for doc, score in valid_docs_with_scores[:settings.max_docs]]
    logger.info(f"Found {len(final_docs)} relevant documents after evaluation and filtering.")

    return final_docs