Spaces:
Running
Running
File size: 9,934 Bytes
ec6d5f9 7c3a33f ec6d5f9 a81e04e ec6d5f9 2b54433 ec6d5f9 2b54433 ec6d5f9 7c3a33f ec6d5f9 bc4ffe8 ec6d5f9 bc4ffe8 ec6d5f9 2b54433 ec6d5f9 2b54433 ec6d5f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
import re
import logging
import time
from typing import List, Dict, Any, Optional, Tuple
from random import sample, shuffle
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.pydantic_v1 import Field, BaseModel as V1BaseModel # For grader models if needed
from .config import settings
from .graph_client import neo4j_client # Use the central client
from .llm_interface import get_llm, invoke_llm
from .prompts import (
CYPHER_GENERATION_PROMPT, CONCEPT_SELECTION_PROMPT,
BINARY_GRADER_PROMPT, SCORE_GRADER_PROMPT
)
from .schemas import KeyIssue # Import if needed here, maybe not
logger = logging.getLogger(__name__)
# --- Helper Functions ---
def extract_cypher(text: str) -> str:
"""Extracts the first Cypher code block or returns the text itself."""
pattern = r"```(?:cypher)?\s*(.*?)\s*```"
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
return match.group(1).strip() if match else text.strip()
def format_doc_for_llm(doc: Dict[str, Any]) -> str:
"""Formats a document dictionary into a string for LLM context."""
return "\n".join(f"**{key}**: {value}" for key, value in doc.items() if value)
# --- Cypher Generation ---
def generate_cypher_auto(question: str) -> str:
"""Generates Cypher using the 'auto' method."""
logger.info("Generating Cypher using 'auto' method.")
# Schema fetching needs implementation if required by the prompt/LLM
# schema_info = neo4j_client.get_schema() # Placeholder
schema_info = "Schema not available." # Default if not implemented
cypher_llm = get_llm(settings.main_llm_model) # Or a specific cypher model
chain = (
{"question": RunnablePassthrough(), "schema": lambda x: schema_info}
| CYPHER_GENERATION_PROMPT
| cypher_llm
| StrOutputParser()
| extract_cypher
)
return invoke_llm(chain,question)
def generate_cypher_guided(question: str, plan_step: int) -> str:
"""Generates Cypher using the 'guided' method based on concepts."""
logger.info(f"Generating Cypher using 'guided' method for plan step {plan_step}.")
try:
concepts = neo4j_client.get_concepts()
if not concepts:
logger.warning("No concepts found in Neo4j for guided cypher generation.")
return "" # Or raise error
concept_llm = get_llm(settings.main_llm_model) # Or a specific concept model
concept_chain = (
CONCEPT_SELECTION_PROMPT
| concept_llm
| StrOutputParser()
)
selected_concept = invoke_llm(concept_chain,{
"question": question,
"concepts": "\n".join(concepts)
}).strip()
logger.info(f"Concept selected by LLM: {selected_concept}")
# Basic check if the selected concept is valid
if selected_concept not in concepts:
logger.warning(f"LLM selected concept '{selected_concept}' not in the known list. Attempting fallback or ignoring.")
# Optional: Add fuzzy matching or similarity search here
# For now, we might default or return empty
# Let's try a simple substring check as a fallback
found_match = None
for c in concepts:
if selected_concept.lower() in c.lower():
found_match = c
logger.info(f"Found potential match: '{found_match}'")
break
if not found_match:
logger.error(f"Could not validate selected concept: {selected_concept}")
return "" # Return empty query if concept is invalid
selected_concept = found_match
# Determine the target node type based on plan step (example logic)
# This mapping might need adjustment based on the actual plan structure
if plan_step <= 1: # Steps 0 and 1: Context gathering
target = "(ts:TechnicalSpecification)"
fields = "ts.title, ts.scope, ts.description"
elif plan_step == 2: # Step 2: Research papers?
target = "(rp:ResearchPaper)"
fields = "rp.title, rp.abstract"
else: # Later steps might involve KeyIssues themselves or other types
target = "(n)" # Generic fallback
fields = "n.title, n.description" # Assuming common fields
# Construct Cypher query
# Ensure selected_concept is properly escaped if needed, though parameters are safer
cypher = f"MATCH (c:Concept {{name: $conceptName}})-[:RELATED_TO]-{target} RETURN {fields}"
# We return the query and the parameters separately for safe execution
# However, the planner currently expects just the string. Let's construct it directly for now.
# Be cautious about injection if concept names can contain special chars. Binding is preferred.
escaped_concept = selected_concept.replace("'", "\\'") # Basic escaping
cypher = f"MATCH (c:Concept {{name: '{escaped_concept}'}})-[:RELATED_TO]-{target} RETURN {fields}"
logger.info(f"Generated guided Cypher: {cypher}")
return cypher
except Exception as e:
logger.error(f"Error during guided cypher generation: {e}", exc_info=True)
time.sleep(60)
return "" # Return empty on error
# --- Document Retrieval ---
def retrieve_documents(cypher_query: str) -> List[Dict[str, Any]]:
"""Retrieves documents from Neo4j using a Cypher query."""
if not cypher_query:
logger.warning("Received empty Cypher query, skipping retrieval.")
return []
logger.info(f"Retrieving documents with Cypher: {cypher_query} limit 10")
try:
# Use the centralized client's query method
raw_results = neo4j_client.query(cypher_query + " limit 10")
# Basic cleaning/deduplication (can be enhanced)
processed_results = []
seen = set()
for doc in raw_results:
# Create a frozenset of items for hashable representation to detect duplicates
doc_items = frozenset(doc.items())
if doc_items not in seen:
processed_results.append(doc)
seen.add(doc_items)
logger.info(f"Retrieved {len(processed_results)} unique documents.")
return processed_results
except (ConnectionError, ValueError, RuntimeError) as e:
# Errors already logged in neo4j_client
logger.error(f"Document retrieval failed: {e}")
return [] # Return empty list on failure
# --- Document Evaluation ---
# Define Pydantic models for structured LLM grader output (if not using built-in LCEL structured output)
class GradeDocumentsBinary(V1BaseModel):
"""Binary score for relevance check."""
binary_score: str = Field(description="Relevant? 'yes' or 'no'")
class GradeDocumentsScore(V1BaseModel):
"""Score for relevance check."""
rationale: str = Field(description="Rationale for the score.")
score: float = Field(description="Relevance score (0.0 to 1.0)")
def evaluate_documents(
docs: List[Dict[str, Any]],
query: str
) -> List[Dict[str, Any]]:
"""Evaluates document relevance to a query using configured method."""
if not docs:
return []
logger.info(f"Evaluating {len(docs)} documents for relevance to query: '{query}' using method: {settings.eval_method}")
eval_llm = get_llm(settings.eval_llm_model)
valid_docs_with_scores: List[Tuple[Dict[str, Any], float]] = []
# Consider using LCEL's structured output capabilities directly if the model supports it well
# This simplifies parsing. Example for binary:
# binary_grader = BINARY_GRADER_PROMPT | eval_llm.with_structured_output(GradeDocumentsBinary)
if settings.eval_method == "binary":
binary_grader = BINARY_GRADER_PROMPT | eval_llm | StrOutputParser() # Fallback to string parsing
for doc in docs:
formatted_doc = format_doc_for_llm(doc)
if not formatted_doc.strip(): continue
try:
result = invoke_llm(binary_grader,{"question": query, "document": formatted_doc})
logger.debug(f"Binary grader result for doc '{doc.get('title', 'N/A')}': {result}")
if result and 'yes' in result.lower():
valid_docs_with_scores.append((doc, 1.0)) # Score 1.0 for relevant
except Exception as e:
logger.warning(f"Binary grading failed for a document: {e}", exc_info=True)
elif settings.eval_method == "score":
# Using JSON parser as a robust fallback for score extraction
score_grader = SCORE_GRADER_PROMPT | eval_llm | JsonOutputParser(pydantic_object=GradeDocumentsScore)
for doc in docs:
formatted_doc = format_doc_for_llm(doc)
if not formatted_doc.strip(): continue
try:
result: GradeDocumentsScore = invoke_llm(score_grader,{"query": query, "document": formatted_doc})
logger.debug(f"Score grader result for doc '{doc.get('title', 'N/A')}': Score={result.score}, Rationale={result.rationale}")
if result.score >= settings.eval_threshold:
valid_docs_with_scores.append((doc, result.score))
except Exception as e:
logger.warning(f"Score grading failed for a document: {e}", exc_info=True)
# Optionally treat as relevant on failure? Or skip? Skipping for now.
# Sort by score if applicable, then limit
if settings.eval_method == 'score':
valid_docs_with_scores.sort(key=lambda item: item[1], reverse=True)
# Limit to max_docs
final_docs = [doc for doc, score in valid_docs_with_scores[:settings.max_docs]]
logger.info(f"Found {len(final_docs)} relevant documents after evaluation and filtering.")
return final_docs |