import os import re import asyncio import json import time import logging from typing import Any, Dict from fastapi.staticfiles import StaticFiles from fastapi import FastAPI, Request, HTTPException from fastapi.responses import StreamingResponse, JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from dotenv import load_dotenv from tenacity import RetryError from openai import RateLimitError from anthropic import RateLimitError as AnthropicRateLimitError from google.api_core.exceptions import ResourceExhausted logger = logging.getLogger() logger.setLevel(logging.INFO) CONTEXT_LENGTH = 128000 BUFFER = 10000 MAX_TOKENS_ALLOWED = CONTEXT_LENGTH - BUFFER # Path to the .env file ENV_FILE_PATH = os.getenv("WRITABLE_DIR", "/tmp") + "/.env" # Per-session state SESSION_STORE: Dict[str, Dict[str, Any]] = {} # Format error message for SSE def format_error_sse(event_type: str, data: str) -> str: lines = data.splitlines() sse_message = f"event: {event_type}\n" for line in lines: sse_message += f"data: {line}\n" sse_message += "\n" return sse_message # Stop the task on error (non-fastapi) def stop_on_error(): state = SESSION_STORE if "process_task" in state: state["process_task"].cancel() del state["process_task"] # Initialize the components def initialize_components(): load_dotenv(ENV_FILE_PATH, override=True) from src.search.search_engine import SearchEngine from src.query_processing.query_processor import QueryProcessor # from src.rag.neo4j_graphrag import Neo4jGraphRAG from src.rag.graph_rag import GraphRAG from src.evaluation.evaluator import Evaluator from src.reasoning.reasoner import Reasoner from src.crawl.crawler import CustomCrawler from src.utils.api_key_manager import APIKeyManager from src.query_processing.late_chunking.late_chunker import LateChunker manager = APIKeyManager() manager._reinit() SESSION_STORE['search_engine'] = SearchEngine() SESSION_STORE['query_processor'] = QueryProcessor() SESSION_STORE['crawler'] = CustomCrawler(max_concurrent_requests=1000) # SESSION_STORE['graph_rag'] = Neo4jGraphRAG(num_workers=os.cpu_count() * 2) SESSION_STORE['graph_rag'] = GraphRAG(num_workers=os.cpu_count() * 2) SESSION_STORE['evaluator'] = Evaluator() SESSION_STORE['reasoner'] = Reasoner() SESSION_STORE['model'] = manager.get_llm() SESSION_STORE['late_chunker'] = LateChunker() SESSION_STORE["initialized"] = True SESSION_STORE["session_id"] = None async def process_query(user_query: str, sse_queue: asyncio.Queue): state = SESSION_STORE try: category = await state["query_processor"].classify_query(user_query) cat_lower = category.lower().strip() if state["session_id"] is None: state["session_id"] = await state["crawler"].create_session() user_query = re.sub(r'category:.*', '', user_query, flags=re.IGNORECASE).strip() if cat_lower == "basic": response = "" chunk_counter = 1 async for chunk in state["reasoner"].reason(user_query): await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 await sse_queue.put(("final_message", response)) SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) await sse_queue.put(("action", { "name": "evaluate", "payload": {"query": user_query, "response": response} })) await sse_queue.put(("complete", "done")) elif cat_lower == "advanced": await sse_queue.put(("step", "Searching...")) optimized_query = await state['search_engine'].generate_optimized_query(user_query) search_results = await state['search_engine'].search( optimized_query, num_results=3, exclude_filetypes=["pdf"] ) urls = [r.get('link', 'No URL') for r in search_results] search_contents = await state['crawler'].fetch_page_contents( urls, user_query, state["session_id"], max_attempts=1 ) contents = "" if search_contents: for k, content in enumerate(search_contents, 1): if isinstance(content, Exception): print(f"Error fetching content: {content}") elif content: contents += f"Document {k}:\n{content}\n\n" if len(contents.strip()) > 0: await sse_queue.put(("step", "Generating Response...")) token_count = state['model'].get_num_tokens(contents) if token_count > MAX_TOKENS_ALLOWED: contents = await state['late_chunker'].chunker(contents, user_query, MAX_TOKENS_ALLOWED) await sse_queue.put(("sources_read", len(search_contents))) response = "" chunk_counter = 1 async for chunk in state["reasoner"].reason(user_query, contents): await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 await sse_queue.put(("final_message", response)) SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) await sse_queue.put(("action", { "name": "sources", "payload": {"search_results": search_results, "search_contents": search_contents} })) await sse_queue.put(("action", { "name": "evaluate", "payload": {"query": user_query, "contents": [contents], "response": response} })) await sse_queue.put(("complete", "done")) else: await sse_queue.put(("error", "No results found.")) elif cat_lower == "pro": current_search_results = [] current_search_contents = [] await sse_queue.put(("step", "Thinking...")) start = time.time() intent = await state['query_processor'].get_query_intent(user_query) sub_queries, _ = await state['query_processor'].decompose_query(user_query, intent) async def sub_query_task(sub_query): try: await sse_queue.put(("step", "Searching...")) await sse_queue.put(("task", (sub_query, "RUNNING"))) optimized_query = await state['search_engine'].generate_optimized_query(sub_query) search_results = await state['search_engine'].search( optimized_query, num_results=10, exclude_filetypes=["pdf"] ) filtered_urls = await state['search_engine'].filter_urls( sub_query, category, search_results ) current_search_results.extend(filtered_urls) urls = [r.get('link', 'No URL') for r in filtered_urls] search_contents = await state['crawler'].fetch_page_contents( urls, sub_query, state["session_id"], max_attempts=1 ) current_search_contents.extend(search_contents) contents = "" if search_contents: for k, c in enumerate(search_contents, 1): if isinstance(c, Exception): logger.info(f"Error fetching content: {c}") elif c: contents += f"Document {k}:\n{c}\n\n" if len(contents.strip()) > 0: await sse_queue.put(("task", (sub_query, "DONE"))) else: await sse_queue.put(("task", (sub_query, "FAILED"))) return contents except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError): await sse_queue.put(("task", (sub_query, "FAILED"))) return "" tasks = [] if len(sub_queries) > 1 and sub_queries[0] != user_query: for sub_query in sub_queries: tasks.append(sub_query_task(sub_query)) results = await asyncio.gather(*tasks) end = time.time() contents = "\n\n".join(r for r in results if r.strip()) unique_results = [] seen = set() for entry in current_search_results: link = entry["link"] if link not in seen: seen.add(link) unique_results.append(entry) current_search_results = unique_results current_search_contents = list(set(current_search_contents)) if len(contents.strip()) > 0: await sse_queue.put(("step", "Generating Response...")) token_count = state['model'].get_num_tokens(contents) if token_count > MAX_TOKENS_ALLOWED: contents = await state['late_chunker'].chunker( text=contents, query=user_query, max_tokens=MAX_TOKENS_ALLOWED ) logger.info(f"Number of tokens in the answer: {token_count}") logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(contents)}") await sse_queue.put(("sources_read", len(current_search_contents))) response = "" chunk_counter = 1 is_first_chunk = True async for chunk in state['reasoner'].reason(user_query, contents): if is_first_chunk: await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds")) is_first_chunk = False await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 await sse_queue.put(("final_message", response)) SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) await sse_queue.put(("action", { "name": "sources", "payload": { "search_results": current_search_results, "search_contents": current_search_contents } })) await sse_queue.put(("action", { "name": "evaluate", "payload": {"query": user_query, "contents": [contents], "response": response} })) await sse_queue.put(("complete", "done")) else: await sse_queue.put(("error", "No results found.")) elif cat_lower == "super": current_search_results = [] current_search_contents = [] await sse_queue.put(("step", "Thinking...")) start = time.time() main_query_intent = await state['query_processor'].get_query_intent(user_query) sub_queries, _ = await state['query_processor'].decompose_query(user_query, main_query_intent) await sse_queue.put(("step", "Searching...")) async def sub_query_task(sub_query): try: async def sub_sub_query_task(sub_sub_query): optimized_query = await state['search_engine'].generate_optimized_query(sub_sub_query) search_results = await state['search_engine'].search( optimized_query, num_results=10, exclude_filetypes=["pdf"] ) filtered_urls = await state['search_engine'].filter_urls( sub_sub_query, category, search_results ) current_search_results.extend(filtered_urls) urls = [r.get('link', 'No URL') for r in filtered_urls] search_contents = await state['crawler'].fetch_page_contents( urls, sub_sub_query, state["session_id"], max_attempts=1, timeout=20 ) current_search_contents.extend(search_contents) contents = "" if search_contents: for k, c in enumerate(search_contents, 1): if isinstance(c, Exception): logger.info(f"Error fetching content: {c}") elif c: contents += f"Document {k}:\n{c}\n\n" return contents await sse_queue.put(("task", (sub_query, "RUNNING"))) sub_sub_queries, _ = await state['query_processor'].decompose_query(sub_query) tasks = [] if len(sub_sub_queries) > 1 and sub_sub_queries[0] != user_query: for sub_sub_query in sub_sub_queries: tasks.append(sub_sub_query_task(sub_sub_query)) results = await asyncio.gather(*tasks) if any(result.strip() for result in results): await sse_queue.put(("task", (sub_query, "DONE"))) else: await sse_queue.put(("task", (sub_query, "FAILED"))) return results except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError): await sse_queue.put(("task", (sub_query, "FAILED"))) return [] tasks = [] if len(sub_queries) > 1 and sub_queries[0] != user_query: for sub_query in sub_queries: tasks.append(sub_query_task(sub_query)) results = await asyncio.gather(*tasks) end = time.time() previous_contents = [] for result in results: if result: for content in result: if isinstance(content, str) and len(content.strip()) > 0: previous_contents.append(content) contents = "\n\n".join(previous_contents) unique_results = [] seen = set() for entry in current_search_results: link = entry["link"] if link not in seen: seen.add(link) unique_results.append(entry) current_search_results = unique_results current_search_contents = list(set(current_search_contents)) if len(contents.strip()) > 0: await sse_queue.put(("step", "Generating Response...")) token_count = state['model'].get_num_tokens(contents) if token_count > MAX_TOKENS_ALLOWED: contents = await state['late_chunker'].chunker( text=contents, query=user_query, max_tokens=MAX_TOKENS_ALLOWED ) logger.info(f"Number of tokens in the answer: {token_count}") logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(contents)}") await sse_queue.put(("sources_read", len(current_search_contents))) response = "" chunk_counter = 1 is_first_chunk = True async for chunk in state['reasoner'].reason(user_query, contents): if is_first_chunk: await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds")) is_first_chunk = False await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 await sse_queue.put(("final_message", response)) SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) await sse_queue.put(("action", { "name": "sources", "payload": { "search_results": current_search_results, "search_contents": current_search_contents } })) await sse_queue.put(("action", { "name": "evaluate", "payload": {"query": user_query, "contents": [contents], "response": response} })) await sse_queue.put(("complete", "done")) else: await sse_queue.put(("error", "No results found.")) elif cat_lower == "ultra": current_search_results = [] current_search_contents = [] match = re.search( r"^This is the previous context of the conversation:\s*.*?\s*Current Query:\s*(.*)$", user_query, flags=re.DOTALL | re.MULTILINE ) if match: user_query = match.group(1) await sse_queue.put(("step", "Thinking...")) await asyncio.sleep(0.01) # Sleep for a short time to allow the message to be sent async def on_event_callback(event_type, data): if event_type == "graph_operation": if data["operation_type"] == "creating_new_graph": await sse_queue.put(("step", "Creating New Graph...")) elif data["operation_type"] == "modifying_existing_graph": await sse_queue.put(("step", "Modifying Existing Graph...")) elif data["operation_type"] == "loading_existing_graph": await sse_queue.put(("step", "Loading Existing Graph...")) elif event_type == "sub_query_created": sub_query = data["sub_query"] await sse_queue.put(("task", (sub_query, "RUNNING"))) elif event_type == "search_process_started": await sse_queue.put(("step", "Searching...")) elif event_type == "sub_query_processed": sub_query = data["sub_query"] await sse_queue.put(("task", (sub_query, "DONE"))) elif event_type == "sub_query_failed": sub_query = data["sub_query"] await sse_queue.put(("task", (sub_query, "FAILED"))) elif event_type == "search_results_filtered": current_search_results.extend(data["filtered_urls"]) filtered_urls = data["filtered_urls"] current_search_results.extend(filtered_urls) elif event_type == "search_contents_fetched": current_search_contents.extend(data["contents"]) contents = data["contents"] current_search_contents.extend(contents) elif event_type == "search_process_completed": await sse_queue.put(("step", "Processing final graph tasks...")) await asyncio.sleep(0.01) # Sleep for a short time to allow the message to be sent state['graph_rag'].set_on_event_callback(on_event_callback) start = time.time() # state['graph_rag'].initialize_schema() await state['graph_rag'].process_graph( user_query, similarity_threshold=0.8, relevance_threshold=0.8, max_tokens_allowed=MAX_TOKENS_ALLOWED ) end = time.time() unique_results = [] seen = set() for entry in current_search_results: link = entry["link"] if link not in seen: seen.add(link) unique_results.append(entry) current_search_results = unique_results current_search_contents = list(set(current_search_contents)) await sse_queue.put(("step", "Generating Response...")) answer = state['graph_rag'].query_graph(user_query) if answer: token_count = state['model'].get_num_tokens(answer) if token_count > MAX_TOKENS_ALLOWED: answer = await state['late_chunker'].chunker( text=answer, query=user_query, max_tokens=MAX_TOKENS_ALLOWED ) logger.info(f"Number of tokens in the answer: {token_count}") logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(answer)}") await sse_queue.put(("sources_read", len(current_search_contents))) response = "" chunk_counter = 1 is_first_chunk = True async for chunk in state['reasoner'].reason(user_query, answer): if is_first_chunk: await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds")) is_first_chunk = False await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 await sse_queue.put(("final_message", response)) SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) await sse_queue.put(("action", { "name": "sources", "payload": {"search_results": current_search_results, "search_contents": current_search_contents}, })) await sse_queue.put(("action", { "name": "graph", "payload": {"query": user_query}, })) await sse_queue.put(("action", { "name": "evaluate", "payload": {"query": user_query, "contents": [answer], "response": response}, })) await sse_queue.put(("complete", "done")) else: await sse_queue.put(("error", "No results found.")) else: await sse_queue.put(("final_message", "I'm not sure how to handle your query.")) except Exception as e: await sse_queue.put(("error", str(e))) stop() # Create a FastAPI app app = FastAPI() # Define allowed origins origins = [ "http://localhost:3000", "http://localhost:7860" "http://localhost:8000", "http://localhost" ] # Add the CORS middleware to your FastAPI app app.add_middleware( CORSMiddleware, allow_origins=origins, # Allows only these origins allow_credentials=True, allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.) allow_headers=["*"], # Allows all headers ) # Serve the React app (the production build) at the root URL. app.mount("/static", StaticFiles(directory="frontend/build/static", html=True), name="static") # Define the routes for the FastAPI app # Define the route for sources action to display search results @app.post("/action/sources") def action_sources(payload: Dict[str, Any]) -> Dict[str, Any]: try: search_contents = payload.get("search_contents", []) search_results = payload.get("search_results", []) sources = [] word_limit = 15 # Maximum number of words for the description for result, contents in zip(search_results, search_contents): if contents: title = result.get('title', 'No Title') link = result.get('link', 'No URL') snippet = result.get('snippet', 'No snippet') cleaned = re.sub(r'<[^>]+>|\[\/?.*?\]', '', snippet) words = cleaned.split() if len(words) > word_limit: description = " ".join(words[:word_limit]) + "..." else: description = " ".join(words) source_obj = { "title": title, "link": link, "description": description } sources.append(source_obj) return {"result": sources} except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) # Define the route for graph action to display the graph @app.post("/action/graph") def action_graph() -> Dict[str, Any]: state = SESSION_STORE try: html_str = state['graph_rag'].display_graph() return {"result": html_str} except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) # Define the route for evaluate action to display evaluation results @app.post("/action/evaluate") async def action_evaluate(payload: Dict[str, Any]) -> Dict[str, Any]: try: query = payload.get("query", "") contents = payload.get("contents", []) response = payload.get("response", "") metrics = payload.get("metrics", []) state = SESSION_STORE evaluator = state["evaluator"] result = await evaluator.evaluate_response(query, response, contents, include_metrics=metrics) return {"result": result} except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) @app.post("/settings") async def update_settings(data: Dict[str, Any]): from src.helpers.helper import ( prepare_provider_key_updates, prepare_proxy_list_updates, update_env_vars ) provider = data.get("Model_Provider", "").strip() model_name = data.get("Model_Name", "").strip() multiple_api_keys = data.get("Model_API_Keys", "").strip() brave_api_key = data.get("Brave_Search_API_Key", "").strip() proxy_list = data.get("Proxy_List", "").strip() # neo4j_url = data.get("Neo4j_URL", "").strip() # neo4j_username = data.get("Neo4j_Username", "").strip() # neo4j_password = data.get("Neo4j_Password", "").strip() model_temperature = str(data.get("Model_Temperature", 0.0)) model_top_p = str(data.get("Model_Top_P", 1.0)) prov_lower = provider.lower() key_updates = prepare_provider_key_updates(prov_lower, multiple_api_keys) env_updates = {} env_updates.update(key_updates) px = prepare_proxy_list_updates(proxy_list) if px: env_updates.update(px) env_updates["BRAVE_API_KEY"] = brave_api_key # env_updates["NEO4J_URI"] = neo4j_url # env_updates["NEO4J_USER"] = neo4j_username # env_updates["NEO4J_PASSWORD"] = neo4j_password env_updates["MODEL_PROVIDER"] = prov_lower env_updates["MODEL_NAME"] = model_name env_updates["MODEL_TEMPERATURE"] = model_temperature env_updates["MODEL_TOP_P"] = model_top_p update_env_vars(env_updates) load_dotenv(override=True) initialize_components() return {"success": True} @app.on_event("startup") def init_chat(): if not SESSION_STORE: print("Initializing chat...") SESSION_STORE["settings_saved"] = False SESSION_STORE["session_id"] = None SESSION_STORE["chat_history"] = [] print("Chat initialized!") return {"sucess": True} else: print("Chat already initialized!") return {"success": False} @app.get("/message-sse") async def sse_message(request: Request, user_message: str): state = SESSION_STORE sse_queue = asyncio.Queue() async def event_generator(): # Build the prompt context = state["chat_history"][-3:] if context: prompt = \ f"""This is the previous context of the conversation: {context} Current Query: {user_message}""" else: prompt = user_message task = asyncio.create_task(process_query(prompt, sse_queue)) state["process_task"] = task while True: if await request.is_disconnected(): task.cancel() break try: event_type, data = await asyncio.wait_for(sse_queue.get(), timeout=5) if event_type == "token": yield f"event: token\ndata: {data}\n\n" elif event_type == "final_message": yield f"event: final_message\ndata: {data}\n\n" elif event_type == "error": stop_on_error() yield format_error_sse("error", data) elif event_type == "step": yield f"event: step\ndata: {data}\n\n" elif event_type == "task": subq, status = data j = {"task": subq, "status": status} yield f"event: task\ndata: {json.dumps(j)}\n\n" elif event_type == "sources_read": yield f"event: sources_read\ndata: {data}\n\n" elif event_type == "action": yield f"event: action\ndata: {json.dumps(data)}\n\n" elif event_type == "complete": yield f"event: complete\ndata: {data}\n\n" break else: yield f"event: message\ndata: {data}\n\n" except asyncio.TimeoutError: if task.done(): break continue except asyncio.CancelledError: break if not task.done(): task.cancel() if "process_task" in state: del state["process_task"] return StreamingResponse(event_generator(), media_type="text/event-stream") @app.post("/stop") def stop(): state = SESSION_STORE if "process_task" in state: state["process_task"].cancel() del state["process_task"] return {"message": "Stopped task manually"} # Catch-all route for frontend paths. @app.get("/{full_path:path}") async def serve_frontend(full_path: str, request: Request): index_path = os.path.join("frontend", "build", "index.html") if not os.path.exists(index_path): raise HTTPException(status_code=500, detail="Frontend build not found") return FileResponse(index_path)