import os import re import json import time import shutil import asyncio import logging import traceback from typing import List, Dict, Any, Optional from fastapi.staticfiles import StaticFiles from fastapi import FastAPI, Request, HTTPException, UploadFile, File, Form from fastapi.responses import StreamingResponse, JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from dotenv import load_dotenv from tenacity import RetryError from openai import RateLimitError from anthropic import RateLimitError as AnthropicRateLimitError from google.api_core.exceptions import ResourceExhausted from src.helpers.helper import get_folder_size, clear_folder logger = logging.getLogger() logger.setLevel(logging.INFO) # Path to the .env file ENV_FILE_PATH = os.getenv("WRITABLE_DIR", "/tmp") + "/.env" # Define the upload directory and maximum folder size UPLOAD_DIRECTORY = os.getenv("WRITABLE_DIR", "/tmp") + "/uploads" MAX_FOLDER_SIZE = 10 * 1024 * 1024 # 10 MB in bytes CONTEXT_LENGTH = 128000 BUFFER = 10000 MAX_TOKENS_ALLOWED = CONTEXT_LENGTH - BUFFER # Per-session state SESSION_STORE: Dict[str, Dict[str, Any]] = {} # Format error message for SSE def format_error_sse(event_type: str, data: str) -> str: lines = data.splitlines() sse_message = f"event: {event_type}\n" for line in lines: sse_message += f"data: {line}\n" sse_message += "\n" return sse_message # Stop the task on error (non-fastapi) def stop_on_error(): state = SESSION_STORE if "process_task" in state: state["process_task"].cancel() del state["process_task"] # Get OAuth tokens for MCP tools def get_oauth_token(provider: str) -> Optional[str]: if "oauth_tokens" in SESSION_STORE and provider in SESSION_STORE["oauth_tokens"]: token_data = SESSION_STORE["oauth_tokens"][provider] # Check if token is expired (1 hour) if time.time() - token_data["timestamp"] < 3600: return token_data["token"] else: # Token expired, remove it del SESSION_STORE["oauth_tokens"][provider] logger.info(f"{provider} token expired and removed") return None # Initialize the components async def initialize_components(): load_dotenv(ENV_FILE_PATH, override=True) from src.search.search_engine import SearchEngine from src.query_processing.query_processor import QueryProcessor # from src.rag.neo4j_graphrag import Neo4jGraphRAG from src.rag.graph_rag import GraphRAG from src.evaluation.evaluator import Evaluator from src.reasoning.reasoner import Reasoner from src.crawl.crawler import CustomCrawler from src.utils.api_key_manager import APIKeyManager from src.query_processing.late_chunking.late_chunker import LateChunker from src.integrations.mcp_client import MCPClient state = SESSION_STORE manager = APIKeyManager() manager._reinit() state['search_engine'] = SearchEngine() state['query_processor'] = QueryProcessor() state['crawler'] = CustomCrawler(max_concurrent_requests=1000) # state['graph_rag'] = Neo4jGraphRAG(num_workers=os.cpu_count() * 2) state['graph_rag'] = GraphRAG(num_workers=os.cpu_count() * 2) state['evaluator'] = Evaluator() state['reasoner'] = Reasoner() state['model'] = manager.get_llm() state['late_chunker'] = LateChunker() state["mcp_client"] = MCPClient() state["initialized"] = True state["session_id"] = await state["crawler"].create_session() # Main function to process user queries async def process_query(user_query: str, sse_queue: asyncio.Queue): state = SESSION_STORE try: # --- Categorize the query --- category = await state["query_processor"].classify_query(user_query) cat_lower = category.lower().strip() user_query = re.sub(r'category:.*', '', user_query, flags=re.IGNORECASE).strip() # --- Read and extract user-provided files and links --- # Initialize caches if not present if "user_files_cache" not in state: state["user_files_cache"] = {} if "user_links_cache" not in state: state["user_links_cache"] = {} # Extract user-provided context user_context = "" user_links = state.get("user_provided_links", []) # Read new uploaded files if state["session_id"]: session_upload_path = os.path.join(UPLOAD_DIRECTORY, state["session_id"]) if os.path.exists(session_upload_path): for filename in os.listdir(session_upload_path): file_path = os.path.join(session_upload_path, filename) if os.path.isfile(file_path): # Check if file is already in cache if filename not in state["user_files_cache"]: try: await sse_queue.put(("step", "Reading User-Provided Files...")) with open(file_path, 'r', encoding='utf-8') as f: file_content = f.read() state["user_files_cache"][filename] = file_content except Exception as e: logger.error(f"Error reading file {filename}: {e}") # Try reading as binary and decode try: with open(file_path, 'rb') as f: file_content = f.read().decode('utf-8', errors='ignore') state["user_files_cache"][filename] = file_content except Exception as e2: logger.error(f"Error reading file {filename} as binary: {e2}") state["user_files_cache"][filename] = "" # Cache empty to avoid retrying # Add all cached file contents for filename, content in state["user_files_cache"].items(): if content: user_context += f"\n[USER PROVIDED FILE: {filename} START]\n{content}\n[USER PROVIDED FILE: {filename} END]\n\n" # Crawl new user-provided links if user_links: await sse_queue.put(("step", "Crawling User-Provided Links...")) new_links = [link for link in user_links if link not in state["user_links_cache"]] if new_links: # Only crawl new links link_contents = await state['crawler'].fetch_page_contents( new_links, user_query, state["session_id"], max_attempts=1 ) # Cache the new contents for link, content in zip(new_links, link_contents): if not isinstance(content, Exception) and content: state["user_links_cache"][link] = content else: state["user_links_cache"][link] = "" # Cache empty to avoid retrying # Add all cached link contents for link, content in state["user_links_cache"].items(): if content: idx = user_links.index(link) + 1 if link in user_links else 0 user_context += f"\n[USER PROVIDED LINK {idx} START]\n{content}\n[USER PROVIDED LINK {idx} END]\n\n" # --- Fetch apps data from MCP service --- app_context = "" selected_services = state.get("selected_services", {}) # Check if any services are selected has_google = selected_services.get("google", []) has_microsoft = selected_services.get("microsoft", []) has_slack = selected_services.get("slack", False) if has_google or has_microsoft or has_slack: await sse_queue.put(("step", "Fetching Data From Connected Apps...")) # Fetch from each provider in parallel tasks = [] # Google services if has_google and len(has_google) > 0: google_token = get_oauth_token("google") tasks.append( state['mcp_client'].fetch_app_data( provider="google", services=has_google, query=user_query, user_id=state["session_id"], access_token=google_token ) ) # Microsoft services if has_microsoft and len(has_microsoft) > 0: microsoft_token = get_oauth_token("microsoft") tasks.append( state['mcp_client'].fetch_app_data( provider="microsoft", services=has_microsoft, query=user_query, user_id=state["session_id"], access_token=microsoft_token ) ) # Slack if has_slack: slack_token = get_oauth_token("slack") tasks.append( state['mcp_client'].fetch_app_data( provider="slack", services=["messages"], # Slack doesn't have sub-services query=user_query, user_id=state["session_id"], access_token=slack_token ) ) # Execute all requests in parallel if tasks: results = await asyncio.gather(*tasks, return_exceptions=True) # Process results for i, result in enumerate(results): if isinstance(result, Exception): logger.error(f"Error fetching app data: {result}") elif isinstance(result, dict): # Determine which provider this result is from if i == 0 and has_google: provider = "google" elif (i == 1 and has_microsoft) or (i == 0 and not has_google and has_microsoft): provider = "microsoft" else: provider = "slack" # Format the data formatted_context = state['mcp_client'].format_as_context(provider, result) if formatted_context: app_context += formatted_context # Log how much app data we got if app_context: logger.info(f"Retrieved app data: {len(app_context)} characters") # Prepend app context to user context if app_context: user_context = app_context + "\n\n" + user_context # Upgrade basic to advanced if user has provided links if cat_lower == "basic" and user_links: cat_lower = "advanced" # --- Process the query based on the category --- if cat_lower == "basic": response = "" chunk_counter = 1 if user_context: # Include user context if available await sse_queue.put(("step", "Generating Response...")) async for chunk in state["reasoner"].answer(user_query, user_context, query_type="basic"): await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 else: # No user context provided async for chunk in state["reasoner"].answer(user_query): await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 await sse_queue.put(("final_message", response)) SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) await sse_queue.put(("action", { "name": "evaluate", "payload": {"query": user_query, "response": response} })) await sse_queue.put(("complete", "done")) elif cat_lower == "advanced": await sse_queue.put(("step", "Searching...")) optimized_query = await state['search_engine'].generate_optimized_query(user_query) search_results = await state['search_engine'].search( optimized_query, num_results=3, exclude_filetypes=["pdf"] ) urls = [r.get('link', 'No URL') for r in search_results] search_contents = await state['crawler'].fetch_page_contents( urls, user_query, state["session_id"], max_attempts=1 ) # Start with user-provided context contents = user_context # Add crawled contents if search_contents: for k, content in enumerate(search_contents, 1): if isinstance(content, Exception): print(f"Error fetching content: {content}") elif content: contents += f"[SOURCE {k} START]\n{content}\n[SOURCE {k} END]\n\n" if len(contents.strip()) > 0: await sse_queue.put(("step", "Generating Response...")) token_count = state['model'].get_num_tokens(contents) if token_count > MAX_TOKENS_ALLOWED: contents = await state['late_chunker'].chunker(contents, user_query, MAX_TOKENS_ALLOWED) await sse_queue.put(("sources_read", len(search_contents))) response = "" chunk_counter = 1 async for chunk in state["reasoner"].answer(user_query, contents): await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 sources_for_answer = [] for idx, result in enumerate(search_results, 1): if search_contents[idx-1]: # Only include if content was successfully fetched sources_for_answer.append({ "id": idx, "title": result.get('title', 'No Title'), "link": result.get('link', 'No URL') } ) await sse_queue.put(("final_message", response)) await sse_queue.put(("final_sources", json.dumps(sources_for_answer))) SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) SESSION_STORE["answer"] = response SESSION_STORE["source_contents"] = contents await sse_queue.put(("action", { "name": "sources", "payload": {"search_results": search_results, "search_contents": search_contents} })) await sse_queue.put(("action", { "name": "evaluate", "payload": {"query": user_query, "contents": [contents], "response": response} })) await sse_queue.put(("complete", "done")) else: await sse_queue.put(("error", "No results found.")) elif cat_lower == "pro": current_search_results = [] current_search_contents = [] await sse_queue.put(("step", "Thinking...")) start = time.time() intent = await state['query_processor'].get_query_intent(user_query) sub_queries, _ = await state['query_processor'].decompose_query(user_query, intent) async def sub_query_task(sub_query): try: await sse_queue.put(("step", "Searching...")) await sse_queue.put(("task", (sub_query, "RUNNING"))) optimized_query = await state['search_engine'].generate_optimized_query(sub_query) search_results = await state['search_engine'].search( optimized_query, num_results=10, exclude_filetypes=["pdf"] ) filtered_urls = await state['search_engine'].filter_urls( sub_query, category, search_results ) current_search_results.extend(filtered_urls) # Combine search results with user-provided links all_search_results = search_results + \ [{"link": url, "title": f"User provided: {url}", "snippet": ""} for url in user_links] urls = [r.get('link', 'No URL') for r in all_search_results] search_contents = await state['crawler'].fetch_page_contents( urls, sub_query, state["session_id"], max_attempts=1 ) current_search_contents.extend(search_contents) contents = user_context if search_contents: for k, c in enumerate(search_contents, 1): if isinstance(c, Exception): logger.info(f"Error fetching content: {c}") elif c: contents += f"[SOURCE {k} START]\n{c}\n[SOURCE {k} END]\n\n" if len(contents.strip()) > 0: await sse_queue.put(("task", (sub_query, "DONE"))) else: await sse_queue.put(("task", (sub_query, "FAILED"))) return contents except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError): await sse_queue.put(("task", (sub_query, "FAILED"))) return "" tasks = [] if len(sub_queries) > 1 and sub_queries[0] != user_query: for sub_query in sub_queries: tasks.append(sub_query_task(sub_query)) results = await asyncio.gather(*tasks) end = time.time() # Start with user-provided context contents = user_context # Add searched contents contents += "\n\n".join(r for r in results if r.strip()) unique_results = [] seen = set() for entry in current_search_results: link = entry["link"] if link not in seen: seen.add(link) unique_results.append(entry) current_search_results = unique_results current_search_contents = list(set(current_search_contents)) if len(contents.strip()) > 0: await sse_queue.put(("step", "Generating Response...")) token_count = state['model'].get_num_tokens(contents) if token_count > MAX_TOKENS_ALLOWED: contents = await state['late_chunker'].chunker( text=contents, query=user_query, max_tokens=MAX_TOKENS_ALLOWED ) logger.info(f"Number of tokens in the answer: {token_count}") logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(contents)}") await sse_queue.put(("sources_read", len(current_search_contents))) response = "" chunk_counter = 1 is_first_chunk = True async for chunk in state['reasoner'].answer(user_query, contents): if is_first_chunk: await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds")) is_first_chunk = False await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 sources_for_answer = [] for idx, (result, content) in enumerate(zip(current_search_results, current_search_contents), 1): if content: # Only include if content was successfully fetched sources_for_answer.append({ "id": idx, "title": result.get('title', 'No Title'), "link": result.get('link', 'No URL') }) await sse_queue.put(("final_message", response)) await sse_queue.put(("final_sources", json.dumps(sources_for_answer))) SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) SESSION_STORE["answer"] = response SESSION_STORE["source_contents"] = contents await sse_queue.put(("action", { "name": "sources", "payload": { "search_results": current_search_results, "search_contents": current_search_contents } })) await sse_queue.put(("action", { "name": "evaluate", "payload": {"query": user_query, "contents": [contents], "response": response} })) await sse_queue.put(("complete", "done")) else: await sse_queue.put(("error", "No results found.")) elif cat_lower == "super": current_search_results = [] current_search_contents = [] await sse_queue.put(("step", "Thinking...")) start = time.time() main_query_intent = await state['query_processor'].get_query_intent(user_query) sub_queries, _ = await state['query_processor'].decompose_query(user_query, main_query_intent) await sse_queue.put(("step", "Searching...")) async def sub_query_task(sub_query): try: async def sub_sub_query_task(sub_sub_query): optimized_query = await state['search_engine'].generate_optimized_query(sub_sub_query) search_results = await state['search_engine'].search( optimized_query, num_results=10, exclude_filetypes=["pdf"] ) filtered_urls = await state['search_engine'].filter_urls( sub_sub_query, category, search_results ) current_search_results.extend(filtered_urls) urls = [r.get('link', 'No URL') for r in filtered_urls] search_contents = await state['crawler'].fetch_page_contents( urls, sub_sub_query, state["session_id"], max_attempts=1, timeout=20 ) current_search_contents.extend(search_contents) contents = "" if search_contents: for k, c in enumerate(search_contents, 1): if isinstance(c, Exception): logger.info(f"Error fetching content: {c}") elif c: contents += f"[SOURCE {k} START]\n{c}\n[SOURCE {k} END]\n\n" return contents await sse_queue.put(("task", (sub_query, "RUNNING"))) sub_sub_queries, _ = await state['query_processor'].decompose_query(sub_query) tasks = [] if len(sub_sub_queries) > 1 and sub_sub_queries[0] != user_query: for sub_sub_query in sub_sub_queries: tasks.append(sub_sub_query_task(sub_sub_query)) results = await asyncio.gather(*tasks) if any(result.strip() for result in results): await sse_queue.put(("task", (sub_query, "DONE"))) else: await sse_queue.put(("task", (sub_query, "FAILED"))) return results except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError): await sse_queue.put(("task", (sub_query, "FAILED"))) return [] tasks = [] if len(sub_queries) > 1 and sub_queries[0] != user_query: for sub_query in sub_queries: tasks.append(sub_query_task(sub_query)) results = await asyncio.gather(*tasks) end = time.time() # Start with user-provided context previous_contents = [] if user_context: previous_contents.append(user_context) for result in results: if result: for content in result: if isinstance(content, str) and len(content.strip()) > 0: previous_contents.append(content) contents = "\n\n".join(previous_contents) unique_results = [] seen = set() for entry in current_search_results: link = entry["link"] if link not in seen: seen.add(link) unique_results.append(entry) current_search_results = unique_results current_search_contents = list(set(current_search_contents)) if len(contents.strip()) > 0: await sse_queue.put(("step", "Generating Response...")) token_count = state['model'].get_num_tokens(contents) if token_count > MAX_TOKENS_ALLOWED: contents = await state['late_chunker'].chunker( text=contents, query=user_query, max_tokens=MAX_TOKENS_ALLOWED ) logger.info(f"Number of tokens in the answer: {token_count}") logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(contents)}") await sse_queue.put(("sources_read", len(current_search_contents))) response = "" chunk_counter = 1 is_first_chunk = True async for chunk in state['reasoner'].answer(user_query, contents): if is_first_chunk: await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds")) is_first_chunk = False await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 sources_for_answer = [] for idx, (result, content) in enumerate(zip(current_search_results, current_search_contents), 1): if content: # Only include if content was successfully fetched sources_for_answer.append({ "id": idx, "title": result.get('title', 'No Title'), "link": result.get('link', 'No URL') }) await sse_queue.put(("final_message", response)) await sse_queue.put(("final_sources", json.dumps(sources_for_answer))) SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) SESSION_STORE["answer"] = response SESSION_STORE["source_contents"] = contents await sse_queue.put(("action", { "name": "sources", "payload": { "search_results": current_search_results, "search_contents": current_search_contents } })) await sse_queue.put(("action", { "name": "evaluate", "payload": {"query": user_query, "contents": [contents], "response": response} })) await sse_queue.put(("complete", "done")) else: await sse_queue.put(("error", "No results found.")) elif cat_lower == "ultra": current_search_results = [] current_search_contents = [] match = re.search( r"^This is the previous context of the conversation:\s*.*?\s*Current Query:\s*(.*)$", user_query, flags=re.DOTALL | re.MULTILINE ) if match: user_query = match.group(1) await sse_queue.put(("step", "Thinking...")) await asyncio.sleep(0.01) # Sleep for a short time to allow the message to be sent async def on_event_callback(event_type, data): if event_type == "graph_operation": if data["operation_type"] == "creating_new_graph": await sse_queue.put(("step", "Creating New Graph...")) elif data["operation_type"] == "modifying_existing_graph": await sse_queue.put(("step", "Modifying Existing Graph...")) elif data["operation_type"] == "loading_existing_graph": await sse_queue.put(("step", "Loading Existing Graph...")) elif event_type == "sub_query_created": sub_query = data["sub_query"] await sse_queue.put(("task", (sub_query, "RUNNING"))) elif event_type == "search_process_started": await sse_queue.put(("step", "Searching...")) elif event_type == "sub_query_processed": sub_query = data["sub_query"] await sse_queue.put(("task", (sub_query, "DONE"))) elif event_type == "sub_query_failed": sub_query = data["sub_query"] await sse_queue.put(("task", (sub_query, "FAILED"))) elif event_type == "search_results_filtered": current_search_results.extend(data["filtered_urls"]) filtered_urls = data["filtered_urls"] current_search_results.extend(filtered_urls) elif event_type == "search_contents_fetched": current_search_contents.extend(data["contents"]) contents = data["contents"] current_search_contents.extend(contents) elif event_type == "search_process_completed": await sse_queue.put(("step", "Processing final graph tasks...")) await asyncio.sleep(0.01) # Sleep for a short time to allow the message to be sent state['graph_rag'].set_on_event_callback(on_event_callback) start = time.time() # state['graph_rag'].initialize_schema() await state['graph_rag'].process_graph( user_query, similarity_threshold=0.8, relevance_threshold=0.8, max_tokens_allowed=MAX_TOKENS_ALLOWED ) end = time.time() unique_results = [] seen = set() for entry in current_search_results: link = entry["link"] if link not in seen: seen.add(link) unique_results.append(entry) current_search_results = unique_results current_search_contents = list(set(current_search_contents)) await sse_queue.put(("step", "Generating Response...")) answer = state['graph_rag'].query_graph(user_query) if answer: # Start with user-provided context previous_contents = [] if user_context: previous_contents.append(user_context) token_count = state['model'].get_num_tokens(answer) if token_count > MAX_TOKENS_ALLOWED: answer = await state['late_chunker'].chunker( text=answer, query=user_query, max_tokens=MAX_TOKENS_ALLOWED ) logger.info(f"Number of tokens in the answer: {token_count}") logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(answer)}") await sse_queue.put(("sources_read", len(current_search_contents))) response = "" chunk_counter = 1 is_first_chunk = True async for chunk in state['reasoner'].answer(user_query, answer): if is_first_chunk: await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds")) is_first_chunk = False await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 sources_for_answer = [] for idx, (result, content) in enumerate(zip(current_search_results, current_search_contents), 1): if content: # Only include if content was successfully fetched sources_for_answer.append({ "id": idx, "title": result.get('title', 'No Title'), "link": result.get('link', 'No URL') }) await sse_queue.put(("final_message", response)) await sse_queue.put(("final_sources", json.dumps(sources_for_answer))) SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) SESSION_STORE["answer"] = response SESSION_STORE["source_contents"] = contents await sse_queue.put(("action", { "name": "sources", "payload": {"search_results": current_search_results, "search_contents": current_search_contents}, })) await sse_queue.put(("action", { "name": "graph", "payload": {"query": user_query}, })) await sse_queue.put(("action", { "name": "evaluate", "payload": {"query": user_query, "contents": [answer], "response": response}, })) await sse_queue.put(("complete", "done")) else: await sse_queue.put(("error", "No results found.")) else: await sse_queue.put(("final_message", "I'm not sure how to handle your query.")) except Exception as e: await sse_queue.put(("error", str(e))) traceback.print_exc() stop() # Create a FastAPI app app = FastAPI() # Define allowed origins origins = [ "http://localhost:3000", "http://localhost:7860" "http://localhost:8000", "http://localhost" ] # Add the CORS middleware to your FastAPI app app.add_middleware( CORSMiddleware, allow_origins=origins, # Allows only these origins allow_credentials=True, allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.) allow_headers=["*"], # Allows all headers ) # Serve the React app (the production build) at the root URL. app.mount("/static", StaticFiles(directory="frontend/build/static", html=True), name="static") # Define the routes for the FastAPI app # Define the route for sources action to display search results @app.post("/action/sources") def action_sources(payload: Dict[str, Any]) -> Dict[str, Any]: try: search_contents = payload.get("search_contents", []) search_results = payload.get("search_results", []) sources = [] word_limit = 15 # Maximum number of words for the description for result, contents in zip(search_results, search_contents): if contents: title = result.get('title', 'No Title') link = result.get('link', 'No URL') snippet = result.get('snippet', 'No snippet') cleaned = re.sub(r'<[^>]+>|\[\/?.*?\]', '', snippet) words = cleaned.split() if len(words) > word_limit: description = " ".join(words[:word_limit]) + "..." else: description = " ".join(words) source_obj = { "title": title, "link": link, "description": description } sources.append(source_obj) return {"result": sources} except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) # Define the route for graph action to display the graph @app.post("/action/graph") def action_graph() -> Dict[str, Any]: state = SESSION_STORE try: html_str = state['graph_rag'].display_graph() return {"result": html_str} except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) # Define the route for evaluate action to display evaluation results @app.post("/action/evaluate") async def action_evaluate(payload: Dict[str, Any]) -> Dict[str, Any]: try: query = payload.get("query", "") contents = payload.get("contents", []) response = payload.get("response", "") metrics = payload.get("metrics", []) state = SESSION_STORE evaluator = state["evaluator"] result = await evaluator.evaluate_response(query, response, contents, include_metrics=metrics) return {"result": result} except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) # Define the route for excerpts action to display excerpts from the sources @app.post("/action/excerpts") async def action_excerpts() -> Dict[str, Any]: def validate_excerpts_format(excerpts): if not isinstance(excerpts, list): return False for item in excerpts: if not isinstance(item, dict): return False for statement, sources in item.items(): if not isinstance(statement, str) or not isinstance(sources, dict): return False for src_num, excerpt in sources.items(): if not (isinstance(src_num, int) or isinstance(src_num, str)): return False if not isinstance(excerpt, str): return False return True try: state = SESSION_STORE response = state["answer"] contents = state["source_contents"] if not response or not contents: raise ValueError("Required data for excerpts not found") excerpts_list = await state["reasoner"].get_excerpts(response, contents) cleaned_excerpts = re.sub( r'```[\w\s]*\n?|```|~~~[\w\s]*\n?|~~~', '', excerpts_list, flags=re.MULTILINE | re.DOTALL ).strip() try: excerpts = eval(cleaned_excerpts) except Exception: print(f"Error parsing excerpts:\n{cleaned_excerpts}") raise ValueError("Excerpts could not be parsed as a Python list.") if not validate_excerpts_format(excerpts): print(f"Excerpts format validation failed:\n{excerpts}") raise ValueError("Excerpts are not in the required format.") print(f"Excerpts:\n{excerpts}") return {"result": excerpts} except Exception as e: print(f"Error in action_excerpts: {e}") return JSONResponse(content={"error": str(e)}, status_code=500) # Define the route for settings to set or update the environment variables @app.post("/settings") async def update_settings(data: Dict[str, Any]): from src.helpers.helper import ( prepare_provider_key_updates, prepare_proxy_list_updates, update_env_vars ) provider = data.get("Model_Provider", "").strip() model_name = data.get("Model_Name", "").strip() multiple_api_keys = data.get("Model_API_Keys", "").strip() brave_api_key = data.get("Brave_Search_API_Key", "").strip() proxy_list = data.get("Proxy_List", "").strip() model_temperature = str(data.get("Model_Temperature", 0.0)) model_top_p = str(data.get("Model_Top_P", 1.0)) prov_lower = provider.lower() key_updates = prepare_provider_key_updates(prov_lower, multiple_api_keys) env_updates = {} env_updates.update(key_updates) px = prepare_proxy_list_updates(proxy_list) if px: env_updates.update(px) env_updates["BRAVE_API_KEY"] = brave_api_key env_updates["MODEL_PROVIDER"] = prov_lower env_updates["MODEL_NAME"] = model_name env_updates["MODEL_TEMPERATURE"] = model_temperature env_updates["MODEL_TOP_P"] = model_top_p update_env_vars(env_updates) load_dotenv(override=True) await initialize_components() return {"success": True} # Define the route for adding/uploading content for a specific session @app.post("/add-content") async def add_content(files: Optional[List[UploadFile]] = File(None), urls: str = Form(...)): state = SESSION_STORE session_id = state.get("session_id") if not session_id: raise HTTPException(status_code=400, detail="Session ID is not set. Please start a session first.") session_upload_path = os.path.join(UPLOAD_DIRECTORY, session_id) os.makedirs(session_upload_path, exist_ok=True) saved_filenames = [] if files: total_new_files_size = sum(file.size for file in files) current_folder_size = get_folder_size(session_upload_path) # Check if the total size exceeds the maximum allowed folder size if current_folder_size + total_new_files_size > MAX_FOLDER_SIZE: raise HTTPException( status_code=400, detail=f"Cannot add files as total storage would exceed 10 MB. Current size: {current_folder_size / (1024 * 1024):.2f} MB" ) for file in files: file_path = os.path.join(session_upload_path, file.filename) try: with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) saved_filenames.append(file.filename) finally: file.file.close() try: parsed_urls = json.loads(urls) print(f"Received links: {parsed_urls}") except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid URL format.") # Store user-provided links in session if parsed_urls: SESSION_STORE["user_provided_links"] = parsed_urls return { "message": "Content added successfully", "files_added": saved_filenames, "links_added": parsed_urls } # Define the route to update the selected services for searching @app.post("/api/selected-services") async def update_selected_services(data: Dict[str, Any]): state = SESSION_STORE selected_services = data.get("services", {}) state["selected_services"] = selected_services logger.info(f"Updated selected services: {selected_services}") return {"success": True, "services": selected_services} # Define the route to receive OAuth tokens from the frontend @app.post("/api/session-token") async def receive_session_token(data: Dict[str, Any]): provider = data.get("provider") # 'google', 'microsoft', 'slack' token = data.get("token") if not provider or not token: raise HTTPException(status_code=400, detail="Provider and token are required") SESSION_STORE["oauth_tokens"][provider] = { "token": token, "timestamp": time.time() } logger.info(f"Stored token {token} for provider {provider}") return {"success": True, "message": f"{provider} token stored successfully"} # Define the route for cleaning up a session if the session ID matches @app.post("/cleanup") async def cleanup_session(): state = SESSION_STORE session_id = state.get("session_id") if not session_id: raise HTTPException(status_code=400, detail="Session ID is not set. Please start a session first.") session_upload_path = os.path.join(UPLOAD_DIRECTORY, session_id) if session_id: # Clear the session upload directory clear_folder(session_upload_path) # Clear user-provided links and caches SESSION_STORE["user_provided_links"] = [] SESSION_STORE["user_files_cache"] = {} SESSION_STORE["user_links_cache"] = {} SESSION_STORE["selected_services"] = {} SESSION_STORE["oauth_tokens"] = {} return {"message": "Cleanup successful."} return {"message": "No session ID provided, cleanup skipped."} @app.on_event("startup") def init_chat(): if not SESSION_STORE: print("Initializing chat...") # Create the upload directory if it doesn't exist print("Creating upload directory...") os.makedirs(UPLOAD_DIRECTORY, exist_ok=True) # Initialize the session store SESSION_STORE["settings_saved"] = False SESSION_STORE["session_id"] = None SESSION_STORE["answer"] = None SESSION_STORE["source_contents"] = None SESSION_STORE["chat_history"] = [] SESSION_STORE["user_provided_links"] = [] SESSION_STORE["user_files_cache"] = {} SESSION_STORE["user_links_cache"] = {} SESSION_STORE["selected_services"] = {} SESSION_STORE["oauth_tokens"] = {} print("Chat initialized!") return {"sucess": True} else: print("Chat already initialized!") return {"success": False} @app.get("/message-sse") async def sse_message(request: Request, user_message: str): state = SESSION_STORE sse_queue = asyncio.Queue() async def event_generator(): # Build the prompt context = state["chat_history"][-3:] if context: prompt = \ f"""This is the previous context of the conversation: {context} Current Query: {user_message}""" else: prompt = user_message task = asyncio.create_task(process_query(prompt, sse_queue)) state["process_task"] = task while True: if await request.is_disconnected(): task.cancel() break try: event_type, data = await asyncio.wait_for(sse_queue.get(), timeout=5) if event_type == "token": yield f"event: token\ndata: {data}\n\n" elif event_type == "final_message": yield f"event: final_message\ndata: {data}\n\n" elif event_type == "error": stop_on_error() yield format_error_sse("error", data) elif event_type == "step": yield f"event: step\ndata: {data}\n\n" elif event_type == "task": subq, status = data j = {"task": subq, "status": status} yield f"event: task\ndata: {json.dumps(j)}\n\n" elif event_type == "sources_read": yield f"event: sources_read\ndata: {data}\n\n" elif event_type == "action": yield f"event: action\ndata: {json.dumps(data)}\n\n" elif event_type == "complete": yield f"event: complete\ndata: {data}\n\n" break else: yield f"event: message\ndata: {data}\n\n" except asyncio.TimeoutError: if task.done(): break continue except asyncio.CancelledError: break if not task.done(): task.cancel() if "process_task" in state: del state["process_task"] return StreamingResponse(event_generator(), media_type="text/event-stream") @app.post("/stop") def stop(): state = SESSION_STORE if "process_task" in state: state["process_task"].cancel() del state["process_task"] return {"message": "Stopped task manually"} # Catch-all route for frontend paths. @app.get("/{full_path:path}") async def serve_frontend(full_path: str, request: Request): index_path = os.path.join("frontend", "build", "index.html") if not os.path.exists(index_path): raise HTTPException(status_code=500, detail="Frontend build not found") return FileResponse(index_path)