Spaces:
Running
Running
import re | |
import logging | |
import time | |
from typing import List, Dict, Any, Optional, Tuple | |
from random import sample, shuffle | |
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser | |
from langchain_core.runnables import Runnable, RunnablePassthrough | |
from langchain_core.pydantic_v1 import Field, BaseModel as V1BaseModel # For grader models if needed | |
from .config import settings | |
from .graph_client import neo4j_client # Use the central client | |
from .llm_interface import get_llm, invoke_llm | |
from .prompts import ( | |
CYPHER_GENERATION_PROMPT, CONCEPT_SELECTION_PROMPT, | |
BINARY_GRADER_PROMPT, SCORE_GRADER_PROMPT | |
) | |
from .schemas import KeyIssue # Import if needed here, maybe not | |
logger = logging.getLogger(__name__) | |
# --- Helper Functions --- | |
def extract_cypher(text: str) -> str: | |
"""Extracts the first Cypher code block or returns the text itself.""" | |
pattern = r"```(?:cypher)?\s*(.*?)\s*```" | |
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) | |
return match.group(1).strip() if match else text.strip() | |
def format_doc_for_llm(doc: Dict[str, Any]) -> str: | |
"""Formats a document dictionary into a string for LLM context.""" | |
return "\n".join(f"**{key}**: {value}" for key, value in doc.items() if value) | |
# --- Cypher Generation --- | |
def generate_cypher_auto(question: str) -> str: | |
"""Generates Cypher using the 'auto' method.""" | |
logger.info("Generating Cypher using 'auto' method.") | |
# Schema fetching needs implementation if required by the prompt/LLM | |
# schema_info = neo4j_client.get_schema() # Placeholder | |
schema_info = "Schema not available." # Default if not implemented | |
cypher_llm = get_llm(settings.main_llm_model) # Or a specific cypher model | |
chain = ( | |
{"question": RunnablePassthrough(), "schema": lambda x: schema_info} | |
| CYPHER_GENERATION_PROMPT | |
| cypher_llm | |
| StrOutputParser() | |
| extract_cypher | |
) | |
return invoke_llm(chain,question) | |
def generate_cypher_guided(question: str, plan_step: int) -> str: | |
"""Generates Cypher using the 'guided' method based on concepts.""" | |
logger.info(f"Generating Cypher using 'guided' method for plan step {plan_step}.") | |
try: | |
concepts = neo4j_client.get_concepts() | |
if not concepts: | |
logger.warning("No concepts found in Neo4j for guided cypher generation.") | |
return "" # Or raise error | |
concept_llm = get_llm(settings.main_llm_model) # Or a specific concept model | |
concept_chain = ( | |
CONCEPT_SELECTION_PROMPT | |
| concept_llm | |
| StrOutputParser() | |
) | |
selected_concept = invoke_llm(concept_chain,{ | |
"question": question, | |
"concepts": "\n".join(concepts) | |
}).strip() | |
logger.info(f"Concept selected by LLM: {selected_concept}") | |
# Basic check if the selected concept is valid | |
if selected_concept not in concepts: | |
logger.warning(f"LLM selected concept '{selected_concept}' not in the known list. Attempting fallback or ignoring.") | |
# Optional: Add fuzzy matching or similarity search here | |
# For now, we might default or return empty | |
# Let's try a simple substring check as a fallback | |
found_match = None | |
for c in concepts: | |
if selected_concept.lower() in c.lower(): | |
found_match = c | |
logger.info(f"Found potential match: '{found_match}'") | |
break | |
if not found_match: | |
logger.error(f"Could not validate selected concept: {selected_concept}") | |
return "" # Return empty query if concept is invalid | |
selected_concept = found_match | |
# Determine the target node type based on plan step (example logic) | |
# This mapping might need adjustment based on the actual plan structure | |
if plan_step <= 1: # Steps 0 and 1: Context gathering | |
target = "(ts:TechnicalSpecification)" | |
fields = "ts.title, ts.scope, ts.description" | |
elif plan_step == 2: # Step 2: Research papers? | |
target = "(rp:ResearchPaper)" | |
fields = "rp.title, rp.abstract" | |
else: # Later steps might involve KeyIssues themselves or other types | |
target = "(n)" # Generic fallback | |
fields = "n.title, n.description" # Assuming common fields | |
# Construct Cypher query | |
# Ensure selected_concept is properly escaped if needed, though parameters are safer | |
cypher = f"MATCH (c:Concept {{name: $conceptName}})-[:RELATED_TO]-{target} RETURN {fields}" | |
# We return the query and the parameters separately for safe execution | |
# However, the planner currently expects just the string. Let's construct it directly for now. | |
# Be cautious about injection if concept names can contain special chars. Binding is preferred. | |
escaped_concept = selected_concept.replace("'", "\\'") # Basic escaping | |
cypher = f"MATCH (c:Concept {{name: '{escaped_concept}'}})-[:RELATED_TO]-{target} RETURN {fields}" | |
logger.info(f"Generated guided Cypher: {cypher}") | |
return cypher | |
except Exception as e: | |
logger.error(f"Error during guided cypher generation: {e}", exc_info=True) | |
time.sleep(60) | |
return "" # Return empty on error | |
# --- Document Retrieval --- | |
def retrieve_documents(cypher_query: str) -> List[Dict[str, Any]]: | |
"""Retrieves documents from Neo4j using a Cypher query.""" | |
if not cypher_query: | |
logger.warning("Received empty Cypher query, skipping retrieval.") | |
return [] | |
logger.info(f"Retrieving documents with Cypher: {cypher_query} limit 10") | |
try: | |
# Use the centralized client's query method | |
raw_results = neo4j_client.query(cypher_query + " limit 10") | |
# Basic cleaning/deduplication (can be enhanced) | |
processed_results = [] | |
seen = set() | |
for doc in raw_results: | |
# Create a frozenset of items for hashable representation to detect duplicates | |
doc_items = frozenset(doc.items()) | |
if doc_items not in seen: | |
processed_results.append(doc) | |
seen.add(doc_items) | |
logger.info(f"Retrieved {len(processed_results)} unique documents.") | |
return processed_results | |
except (ConnectionError, ValueError, RuntimeError) as e: | |
# Errors already logged in neo4j_client | |
logger.error(f"Document retrieval failed: {e}") | |
return [] # Return empty list on failure | |
# --- Document Evaluation --- | |
# Define Pydantic models for structured LLM grader output (if not using built-in LCEL structured output) | |
class GradeDocumentsBinary(V1BaseModel): | |
"""Binary score for relevance check.""" | |
binary_score: str = Field(description="Relevant? 'yes' or 'no'") | |
class GradeDocumentsScore(V1BaseModel): | |
"""Score for relevance check.""" | |
rationale: str = Field(description="Rationale for the score.") | |
score: float = Field(description="Relevance score (0.0 to 1.0)") | |
def evaluate_documents( | |
docs: List[Dict[str, Any]], | |
query: str | |
) -> List[Dict[str, Any]]: | |
"""Evaluates document relevance to a query using configured method.""" | |
if not docs: | |
return [] | |
logger.info(f"Evaluating {len(docs)} documents for relevance to query: '{query}' using method: {settings.eval_method}") | |
eval_llm = get_llm(settings.eval_llm_model) | |
valid_docs_with_scores: List[Tuple[Dict[str, Any], float]] = [] | |
# Consider using LCEL's structured output capabilities directly if the model supports it well | |
# This simplifies parsing. Example for binary: | |
# binary_grader = BINARY_GRADER_PROMPT | eval_llm.with_structured_output(GradeDocumentsBinary) | |
if settings.eval_method == "binary": | |
binary_grader = BINARY_GRADER_PROMPT | eval_llm | StrOutputParser() # Fallback to string parsing | |
for doc in docs: | |
formatted_doc = format_doc_for_llm(doc) | |
if not formatted_doc.strip(): continue | |
try: | |
result = invoke_llm(binary_grader,{"question": query, "document": formatted_doc}) | |
logger.debug(f"Binary grader result for doc '{doc.get('title', 'N/A')}': {result}") | |
if result and 'yes' in result.lower(): | |
valid_docs_with_scores.append((doc, 1.0)) # Score 1.0 for relevant | |
except Exception as e: | |
logger.warning(f"Binary grading failed for a document: {e}", exc_info=True) | |
elif settings.eval_method == "score": | |
# Using JSON parser as a robust fallback for score extraction | |
score_grader = SCORE_GRADER_PROMPT | eval_llm | JsonOutputParser(pydantic_object=GradeDocumentsScore) | |
for doc in docs: | |
formatted_doc = format_doc_for_llm(doc) | |
if not formatted_doc.strip(): continue | |
try: | |
result: GradeDocumentsScore = invoke_llm(score_grader,{"query": query, "document": formatted_doc}) | |
logger.debug(f"Score grader result for doc '{doc.get('title', 'N/A')}': Score={result.score}, Rationale={result.rationale}") | |
if result.score >= settings.eval_threshold: | |
valid_docs_with_scores.append((doc, result.score)) | |
except Exception as e: | |
logger.warning(f"Score grading failed for a document: {e}", exc_info=True) | |
# Optionally treat as relevant on failure? Or skip? Skipping for now. | |
# Sort by score if applicable, then limit | |
if settings.eval_method == 'score': | |
valid_docs_with_scores.sort(key=lambda item: item[1], reverse=True) | |
# Limit to max_docs | |
final_docs = [doc for doc, score in valid_docs_with_scores[:settings.max_docs]] | |
logger.info(f"Found {len(final_docs)} relevant documents after evaluation and filtering.") | |
return final_docs |