FastAPI_KIG / api.py
adrienbrdne's picture
Upload 13 files
1bcef92 verified
raw
history blame
9.35 kB
import logging
import time
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from typing import List, Dict, Any
# Import necessary components from your kig_core library
# Ensure kig_core is in the Python path or installed as a package
try:
from kig_core.config import settings # Loads config on import
from kig_core.schemas import PlannerState, KeyIssue as KigKeyIssue, GraphConfig
from kig_core.planner import build_graph
from kig_core.graph_client import neo4j_client # Import the initialized client instance
from langchain_core.messages import HumanMessage
except ImportError as e:
print(f"Error importing kig_core components: {e}")
print("Please ensure kig_core is in your Python path or installed.")
# You might want to exit or raise a clearer error if imports fail
raise
# Configure logging for the API
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Pydantic Models for API Request/Response ---
class KeyIssueRequest(BaseModel):
"""Request body containing the user's technical query."""
query: str
class KeyIssueResponse(BaseModel):
"""Response body containing the generated key issues."""
key_issues: List[KigKeyIssue] # Use the KeyIssue schema from kig_core
# --- Global Variables / State ---
# Keep the graph instance global for efficiency if desired,
# but consider potential concurrency issues if graph/LLMs have state.
# Rebuilding on each request is safer for statelessness.
app_graph = None # Will be initialized at startup
# --- Application Lifecycle (Startup/Shutdown) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Handles startup and shutdown events."""
global app_graph
logger.info("API starting up...")
# Initialize Neo4j client (already done on import by graph_client.py)
# Verify connection (optional, already done by graph_client on init)
try:
logger.info("Verifying Neo4j connection...")
neo4j_client._get_driver().verify_connectivity()
logger.info("Neo4j connection verified.")
except Exception as e:
logger.error(f"Neo4j connection verification failed on startup: {e}", exc_info=True)
# Decide if the app should fail to start
# raise RuntimeError("Failed to connect to Neo4j on startup.") from e
# Build the LangGraph application
logger.info("Building LangGraph application...")
try:
app_graph = build_graph()
logger.info("LangGraph application built successfully.")
except Exception as e:
logger.error(f"Failed to build LangGraph application on startup: {e}", exc_info=True)
# Decide if the app should fail to start
raise RuntimeError("Failed to build LangGraph on startup.") from e
yield # API runs here
# --- Shutdown ---
logger.info("API shutting down...")
# Close Neo4j connection (handled by atexit in graph_client.py)
# neo4j_client.close() # Usually not needed due to atexit registration
logger.info("Neo4j client closed (likely via atexit).")
logger.info("API shutdown complete.")
# --- FastAPI Application ---
app = FastAPI(
title="Key Issue Generator API",
description="API to generate Key Issues based on a technical query using LLMs and Neo4j.",
version="1.0.0",
lifespan=lifespan # Use the lifespan context manager
)
# --- API Endpoint ---
# API state check route
@app.get("/")
def read_root():
return {"status": "ok"}
@app.post("/generate-key-issues", response_model=KeyIssueResponse)
async def generate_issues(request: KeyIssueRequest):
"""
Accepts a technical query and returns a list of generated Key Issues.
"""
global app_graph
if app_graph is None:
logger.error("Graph application is not initialized.")
raise HTTPException(status_code=503, detail="Service Unavailable: Graph not initialized")
user_query = request.query
if not user_query:
raise HTTPException(status_code=400, detail="Query cannot be empty.")
logger.info(f"Received request to generate key issues for query: '{user_query[:100]}...'")
start_time = time.time()
try:
# --- Prepare Initial State for LangGraph ---
# Note: Ensure PlannerState aligns with what build_graph expects
initial_state: PlannerState = {
"user_query": user_query,
"messages": [HumanMessage(content=user_query)],
"plan": [],
"current_plan_step_index": -1, # Or as expected by your graph's entry point
"step_outputs": {},
"key_issues": [],
"error": None
}
# --- Define Configuration (e.g., Thread ID for Memory) ---
# Using a simple thread ID; adapt if using persistent memory
# import hashlib
# thread_id = hashlib.sha256(user_query.encode()).hexdigest()[:8]
# config: GraphConfig = {"configurable": {"thread_id": thread_id}}
# If not using memory, config can be simpler or empty based on LangGraph version
config: GraphConfig = {"configurable": {}} # Adjust if thread_id/memory is needed
# --- Execute the LangGraph Workflow ---
logger.info("Invoking LangGraph workflow...")
# Use invoke for a single result, or stream if you need intermediate steps
final_state = await app_graph.ainvoke(initial_state, config=config)
# If using stream:
# final_state = None
# async for step_state in app_graph.astream(initial_state, config=config):
# # Process intermediate states if needed
# node_name = list(step_state.keys())[0]
# logger.debug(f"Graph step completed: {node_name}")
# final_state = step_state[node_name] # Get the latest full state output
end_time = time.time()
logger.info(f"Workflow finished in {end_time - start_time:.2f} seconds.")
# --- Process Final Results ---
if final_state is None:
logger.error("Workflow execution did not produce a final state.")
raise HTTPException(status_code=500, detail="Workflow execution failed to produce a result.")
if final_state.get("error"):
error_msg = final_state.get("error", "Unknown error")
logger.error(f"Workflow failed with error: {error_msg}")
# Map internal errors to appropriate HTTP status codes
status_code = 500 # Internal Server Error by default
if "Neo4j" in error_msg or "connection" in error_msg.lower():
status_code = 503 # Service Unavailable (database issue)
elif "LLM error" in error_msg or "parse" in error_msg.lower():
status_code = 502 # Bad Gateway (issue with upstream LLM)
raise HTTPException(status_code=status_code, detail=f"Workflow failed: {error_msg}")
# --- Extract Key Issues ---
# Ensure the structure matches KeyIssueResponse and KigKeyIssue Pydantic model
generated_issues_data = final_state.get("key_issues", [])
# Validate and convert if necessary (Pydantic usually handles this via response_model)
try:
# Pydantic will validate against KeyIssueResponse -> List[KigKeyIssue]
response_data = {"key_issues": generated_issues_data}
logger.info(f"Successfully generated {len(generated_issues_data)} key issues.")
return response_data
except Exception as pydantic_error: # Catch potential validation errors
logger.error(f"Failed to validate final key issues against response model: {pydantic_error}", exc_info=True)
logger.error(f"Data that failed validation: {generated_issues_data}")
raise HTTPException(status_code=500, detail="Internal error: Failed to format key issues response.")
except HTTPException as http_exc:
# Re-raise HTTPExceptions directly
raise http_exc
except ConnectionError as e:
logger.error(f"Connection Error during API request: {e}", exc_info=True)
raise HTTPException(status_code=503, detail=f"Service Unavailable: {e}")
except ValueError as e:
logger.error(f"Value Error during API request: {e}", exc_info=True)
raise HTTPException(status_code=400, detail=f"Bad Request: {e}") # Often input validation issues
except Exception as e:
logger.error(f"An unexpected error occurred during API request: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal Server Error: An unexpected error occurred.")
# --- How to Run ---
if __name__ == "__main__":
# Make sure to set environment variables for config (NEO4J_URI, NEO4J_PASSWORD, GEMINI_API_KEY, etc.)
# or have a .env file in the same directory where you run this script.
print("Starting API server...")
print("Ensure required environment variables (e.g., NEO4J_URI, NEO4J_PASSWORD, GEMINI_API_KEY) are set or .env file is present.")
# Run with uvicorn: uvicorn api:app --reload --host 0.0.0.0 --port 8000
# The --reload flag is good for development. Remove it for production.
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True) # Use reload=False for production