seekr / main.py
Hemang Thakur
updated environmemnt file path
d9487c9
raw
history blame
50.3 kB
import os
import re
import json
import time
import shutil
import asyncio
import logging
import traceback
from typing import List, Dict, Any, Optional
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, Request, HTTPException, UploadFile, File, Form
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from tenacity import RetryError
from openai import RateLimitError
from anthropic import RateLimitError as AnthropicRateLimitError
from google.api_core.exceptions import ResourceExhausted
from src.helpers.helper import get_folder_size, clear_folder
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Path to the .env file
ENV_FILE_PATH = os.getenv("WRITABLE_DIR", "/tmp") + "/.env"
# Define the upload directory and maximum folder size
UPLOAD_DIRECTORY = os.getenv("WRITABLE_DIR", "/tmp") + "/uploads"
MAX_FOLDER_SIZE = 10 * 1024 * 1024 # 10 MB in bytes
CONTEXT_LENGTH = 128000
BUFFER = 10000
MAX_TOKENS_ALLOWED = CONTEXT_LENGTH - BUFFER
# Per-session state
SESSION_STORE: Dict[str, Dict[str, Any]] = {}
# Format error message for SSE
def format_error_sse(event_type: str, data: str) -> str:
lines = data.splitlines()
sse_message = f"event: {event_type}\n"
for line in lines:
sse_message += f"data: {line}\n"
sse_message += "\n"
return sse_message
# Stop the task on error (non-fastapi)
def stop_on_error():
state = SESSION_STORE
if "process_task" in state:
state["process_task"].cancel()
del state["process_task"]
# Get OAuth tokens for MCP tools
def get_oauth_token(provider: str) -> Optional[str]:
if "oauth_tokens" in SESSION_STORE and provider in SESSION_STORE["oauth_tokens"]:
token_data = SESSION_STORE["oauth_tokens"][provider]
# Check if token is expired (1 hour)
if time.time() - token_data["timestamp"] < 3600:
return token_data["token"]
else:
# Token expired, remove it
del SESSION_STORE["oauth_tokens"][provider]
logger.info(f"{provider} token expired and removed")
return None
# Initialize the components
async def initialize_components():
load_dotenv(ENV_FILE_PATH, override=True)
from src.search.search_engine import SearchEngine
from src.query_processing.query_processor import QueryProcessor
# from src.rag.neo4j_graphrag import Neo4jGraphRAG
from src.rag.graph_rag import GraphRAG
from src.evaluation.evaluator import Evaluator
from src.reasoning.reasoner import Reasoner
from src.crawl.crawler import CustomCrawler
from src.utils.api_key_manager import APIKeyManager
from src.query_processing.late_chunking.late_chunker import LateChunker
from src.integrations.mcp_client import MCPClient
state = SESSION_STORE
manager = APIKeyManager()
manager._reinit()
state['search_engine'] = SearchEngine()
state['query_processor'] = QueryProcessor()
state['crawler'] = CustomCrawler(max_concurrent_requests=1000)
# state['graph_rag'] = Neo4jGraphRAG(num_workers=os.cpu_count() * 2)
state['graph_rag'] = GraphRAG(num_workers=os.cpu_count() * 2)
state['evaluator'] = Evaluator()
state['reasoner'] = Reasoner()
state['model'] = manager.get_llm()
state['late_chunker'] = LateChunker()
state["mcp_client"] = MCPClient()
state["initialized"] = True
state["session_id"] = await state["crawler"].create_session()
# Main function to process user queries
async def process_query(user_query: str, sse_queue: asyncio.Queue):
state = SESSION_STORE
try:
# --- Categorize the query ---
category = await state["query_processor"].classify_query(user_query)
cat_lower = category.lower().strip()
user_query = re.sub(r'category:.*', '', user_query, flags=re.IGNORECASE).strip()
# --- Read and extract user-provided files and links ---
# Initialize caches if not present
if "user_files_cache" not in state:
state["user_files_cache"] = {}
if "user_links_cache" not in state:
state["user_links_cache"] = {}
# Extract user-provided context
user_context = ""
user_links = state.get("user_provided_links", [])
# Read new uploaded files
if state["session_id"]:
session_upload_path = os.path.join(UPLOAD_DIRECTORY, state["session_id"])
if os.path.exists(session_upload_path):
for filename in os.listdir(session_upload_path):
file_path = os.path.join(session_upload_path, filename)
if os.path.isfile(file_path):
# Check if file is already in cache
if filename not in state["user_files_cache"]:
try:
await sse_queue.put(("step", "Reading User-Provided Files..."))
with open(file_path, 'r', encoding='utf-8') as f:
file_content = f.read()
state["user_files_cache"][filename] = file_content
except Exception as e:
logger.error(f"Error reading file {filename}: {e}")
# Try reading as binary and decode
try:
with open(file_path, 'rb') as f:
file_content = f.read().decode('utf-8', errors='ignore')
state["user_files_cache"][filename] = file_content
except Exception as e2:
logger.error(f"Error reading file {filename} as binary: {e2}")
state["user_files_cache"][filename] = "" # Cache empty to avoid retrying
# Add all cached file contents
for filename, content in state["user_files_cache"].items():
if content:
user_context += f"\n[USER PROVIDED FILE: {filename} START]\n{content}\n[USER PROVIDED FILE: {filename} END]\n\n"
# Crawl new user-provided links
if user_links:
await sse_queue.put(("step", "Crawling User-Provided Links..."))
new_links = [link for link in user_links if link not in state["user_links_cache"]]
if new_links:
# Only crawl new links
link_contents = await state['crawler'].fetch_page_contents(
new_links,
user_query,
state["session_id"],
max_attempts=1
)
# Cache the new contents
for link, content in zip(new_links, link_contents):
if not isinstance(content, Exception) and content:
state["user_links_cache"][link] = content
else:
state["user_links_cache"][link] = "" # Cache empty to avoid retrying
# Add all cached link contents
for link, content in state["user_links_cache"].items():
if content:
idx = user_links.index(link) + 1 if link in user_links else 0
user_context += f"\n[USER PROVIDED LINK {idx} START]\n{content}\n[USER PROVIDED LINK {idx} END]\n\n"
# --- Fetch apps data from MCP service ---
app_context = ""
selected_services = state.get("selected_services", {})
# Check if any services are selected
has_google = selected_services.get("google", [])
has_microsoft = selected_services.get("microsoft", [])
has_slack = selected_services.get("slack", False)
if has_google or has_microsoft or has_slack:
await sse_queue.put(("step", "Fetching Data From Connected Apps..."))
# Fetch from each provider in parallel
tasks = []
# Google services
if has_google and len(has_google) > 0:
google_token = get_oauth_token("google")
tasks.append(
state['mcp_client'].fetch_app_data(
provider="google",
services=has_google,
query=user_query,
user_id=state["session_id"],
access_token=google_token
)
)
# Microsoft services
if has_microsoft and len(has_microsoft) > 0:
microsoft_token = get_oauth_token("microsoft")
tasks.append(
state['mcp_client'].fetch_app_data(
provider="microsoft",
services=has_microsoft,
query=user_query,
user_id=state["session_id"],
access_token=microsoft_token
)
)
# Slack
if has_slack:
slack_token = get_oauth_token("slack")
tasks.append(
state['mcp_client'].fetch_app_data(
provider="slack",
services=["messages"], # Slack doesn't have sub-services
query=user_query,
user_id=state["session_id"],
access_token=slack_token
)
)
# Execute all requests in parallel
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"Error fetching app data: {result}")
elif isinstance(result, dict):
# Determine which provider this result is from
if i == 0 and has_google:
provider = "google"
elif (i == 1 and has_microsoft) or (i == 0 and not has_google and has_microsoft):
provider = "microsoft"
else:
provider = "slack"
# Format the data
formatted_context = state['mcp_client'].format_as_context(provider, result)
if formatted_context:
app_context += formatted_context
# Log how much app data we got
if app_context:
logger.info(f"Retrieved app data: {len(app_context)} characters")
# Prepend app context to user context
if app_context:
user_context = app_context + "\n\n" + user_context
# Upgrade basic to advanced if user has provided links
if cat_lower == "basic" and user_links:
cat_lower = "advanced"
# --- Process the query based on the category ---
if cat_lower == "basic":
response = ""
chunk_counter = 1
if user_context: # Include user context if available
await sse_queue.put(("step", "Generating Response..."))
async for chunk in state["reasoner"].answer(user_query, user_context, query_type="basic"):
await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter})))
response += chunk
chunk_counter += 1
else: # No user context provided
async for chunk in state["reasoner"].answer(user_query):
await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter})))
response += chunk
chunk_counter += 1
await sse_queue.put(("final_message", response))
SESSION_STORE["chat_history"].append({"query": user_query, "response": response})
await sse_queue.put(("action", {
"name": "evaluate",
"payload": {"query": user_query, "response": response}
}))
await sse_queue.put(("complete", "done"))
elif cat_lower == "advanced":
await sse_queue.put(("step", "Searching..."))
optimized_query = await state['search_engine'].generate_optimized_query(user_query)
search_results = await state['search_engine'].search(
optimized_query,
num_results=3,
exclude_filetypes=["pdf"]
)
urls = [r.get('link', 'No URL') for r in search_results]
search_contents = await state['crawler'].fetch_page_contents(
urls,
user_query,
state["session_id"],
max_attempts=1
)
# Start with user-provided context
contents = user_context
# Add crawled contents
if search_contents:
for k, content in enumerate(search_contents, 1):
if isinstance(content, Exception):
print(f"Error fetching content: {content}")
elif content:
contents += f"[SOURCE {k} START]\n{content}\n[SOURCE {k} END]\n\n"
if len(contents.strip()) > 0:
await sse_queue.put(("step", "Generating Response..."))
token_count = state['model'].get_num_tokens(contents)
if token_count > MAX_TOKENS_ALLOWED:
contents = await state['late_chunker'].chunker(contents, user_query, MAX_TOKENS_ALLOWED)
await sse_queue.put(("sources_read", len(search_contents)))
response = ""
chunk_counter = 1
async for chunk in state["reasoner"].answer(user_query, contents):
await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter})))
response += chunk
chunk_counter += 1
sources_for_answer = []
for idx, result in enumerate(search_results, 1):
if search_contents[idx-1]: # Only include if content was successfully fetched
sources_for_answer.append({
"id": idx,
"title": result.get('title', 'No Title'),
"link": result.get('link', 'No URL')
}
)
await sse_queue.put(("final_message", response))
await sse_queue.put(("final_sources", json.dumps(sources_for_answer)))
SESSION_STORE["chat_history"].append({"query": user_query, "response": response})
SESSION_STORE["answer"] = response
SESSION_STORE["source_contents"] = contents
await sse_queue.put(("action", {
"name": "sources",
"payload": {"search_results": search_results, "search_contents": search_contents}
}))
await sse_queue.put(("action", {
"name": "evaluate",
"payload": {"query": user_query, "contents": [contents], "response": response}
}))
await sse_queue.put(("complete", "done"))
else:
await sse_queue.put(("error", "No results found."))
elif cat_lower == "pro":
current_search_results = []
current_search_contents = []
await sse_queue.put(("step", "Thinking..."))
start = time.time()
intent = await state['query_processor'].get_query_intent(user_query)
sub_queries, _ = await state['query_processor'].decompose_query(user_query, intent)
async def sub_query_task(sub_query):
try:
await sse_queue.put(("step", "Searching..."))
await sse_queue.put(("task", (sub_query, "RUNNING")))
optimized_query = await state['search_engine'].generate_optimized_query(sub_query)
search_results = await state['search_engine'].search(
optimized_query,
num_results=10,
exclude_filetypes=["pdf"]
)
filtered_urls = await state['search_engine'].filter_urls(
sub_query,
category,
search_results
)
current_search_results.extend(filtered_urls)
# Combine search results with user-provided links
all_search_results = search_results + \
[{"link": url, "title": f"User provided: {url}", "snippet": ""} for url in user_links]
urls = [r.get('link', 'No URL') for r in all_search_results]
search_contents = await state['crawler'].fetch_page_contents(
urls,
sub_query,
state["session_id"],
max_attempts=1
)
current_search_contents.extend(search_contents)
contents = user_context
if search_contents:
for k, c in enumerate(search_contents, 1):
if isinstance(c, Exception):
logger.info(f"Error fetching content: {c}")
elif c:
contents += f"[SOURCE {k} START]\n{c}\n[SOURCE {k} END]\n\n"
if len(contents.strip()) > 0:
await sse_queue.put(("task", (sub_query, "DONE")))
else:
await sse_queue.put(("task", (sub_query, "FAILED")))
return contents
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError):
await sse_queue.put(("task", (sub_query, "FAILED")))
return ""
tasks = []
if len(sub_queries) > 1 and sub_queries[0] != user_query:
for sub_query in sub_queries:
tasks.append(sub_query_task(sub_query))
results = await asyncio.gather(*tasks)
end = time.time()
# Start with user-provided context
contents = user_context
# Add searched contents
contents += "\n\n".join(r for r in results if r.strip())
unique_results = []
seen = set()
for entry in current_search_results:
link = entry["link"]
if link not in seen:
seen.add(link)
unique_results.append(entry)
current_search_results = unique_results
current_search_contents = list(set(current_search_contents))
if len(contents.strip()) > 0:
await sse_queue.put(("step", "Generating Response..."))
token_count = state['model'].get_num_tokens(contents)
if token_count > MAX_TOKENS_ALLOWED:
contents = await state['late_chunker'].chunker(
text=contents,
query=user_query,
max_tokens=MAX_TOKENS_ALLOWED
)
logger.info(f"Number of tokens in the answer: {token_count}")
logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(contents)}")
await sse_queue.put(("sources_read", len(current_search_contents)))
response = ""
chunk_counter = 1
is_first_chunk = True
async for chunk in state['reasoner'].answer(user_query, contents):
if is_first_chunk:
await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds"))
is_first_chunk = False
await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter})))
response += chunk
chunk_counter += 1
sources_for_answer = []
for idx, (result, content) in enumerate(zip(current_search_results, current_search_contents), 1):
if content: # Only include if content was successfully fetched
sources_for_answer.append({
"id": idx,
"title": result.get('title', 'No Title'),
"link": result.get('link', 'No URL')
})
await sse_queue.put(("final_message", response))
await sse_queue.put(("final_sources", json.dumps(sources_for_answer)))
SESSION_STORE["chat_history"].append({"query": user_query, "response": response})
SESSION_STORE["answer"] = response
SESSION_STORE["source_contents"] = contents
await sse_queue.put(("action", {
"name": "sources",
"payload": {
"search_results": current_search_results,
"search_contents": current_search_contents
}
}))
await sse_queue.put(("action", {
"name": "evaluate",
"payload": {"query": user_query, "contents": [contents], "response": response}
}))
await sse_queue.put(("complete", "done"))
else:
await sse_queue.put(("error", "No results found."))
elif cat_lower == "super":
current_search_results = []
current_search_contents = []
await sse_queue.put(("step", "Thinking..."))
start = time.time()
main_query_intent = await state['query_processor'].get_query_intent(user_query)
sub_queries, _ = await state['query_processor'].decompose_query(user_query, main_query_intent)
await sse_queue.put(("step", "Searching..."))
async def sub_query_task(sub_query):
try:
async def sub_sub_query_task(sub_sub_query):
optimized_query = await state['search_engine'].generate_optimized_query(sub_sub_query)
search_results = await state['search_engine'].search(
optimized_query,
num_results=10,
exclude_filetypes=["pdf"]
)
filtered_urls = await state['search_engine'].filter_urls(
sub_sub_query,
category,
search_results
)
current_search_results.extend(filtered_urls)
urls = [r.get('link', 'No URL') for r in filtered_urls]
search_contents = await state['crawler'].fetch_page_contents(
urls,
sub_sub_query,
state["session_id"],
max_attempts=1,
timeout=20
)
current_search_contents.extend(search_contents)
contents = ""
if search_contents:
for k, c in enumerate(search_contents, 1):
if isinstance(c, Exception):
logger.info(f"Error fetching content: {c}")
elif c:
contents += f"[SOURCE {k} START]\n{c}\n[SOURCE {k} END]\n\n"
return contents
await sse_queue.put(("task", (sub_query, "RUNNING")))
sub_sub_queries, _ = await state['query_processor'].decompose_query(sub_query)
tasks = []
if len(sub_sub_queries) > 1 and sub_sub_queries[0] != user_query:
for sub_sub_query in sub_sub_queries:
tasks.append(sub_sub_query_task(sub_sub_query))
results = await asyncio.gather(*tasks)
if any(result.strip() for result in results):
await sse_queue.put(("task", (sub_query, "DONE")))
else:
await sse_queue.put(("task", (sub_query, "FAILED")))
return results
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError):
await sse_queue.put(("task", (sub_query, "FAILED")))
return []
tasks = []
if len(sub_queries) > 1 and sub_queries[0] != user_query:
for sub_query in sub_queries:
tasks.append(sub_query_task(sub_query))
results = await asyncio.gather(*tasks)
end = time.time()
# Start with user-provided context
previous_contents = []
if user_context:
previous_contents.append(user_context)
for result in results:
if result:
for content in result:
if isinstance(content, str) and len(content.strip()) > 0:
previous_contents.append(content)
contents = "\n\n".join(previous_contents)
unique_results = []
seen = set()
for entry in current_search_results:
link = entry["link"]
if link not in seen:
seen.add(link)
unique_results.append(entry)
current_search_results = unique_results
current_search_contents = list(set(current_search_contents))
if len(contents.strip()) > 0:
await sse_queue.put(("step", "Generating Response..."))
token_count = state['model'].get_num_tokens(contents)
if token_count > MAX_TOKENS_ALLOWED:
contents = await state['late_chunker'].chunker(
text=contents,
query=user_query,
max_tokens=MAX_TOKENS_ALLOWED
)
logger.info(f"Number of tokens in the answer: {token_count}")
logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(contents)}")
await sse_queue.put(("sources_read", len(current_search_contents)))
response = ""
chunk_counter = 1
is_first_chunk = True
async for chunk in state['reasoner'].answer(user_query, contents):
if is_first_chunk:
await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds"))
is_first_chunk = False
await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter})))
response += chunk
chunk_counter += 1
sources_for_answer = []
for idx, (result, content) in enumerate(zip(current_search_results, current_search_contents), 1):
if content: # Only include if content was successfully fetched
sources_for_answer.append({
"id": idx,
"title": result.get('title', 'No Title'),
"link": result.get('link', 'No URL')
})
await sse_queue.put(("final_message", response))
await sse_queue.put(("final_sources", json.dumps(sources_for_answer)))
SESSION_STORE["chat_history"].append({"query": user_query, "response": response})
SESSION_STORE["answer"] = response
SESSION_STORE["source_contents"] = contents
await sse_queue.put(("action", {
"name": "sources",
"payload": {
"search_results": current_search_results,
"search_contents": current_search_contents
}
}))
await sse_queue.put(("action", {
"name": "evaluate",
"payload": {"query": user_query, "contents": [contents], "response": response}
}))
await sse_queue.put(("complete", "done"))
else:
await sse_queue.put(("error", "No results found."))
elif cat_lower == "ultra":
current_search_results = []
current_search_contents = []
match = re.search(
r"^This is the previous context of the conversation:\s*.*?\s*Current Query:\s*(.*)$",
user_query,
flags=re.DOTALL | re.MULTILINE
)
if match:
user_query = match.group(1)
await sse_queue.put(("step", "Thinking..."))
await asyncio.sleep(0.01) # Sleep for a short time to allow the message to be sent
async def on_event_callback(event_type, data):
if event_type == "graph_operation":
if data["operation_type"] == "creating_new_graph":
await sse_queue.put(("step", "Creating New Graph..."))
elif data["operation_type"] == "modifying_existing_graph":
await sse_queue.put(("step", "Modifying Existing Graph..."))
elif data["operation_type"] == "loading_existing_graph":
await sse_queue.put(("step", "Loading Existing Graph..."))
elif event_type == "sub_query_created":
sub_query = data["sub_query"]
await sse_queue.put(("task", (sub_query, "RUNNING")))
elif event_type == "search_process_started":
await sse_queue.put(("step", "Searching..."))
elif event_type == "sub_query_processed":
sub_query = data["sub_query"]
await sse_queue.put(("task", (sub_query, "DONE")))
elif event_type == "sub_query_failed":
sub_query = data["sub_query"]
await sse_queue.put(("task", (sub_query, "FAILED")))
elif event_type == "search_results_filtered":
current_search_results.extend(data["filtered_urls"])
filtered_urls = data["filtered_urls"]
current_search_results.extend(filtered_urls)
elif event_type == "search_contents_fetched":
current_search_contents.extend(data["contents"])
contents = data["contents"]
current_search_contents.extend(contents)
elif event_type == "search_process_completed":
await sse_queue.put(("step", "Processing final graph tasks..."))
await asyncio.sleep(0.01) # Sleep for a short time to allow the message to be sent
state['graph_rag'].set_on_event_callback(on_event_callback)
start = time.time()
# state['graph_rag'].initialize_schema()
await state['graph_rag'].process_graph(
user_query,
similarity_threshold=0.8,
relevance_threshold=0.8,
max_tokens_allowed=MAX_TOKENS_ALLOWED
)
end = time.time()
unique_results = []
seen = set()
for entry in current_search_results:
link = entry["link"]
if link not in seen:
seen.add(link)
unique_results.append(entry)
current_search_results = unique_results
current_search_contents = list(set(current_search_contents))
await sse_queue.put(("step", "Generating Response..."))
answer = state['graph_rag'].query_graph(user_query)
if answer:
# Start with user-provided context
previous_contents = []
if user_context:
previous_contents.append(user_context)
token_count = state['model'].get_num_tokens(answer)
if token_count > MAX_TOKENS_ALLOWED:
answer = await state['late_chunker'].chunker(
text=answer,
query=user_query,
max_tokens=MAX_TOKENS_ALLOWED
)
logger.info(f"Number of tokens in the answer: {token_count}")
logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(answer)}")
await sse_queue.put(("sources_read", len(current_search_contents)))
response = ""
chunk_counter = 1
is_first_chunk = True
async for chunk in state['reasoner'].answer(user_query, answer):
if is_first_chunk:
await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds"))
is_first_chunk = False
await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter})))
response += chunk
chunk_counter += 1
sources_for_answer = []
for idx, (result, content) in enumerate(zip(current_search_results, current_search_contents), 1):
if content: # Only include if content was successfully fetched
sources_for_answer.append({
"id": idx,
"title": result.get('title', 'No Title'),
"link": result.get('link', 'No URL')
})
await sse_queue.put(("final_message", response))
await sse_queue.put(("final_sources", json.dumps(sources_for_answer)))
SESSION_STORE["chat_history"].append({"query": user_query, "response": response})
SESSION_STORE["answer"] = response
SESSION_STORE["source_contents"] = contents
await sse_queue.put(("action", {
"name": "sources",
"payload": {"search_results": current_search_results, "search_contents": current_search_contents},
}))
await sse_queue.put(("action", {
"name": "graph",
"payload": {"query": user_query},
}))
await sse_queue.put(("action", {
"name": "evaluate",
"payload": {"query": user_query, "contents": [answer], "response": response},
}))
await sse_queue.put(("complete", "done"))
else:
await sse_queue.put(("error", "No results found."))
else:
await sse_queue.put(("final_message", "I'm not sure how to handle your query."))
except Exception as e:
await sse_queue.put(("error", str(e)))
traceback.print_exc()
stop()
# Create a FastAPI app
app = FastAPI()
# Define allowed origins
origins = [
"http://localhost:3000",
"http://localhost:7860"
"http://localhost:8000",
"http://localhost"
]
# Add the CORS middleware to your FastAPI app
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # Allows only these origins
allow_credentials=True,
allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.)
allow_headers=["*"], # Allows all headers
)
# Serve the React app (the production build) at the root URL.
app.mount("/static", StaticFiles(directory="frontend/build/static", html=True), name="static")
# Define the routes for the FastAPI app
# Define the route for sources action to display search results
@app.post("/action/sources")
def action_sources(payload: Dict[str, Any]) -> Dict[str, Any]:
try:
search_contents = payload.get("search_contents", [])
search_results = payload.get("search_results", [])
sources = []
word_limit = 15 # Maximum number of words for the description
for result, contents in zip(search_results, search_contents):
if contents:
title = result.get('title', 'No Title')
link = result.get('link', 'No URL')
snippet = result.get('snippet', 'No snippet')
cleaned = re.sub(r'<[^>]+>|\[\/?.*?\]', '', snippet)
words = cleaned.split()
if len(words) > word_limit:
description = " ".join(words[:word_limit]) + "..."
else:
description = " ".join(words)
source_obj = {
"title": title,
"link": link,
"description": description
}
sources.append(source_obj)
return {"result": sources}
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
# Define the route for graph action to display the graph
@app.post("/action/graph")
def action_graph() -> Dict[str, Any]:
state = SESSION_STORE
try:
html_str = state['graph_rag'].display_graph()
return {"result": html_str}
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
# Define the route for evaluate action to display evaluation results
@app.post("/action/evaluate")
async def action_evaluate(payload: Dict[str, Any]) -> Dict[str, Any]:
try:
query = payload.get("query", "")
contents = payload.get("contents", [])
response = payload.get("response", "")
metrics = payload.get("metrics", [])
state = SESSION_STORE
evaluator = state["evaluator"]
result = await evaluator.evaluate_response(query, response, contents, include_metrics=metrics)
return {"result": result}
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
# Define the route for excerpts action to display excerpts from the sources
@app.post("/action/excerpts")
async def action_excerpts() -> Dict[str, Any]:
def validate_excerpts_format(excerpts):
if not isinstance(excerpts, list):
return False
for item in excerpts:
if not isinstance(item, dict):
return False
for statement, sources in item.items():
if not isinstance(statement, str) or not isinstance(sources, dict):
return False
for src_num, excerpt in sources.items():
if not (isinstance(src_num, int) or isinstance(src_num, str)):
return False
if not isinstance(excerpt, str):
return False
return True
try:
state = SESSION_STORE
response = state["answer"]
contents = state["source_contents"]
if not response or not contents:
raise ValueError("Required data for excerpts not found")
excerpts_list = await state["reasoner"].get_excerpts(response, contents)
cleaned_excerpts = re.sub(
r'```[\w\s]*\n?|```|~~~[\w\s]*\n?|~~~', '', excerpts_list, flags=re.MULTILINE | re.DOTALL
).strip()
try:
excerpts = eval(cleaned_excerpts)
except Exception:
print(f"Error parsing excerpts:\n{cleaned_excerpts}")
raise ValueError("Excerpts could not be parsed as a Python list.")
if not validate_excerpts_format(excerpts):
print(f"Excerpts format validation failed:\n{excerpts}")
raise ValueError("Excerpts are not in the required format.")
print(f"Excerpts:\n{excerpts}")
return {"result": excerpts}
except Exception as e:
print(f"Error in action_excerpts: {e}")
return JSONResponse(content={"error": str(e)}, status_code=500)
# Define the route for settings to set or update the environment variables
@app.post("/settings")
async def update_settings(data: Dict[str, Any]):
from src.helpers.helper import (
prepare_provider_key_updates,
prepare_proxy_list_updates,
update_env_vars
)
provider = data.get("Model_Provider", "").strip()
model_name = data.get("Model_Name", "").strip()
multiple_api_keys = data.get("Model_API_Keys", "").strip()
brave_api_key = data.get("Brave_Search_API_Key", "").strip()
proxy_list = data.get("Proxy_List", "").strip()
model_temperature = str(data.get("Model_Temperature", 0.0))
model_top_p = str(data.get("Model_Top_P", 1.0))
prov_lower = provider.lower()
key_updates = prepare_provider_key_updates(prov_lower, multiple_api_keys)
env_updates = {}
env_updates.update(key_updates)
px = prepare_proxy_list_updates(proxy_list)
if px:
env_updates.update(px)
env_updates["BRAVE_API_KEY"] = brave_api_key
env_updates["MODEL_PROVIDER"] = prov_lower
env_updates["MODEL_NAME"] = model_name
env_updates["MODEL_TEMPERATURE"] = model_temperature
env_updates["MODEL_TOP_P"] = model_top_p
update_env_vars(env_updates)
load_dotenv(override=True)
await initialize_components()
return {"success": True}
# Define the route for adding/uploading content for a specific session
@app.post("/add-content")
async def add_content(files: Optional[List[UploadFile]] = File(None), urls: str = Form(...)):
state = SESSION_STORE
session_id = state.get("session_id")
if not session_id:
raise HTTPException(status_code=400, detail="Session ID is not set. Please start a session first.")
session_upload_path = os.path.join(UPLOAD_DIRECTORY, session_id)
os.makedirs(session_upload_path, exist_ok=True)
saved_filenames = []
if files:
total_new_files_size = sum(file.size for file in files)
current_folder_size = get_folder_size(session_upload_path)
# Check if the total size exceeds the maximum allowed folder size
if current_folder_size + total_new_files_size > MAX_FOLDER_SIZE:
raise HTTPException(
status_code=400,
detail=f"Cannot add files as total storage would exceed 10 MB. Current size: {current_folder_size / (1024 * 1024):.2f} MB"
)
for file in files:
file_path = os.path.join(session_upload_path, file.filename)
try:
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
saved_filenames.append(file.filename)
finally:
file.file.close()
try:
parsed_urls = json.loads(urls)
print(f"Received links: {parsed_urls}")
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid URL format.")
# Store user-provided links in session
if parsed_urls:
SESSION_STORE["user_provided_links"] = parsed_urls
return {
"message": "Content added successfully",
"files_added": saved_filenames,
"links_added": parsed_urls
}
# Define the route to update the selected services for searching
@app.post("/api/selected-services")
async def update_selected_services(data: Dict[str, Any]):
state = SESSION_STORE
selected_services = data.get("services", {})
state["selected_services"] = selected_services
logger.info(f"Updated selected services: {selected_services}")
return {"success": True, "services": selected_services}
# Define the route to receive OAuth tokens from the frontend
@app.post("/api/session-token")
async def receive_session_token(data: Dict[str, Any]):
provider = data.get("provider") # 'google', 'microsoft', 'slack'
token = data.get("token")
if not provider or not token:
raise HTTPException(status_code=400, detail="Provider and token are required")
SESSION_STORE["oauth_tokens"][provider] = {
"token": token,
"timestamp": time.time()
}
logger.info(f"Stored token {token} for provider {provider}")
return {"success": True, "message": f"{provider} token stored successfully"}
# Define the route for cleaning up a session if the session ID matches
@app.post("/cleanup")
async def cleanup_session():
state = SESSION_STORE
session_id = state.get("session_id")
if not session_id:
raise HTTPException(status_code=400, detail="Session ID is not set. Please start a session first.")
session_upload_path = os.path.join(UPLOAD_DIRECTORY, session_id)
if session_id:
# Clear the session upload directory
clear_folder(session_upload_path)
# Clear user-provided links and caches
SESSION_STORE["user_provided_links"] = []
SESSION_STORE["user_files_cache"] = {}
SESSION_STORE["user_links_cache"] = {}
SESSION_STORE["selected_services"] = {}
SESSION_STORE["oauth_tokens"] = {}
return {"message": "Cleanup successful."}
return {"message": "No session ID provided, cleanup skipped."}
@app.on_event("startup")
def init_chat():
if not SESSION_STORE:
print("Initializing chat...")
# Create the upload directory if it doesn't exist
print("Creating upload directory...")
os.makedirs(UPLOAD_DIRECTORY, exist_ok=True)
# Initialize the session store
SESSION_STORE["settings_saved"] = False
SESSION_STORE["session_id"] = None
SESSION_STORE["answer"] = None
SESSION_STORE["source_contents"] = None
SESSION_STORE["chat_history"] = []
SESSION_STORE["user_provided_links"] = []
SESSION_STORE["user_files_cache"] = {}
SESSION_STORE["user_links_cache"] = {}
SESSION_STORE["selected_services"] = {}
SESSION_STORE["oauth_tokens"] = {}
print("Chat initialized!")
return {"sucess": True}
else:
print("Chat already initialized!")
return {"success": False}
@app.get("/message-sse")
async def sse_message(request: Request, user_message: str):
state = SESSION_STORE
sse_queue = asyncio.Queue()
async def event_generator():
# Build the prompt
context = state["chat_history"][-3:]
if context:
prompt = \
f"""This is the previous context of the conversation:
{context}
Current Query:
{user_message}"""
else:
prompt = user_message
task = asyncio.create_task(process_query(prompt, sse_queue))
state["process_task"] = task
while True:
if await request.is_disconnected():
task.cancel()
break
try:
event_type, data = await asyncio.wait_for(sse_queue.get(), timeout=5)
if event_type == "token":
yield f"event: token\ndata: {data}\n\n"
elif event_type == "final_message":
yield f"event: final_message\ndata: {data}\n\n"
elif event_type == "error":
stop_on_error()
yield format_error_sse("error", data)
elif event_type == "step":
yield f"event: step\ndata: {data}\n\n"
elif event_type == "task":
subq, status = data
j = {"task": subq, "status": status}
yield f"event: task\ndata: {json.dumps(j)}\n\n"
elif event_type == "sources_read":
yield f"event: sources_read\ndata: {data}\n\n"
elif event_type == "action":
yield f"event: action\ndata: {json.dumps(data)}\n\n"
elif event_type == "complete":
yield f"event: complete\ndata: {data}\n\n"
break
else:
yield f"event: message\ndata: {data}\n\n"
except asyncio.TimeoutError:
if task.done():
break
continue
except asyncio.CancelledError:
break
if not task.done():
task.cancel()
if "process_task" in state:
del state["process_task"]
return StreamingResponse(event_generator(), media_type="text/event-stream")
@app.post("/stop")
def stop():
state = SESSION_STORE
if "process_task" in state:
state["process_task"].cancel()
del state["process_task"]
return {"message": "Stopped task manually"}
# Catch-all route for frontend paths.
@app.get("/{full_path:path}")
async def serve_frontend(full_path: str, request: Request):
index_path = os.path.join("frontend", "build", "index.html")
if not os.path.exists(index_path):
raise HTTPException(status_code=500, detail="Frontend build not found")
return FileResponse(index_path)