import logging import time import uvicorn from fastapi import FastAPI, HTTPException from pydantic import BaseModel from contextlib import asynccontextmanager from typing import List, Dict, Any # Import necessary components from your kig_core library # Ensure kig_core is in the Python path or installed as a package try: from kig_core.config import settings # Loads config on import from kig_core.schemas import PlannerState, KeyIssue as KigKeyIssue, GraphConfig from kig_core.planner import build_graph from kig_core.graph_client import neo4j_client # Import the initialized client instance from langchain_core.messages import HumanMessage except ImportError as e: print(f"Error importing kig_core components: {e}") print("Please ensure kig_core is in your Python path or installed.") # You might want to exit or raise a clearer error if imports fail raise # Configure logging for the API logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # --- Pydantic Models for API Request/Response --- class KeyIssueRequest(BaseModel): """Request body containing the user's technical query.""" query: str class KeyIssueResponse(BaseModel): """Response body containing the generated key issues.""" key_issues: List[KigKeyIssue] # Use the KeyIssue schema from kig_core # --- Global Variables / State --- # Keep the graph instance global for efficiency if desired, # but consider potential concurrency issues if graph/LLMs have state. # Rebuilding on each request is safer for statelessness. app_graph = None # Will be initialized at startup # --- Application Lifecycle (Startup/Shutdown) --- @asynccontextmanager async def lifespan(app: FastAPI): """Handles startup and shutdown events.""" global app_graph logger.info("API starting up...") # Initialize Neo4j client (already done on import by graph_client.py) # Verify connection (optional, already done by graph_client on init) try: logger.info("Verifying Neo4j connection...") neo4j_client._get_driver().verify_connectivity() logger.info("Neo4j connection verified.") except Exception as e: logger.error(f"Neo4j connection verification failed on startup: {e}", exc_info=True) # Decide if the app should fail to start # raise RuntimeError("Failed to connect to Neo4j on startup.") from e # Build the LangGraph application logger.info("Building LangGraph application...") try: app_graph = build_graph() logger.info("LangGraph application built successfully.") except Exception as e: logger.error(f"Failed to build LangGraph application on startup: {e}", exc_info=True) # Decide if the app should fail to start raise RuntimeError("Failed to build LangGraph on startup.") from e yield # API runs here # --- Shutdown --- logger.info("API shutting down...") # Close Neo4j connection (handled by atexit in graph_client.py) # neo4j_client.close() # Usually not needed due to atexit registration logger.info("Neo4j client closed (likely via atexit).") logger.info("API shutdown complete.") # --- FastAPI Application --- app = FastAPI( title="Key Issue Generator API", description="API to generate Key Issues based on a technical query using LLMs and Neo4j.", version="1.0.0", lifespan=lifespan # Use the lifespan context manager ) # --- API Endpoint --- # API state check route @app.get("/") def read_root(): return {"status": "ok"} @app.post("/generate-key-issues", response_model=KeyIssueResponse) async def generate_issues(request: KeyIssueRequest): """ Accepts a technical query and returns a list of generated Key Issues. """ global app_graph if app_graph is None: logger.error("Graph application is not initialized.") raise HTTPException(status_code=503, detail="Service Unavailable: Graph not initialized") user_query = request.query if not user_query: raise HTTPException(status_code=400, detail="Query cannot be empty.") logger.info(f"Received request to generate key issues for query: '{user_query[:100]}...'") start_time = time.time() try: # --- Prepare Initial State for LangGraph --- # Note: Ensure PlannerState aligns with what build_graph expects initial_state: PlannerState = { "user_query": user_query, "messages": [HumanMessage(content=user_query)], "plan": [], "current_plan_step_index": -1, # Or as expected by your graph's entry point "step_outputs": {}, "key_issues": [], "error": None } # --- Define Configuration (e.g., Thread ID for Memory) --- # Using a simple thread ID; adapt if using persistent memory # import hashlib # thread_id = hashlib.sha256(user_query.encode()).hexdigest()[:8] # config: GraphConfig = {"configurable": {"thread_id": thread_id}} # If not using memory, config can be simpler or empty based on LangGraph version config: GraphConfig = {"configurable": {}} # Adjust if thread_id/memory is needed # --- Execute the LangGraph Workflow --- logger.info("Invoking LangGraph workflow...") # Use invoke for a single result, or stream if you need intermediate steps final_state = await app_graph.ainvoke(initial_state, config=config) # If using stream: # final_state = None # async for step_state in app_graph.astream(initial_state, config=config): # # Process intermediate states if needed # node_name = list(step_state.keys())[0] # logger.debug(f"Graph step completed: {node_name}") # final_state = step_state[node_name] # Get the latest full state output end_time = time.time() logger.info(f"Workflow finished in {end_time - start_time:.2f} seconds.") # --- Process Final Results --- if final_state is None: logger.error("Workflow execution did not produce a final state.") raise HTTPException(status_code=500, detail="Workflow execution failed to produce a result.") if final_state.get("error"): error_msg = final_state.get("error", "Unknown error") logger.error(f"Workflow failed with error: {error_msg}") # Map internal errors to appropriate HTTP status codes status_code = 500 # Internal Server Error by default if "Neo4j" in error_msg or "connection" in error_msg.lower(): status_code = 503 # Service Unavailable (database issue) elif "LLM error" in error_msg or "parse" in error_msg.lower(): status_code = 502 # Bad Gateway (issue with upstream LLM) raise HTTPException(status_code=status_code, detail=f"Workflow failed: {error_msg}") # --- Extract Key Issues --- # Ensure the structure matches KeyIssueResponse and KigKeyIssue Pydantic model generated_issues_data = final_state.get("key_issues", []) # Validate and convert if necessary (Pydantic usually handles this via response_model) try: # Pydantic will validate against KeyIssueResponse -> List[KigKeyIssue] response_data = {"key_issues": generated_issues_data} logger.info(f"Successfully generated {len(generated_issues_data)} key issues.") return response_data except Exception as pydantic_error: # Catch potential validation errors logger.error(f"Failed to validate final key issues against response model: {pydantic_error}", exc_info=True) logger.error(f"Data that failed validation: {generated_issues_data}") raise HTTPException(status_code=500, detail="Internal error: Failed to format key issues response.") except HTTPException as http_exc: # Re-raise HTTPExceptions directly raise http_exc except ConnectionError as e: logger.error(f"Connection Error during API request: {e}", exc_info=True) raise HTTPException(status_code=503, detail=f"Service Unavailable: {e}") except ValueError as e: logger.error(f"Value Error during API request: {e}", exc_info=True) raise HTTPException(status_code=400, detail=f"Bad Request: {e}") # Often input validation issues except Exception as e: logger.error(f"An unexpected error occurred during API request: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Internal Server Error: An unexpected error occurred.") # --- How to Run --- if __name__ == "__main__": # Make sure to set environment variables for config (NEO4J_URI, NEO4J_PASSWORD, GEMINI_API_KEY, etc.) # or have a .env file in the same directory where you run this script. print("Starting API server...") print("Ensure required environment variables (e.g., NEO4J_URI, NEO4J_PASSWORD, GEMINI_API_KEY) are set or .env file is present.") # Run with uvicorn: uvicorn api:app --reload --host 0.0.0.0 --port 8000 # The --reload flag is good for development. Remove it for production. uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True) # Use reload=False for production