Spaces:
Paused
Paused
import os | |
import re | |
import json | |
import time | |
import shutil | |
import asyncio | |
import logging | |
import traceback | |
from typing import List, Dict, Any, Optional | |
from fastapi.staticfiles import StaticFiles | |
from fastapi import FastAPI, Request, HTTPException, UploadFile, File, Form | |
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from dotenv import load_dotenv | |
from tenacity import RetryError | |
from openai import RateLimitError | |
from anthropic import RateLimitError as AnthropicRateLimitError | |
from google.api_core.exceptions import ResourceExhausted | |
from src.helpers.helper import get_folder_size, clear_folder | |
logger = logging.getLogger() | |
logger.setLevel(logging.INFO) | |
# Path to the .env file | |
ENV_FILE_PATH = os.getenv("WRITABLE_DIR", "/tmp") + "/.env" | |
# Define the upload directory and maximum folder size | |
UPLOAD_DIRECTORY = os.getenv("WRITABLE_DIR", "/tmp") + "/uploads" | |
MAX_FOLDER_SIZE = 10 * 1024 * 1024 # 10 MB in bytes | |
CONTEXT_LENGTH = 128000 | |
BUFFER = 10000 | |
MAX_TOKENS_ALLOWED = CONTEXT_LENGTH - BUFFER | |
# Per-session state | |
SESSION_STORE: Dict[str, Dict[str, Any]] = {} | |
# Format error message for SSE | |
def format_error_sse(event_type: str, data: str) -> str: | |
lines = data.splitlines() | |
sse_message = f"event: {event_type}\n" | |
for line in lines: | |
sse_message += f"data: {line}\n" | |
sse_message += "\n" | |
return sse_message | |
# Stop the task on error (non-fastapi) | |
def stop_on_error(): | |
state = SESSION_STORE | |
if "process_task" in state: | |
state["process_task"].cancel() | |
del state["process_task"] | |
# Get OAuth tokens for MCP tools | |
def get_oauth_token(provider: str) -> Optional[str]: | |
if "oauth_tokens" in SESSION_STORE and provider in SESSION_STORE["oauth_tokens"]: | |
token_data = SESSION_STORE["oauth_tokens"][provider] | |
# Check if token is expired (1 hour) | |
if time.time() - token_data["timestamp"] < 3600: | |
return token_data["token"] | |
else: | |
# Token expired, remove it | |
del SESSION_STORE["oauth_tokens"][provider] | |
logger.info(f"{provider} token expired and removed") | |
return None | |
# Initialize the components | |
async def initialize_components(): | |
load_dotenv(ENV_FILE_PATH, override=True) | |
from src.search.search_engine import SearchEngine | |
from src.query_processing.query_processor import QueryProcessor | |
# from src.rag.neo4j_graphrag import Neo4jGraphRAG | |
from src.rag.graph_rag import GraphRAG | |
from src.evaluation.evaluator import Evaluator | |
from src.reasoning.reasoner import Reasoner | |
from src.crawl.crawler import CustomCrawler | |
from src.utils.api_key_manager import APIKeyManager | |
from src.query_processing.late_chunking.late_chunker import LateChunker | |
from src.integrations.mcp_client import MCPClient | |
state = SESSION_STORE | |
manager = APIKeyManager() | |
manager._reinit() | |
state['search_engine'] = SearchEngine() | |
state['query_processor'] = QueryProcessor() | |
state['crawler'] = CustomCrawler(max_concurrent_requests=1000) | |
# state['graph_rag'] = Neo4jGraphRAG(num_workers=os.cpu_count() * 2) | |
state['graph_rag'] = GraphRAG(num_workers=os.cpu_count() * 2) | |
state['evaluator'] = Evaluator() | |
state['reasoner'] = Reasoner() | |
state['model'] = manager.get_llm() | |
state['late_chunker'] = LateChunker() | |
state["mcp_client"] = MCPClient() | |
state["initialized"] = True | |
state["session_id"] = await state["crawler"].create_session() | |
# Main function to process user queries | |
async def process_query(user_query: str, sse_queue: asyncio.Queue): | |
state = SESSION_STORE | |
try: | |
# --- Categorize the query --- | |
category = await state["query_processor"].classify_query(user_query) | |
cat_lower = category.lower().strip() | |
user_query = re.sub(r'category:.*', '', user_query, flags=re.IGNORECASE).strip() | |
# --- Read and extract user-provided files and links --- | |
# Initialize caches if not present | |
if "user_files_cache" not in state: | |
state["user_files_cache"] = {} | |
if "user_links_cache" not in state: | |
state["user_links_cache"] = {} | |
# Extract user-provided context | |
user_context = "" | |
user_links = state.get("user_provided_links", []) | |
# Read new uploaded files | |
if state["session_id"]: | |
session_upload_path = os.path.join(UPLOAD_DIRECTORY, state["session_id"]) | |
if os.path.exists(session_upload_path): | |
for filename in os.listdir(session_upload_path): | |
file_path = os.path.join(session_upload_path, filename) | |
if os.path.isfile(file_path): | |
# Check if file is already in cache | |
if filename not in state["user_files_cache"]: | |
try: | |
await sse_queue.put(("step", "Reading User-Provided Files...")) | |
with open(file_path, 'r', encoding='utf-8') as f: | |
file_content = f.read() | |
state["user_files_cache"][filename] = file_content | |
except Exception as e: | |
logger.error(f"Error reading file {filename}: {e}") | |
# Try reading as binary and decode | |
try: | |
with open(file_path, 'rb') as f: | |
file_content = f.read().decode('utf-8', errors='ignore') | |
state["user_files_cache"][filename] = file_content | |
except Exception as e2: | |
logger.error(f"Error reading file {filename} as binary: {e2}") | |
state["user_files_cache"][filename] = "" # Cache empty to avoid retrying | |
# Add all cached file contents | |
for filename, content in state["user_files_cache"].items(): | |
if content: | |
user_context += f"\n[USER PROVIDED FILE: {filename} START]\n{content}\n[USER PROVIDED FILE: {filename} END]\n\n" | |
# Crawl new user-provided links | |
if user_links: | |
await sse_queue.put(("step", "Crawling User-Provided Links...")) | |
new_links = [link for link in user_links if link not in state["user_links_cache"]] | |
if new_links: | |
# Only crawl new links | |
link_contents = await state['crawler'].fetch_page_contents( | |
new_links, | |
user_query, | |
state["session_id"], | |
max_attempts=1 | |
) | |
# Cache the new contents | |
for link, content in zip(new_links, link_contents): | |
if not isinstance(content, Exception) and content: | |
state["user_links_cache"][link] = content | |
else: | |
state["user_links_cache"][link] = "" # Cache empty to avoid retrying | |
# Add all cached link contents | |
for link, content in state["user_links_cache"].items(): | |
if content: | |
idx = user_links.index(link) + 1 if link in user_links else 0 | |
user_context += f"\n[USER PROVIDED LINK {idx} START]\n{content}\n[USER PROVIDED LINK {idx} END]\n\n" | |
# --- Fetch apps data from MCP service --- | |
app_context = "" | |
selected_services = state.get("selected_services", {}) | |
# Check if any services are selected | |
has_google = selected_services.get("google", []) | |
has_microsoft = selected_services.get("microsoft", []) | |
has_slack = selected_services.get("slack", False) | |
if has_google or has_microsoft or has_slack: | |
await sse_queue.put(("step", "Fetching Data From Connected Apps...")) | |
# Fetch from each provider in parallel | |
tasks = [] | |
# Google services | |
if has_google and len(has_google) > 0: | |
google_token = get_oauth_token("google") | |
tasks.append( | |
state['mcp_client'].fetch_app_data( | |
provider="google", | |
services=has_google, | |
query=user_query, | |
user_id=state["session_id"], | |
access_token=google_token | |
) | |
) | |
# Microsoft services | |
if has_microsoft and len(has_microsoft) > 0: | |
microsoft_token = get_oauth_token("microsoft") | |
tasks.append( | |
state['mcp_client'].fetch_app_data( | |
provider="microsoft", | |
services=has_microsoft, | |
query=user_query, | |
user_id=state["session_id"], | |
access_token=microsoft_token | |
) | |
) | |
# Slack | |
if has_slack: | |
slack_token = get_oauth_token("slack") | |
tasks.append( | |
state['mcp_client'].fetch_app_data( | |
provider="slack", | |
services=["messages"], # Slack doesn't have sub-services | |
query=user_query, | |
user_id=state["session_id"], | |
access_token=slack_token | |
) | |
) | |
# Execute all requests in parallel | |
if tasks: | |
results = await asyncio.gather(*tasks, return_exceptions=True) | |
# Process results | |
for i, result in enumerate(results): | |
if isinstance(result, Exception): | |
logger.error(f"Error fetching app data: {result}") | |
elif isinstance(result, dict): | |
# Determine which provider this result is from | |
if i == 0 and has_google: | |
provider = "google" | |
elif (i == 1 and has_microsoft) or (i == 0 and not has_google and has_microsoft): | |
provider = "microsoft" | |
else: | |
provider = "slack" | |
# Format the data | |
formatted_context = state['mcp_client'].format_as_context(provider, result) | |
if formatted_context: | |
app_context += formatted_context | |
# Log how much app data we got | |
if app_context: | |
logger.info(f"Retrieved app data: {len(app_context)} characters") | |
# Prepend app context to user context | |
if app_context: | |
user_context = app_context + "\n\n" + user_context | |
# Upgrade basic to advanced if user has provided links | |
if cat_lower == "basic" and user_links: | |
cat_lower = "advanced" | |
# --- Process the query based on the category --- | |
if cat_lower == "basic": | |
response = "" | |
chunk_counter = 1 | |
if user_context: # Include user context if available | |
await sse_queue.put(("step", "Generating Response...")) | |
async for chunk in state["reasoner"].answer(user_query, user_context, query_type="basic"): | |
await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) | |
response += chunk | |
chunk_counter += 1 | |
else: # No user context provided | |
async for chunk in state["reasoner"].answer(user_query): | |
await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) | |
response += chunk | |
chunk_counter += 1 | |
await sse_queue.put(("final_message", response)) | |
SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) | |
await sse_queue.put(("action", { | |
"name": "evaluate", | |
"payload": {"query": user_query, "response": response} | |
})) | |
await sse_queue.put(("complete", "done")) | |
elif cat_lower == "advanced": | |
await sse_queue.put(("step", "Searching...")) | |
optimized_query = await state['search_engine'].generate_optimized_query(user_query) | |
search_results = await state['search_engine'].search( | |
optimized_query, | |
num_results=3, | |
exclude_filetypes=["pdf"] | |
) | |
urls = [r.get('link', 'No URL') for r in search_results] | |
search_contents = await state['crawler'].fetch_page_contents( | |
urls, | |
user_query, | |
state["session_id"], | |
max_attempts=1 | |
) | |
# Start with user-provided context | |
contents = user_context | |
# Add crawled contents | |
if search_contents: | |
for k, content in enumerate(search_contents, 1): | |
if isinstance(content, Exception): | |
print(f"Error fetching content: {content}") | |
elif content: | |
contents += f"[SOURCE {k} START]\n{content}\n[SOURCE {k} END]\n\n" | |
if len(contents.strip()) > 0: | |
await sse_queue.put(("step", "Generating Response...")) | |
token_count = state['model'].get_num_tokens(contents) | |
if token_count > MAX_TOKENS_ALLOWED: | |
contents = await state['late_chunker'].chunker(contents, user_query, MAX_TOKENS_ALLOWED) | |
await sse_queue.put(("sources_read", len(search_contents))) | |
response = "" | |
chunk_counter = 1 | |
async for chunk in state["reasoner"].answer(user_query, contents): | |
await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) | |
response += chunk | |
chunk_counter += 1 | |
sources_for_answer = [] | |
for idx, result in enumerate(search_results, 1): | |
if search_contents[idx-1]: # Only include if content was successfully fetched | |
sources_for_answer.append({ | |
"id": idx, | |
"title": result.get('title', 'No Title'), | |
"link": result.get('link', 'No URL') | |
} | |
) | |
await sse_queue.put(("final_message", response)) | |
await sse_queue.put(("final_sources", json.dumps(sources_for_answer))) | |
SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) | |
SESSION_STORE["answer"] = response | |
SESSION_STORE["source_contents"] = contents | |
await sse_queue.put(("action", { | |
"name": "sources", | |
"payload": {"search_results": search_results, "search_contents": search_contents} | |
})) | |
await sse_queue.put(("action", { | |
"name": "evaluate", | |
"payload": {"query": user_query, "contents": [contents], "response": response} | |
})) | |
await sse_queue.put(("complete", "done")) | |
else: | |
await sse_queue.put(("error", "No results found.")) | |
elif cat_lower == "pro": | |
current_search_results = [] | |
current_search_contents = [] | |
await sse_queue.put(("step", "Thinking...")) | |
start = time.time() | |
intent = await state['query_processor'].get_query_intent(user_query) | |
sub_queries, _ = await state['query_processor'].decompose_query(user_query, intent) | |
async def sub_query_task(sub_query): | |
try: | |
await sse_queue.put(("step", "Searching...")) | |
await sse_queue.put(("task", (sub_query, "RUNNING"))) | |
optimized_query = await state['search_engine'].generate_optimized_query(sub_query) | |
search_results = await state['search_engine'].search( | |
optimized_query, | |
num_results=10, | |
exclude_filetypes=["pdf"] | |
) | |
filtered_urls = await state['search_engine'].filter_urls( | |
sub_query, | |
category, | |
search_results | |
) | |
current_search_results.extend(filtered_urls) | |
# Combine search results with user-provided links | |
all_search_results = search_results + \ | |
[{"link": url, "title": f"User provided: {url}", "snippet": ""} for url in user_links] | |
urls = [r.get('link', 'No URL') for r in all_search_results] | |
search_contents = await state['crawler'].fetch_page_contents( | |
urls, | |
sub_query, | |
state["session_id"], | |
max_attempts=1 | |
) | |
current_search_contents.extend(search_contents) | |
contents = user_context | |
if search_contents: | |
for k, c in enumerate(search_contents, 1): | |
if isinstance(c, Exception): | |
logger.info(f"Error fetching content: {c}") | |
elif c: | |
contents += f"[SOURCE {k} START]\n{c}\n[SOURCE {k} END]\n\n" | |
if len(contents.strip()) > 0: | |
await sse_queue.put(("task", (sub_query, "DONE"))) | |
else: | |
await sse_queue.put(("task", (sub_query, "FAILED"))) | |
return contents | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError): | |
await sse_queue.put(("task", (sub_query, "FAILED"))) | |
return "" | |
tasks = [] | |
if len(sub_queries) > 1 and sub_queries[0] != user_query: | |
for sub_query in sub_queries: | |
tasks.append(sub_query_task(sub_query)) | |
results = await asyncio.gather(*tasks) | |
end = time.time() | |
# Start with user-provided context | |
contents = user_context | |
# Add searched contents | |
contents += "\n\n".join(r for r in results if r.strip()) | |
unique_results = [] | |
seen = set() | |
for entry in current_search_results: | |
link = entry["link"] | |
if link not in seen: | |
seen.add(link) | |
unique_results.append(entry) | |
current_search_results = unique_results | |
current_search_contents = list(set(current_search_contents)) | |
if len(contents.strip()) > 0: | |
await sse_queue.put(("step", "Generating Response...")) | |
token_count = state['model'].get_num_tokens(contents) | |
if token_count > MAX_TOKENS_ALLOWED: | |
contents = await state['late_chunker'].chunker( | |
text=contents, | |
query=user_query, | |
max_tokens=MAX_TOKENS_ALLOWED | |
) | |
logger.info(f"Number of tokens in the answer: {token_count}") | |
logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(contents)}") | |
await sse_queue.put(("sources_read", len(current_search_contents))) | |
response = "" | |
chunk_counter = 1 | |
is_first_chunk = True | |
async for chunk in state['reasoner'].answer(user_query, contents): | |
if is_first_chunk: | |
await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds")) | |
is_first_chunk = False | |
await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) | |
response += chunk | |
chunk_counter += 1 | |
sources_for_answer = [] | |
for idx, (result, content) in enumerate(zip(current_search_results, current_search_contents), 1): | |
if content: # Only include if content was successfully fetched | |
sources_for_answer.append({ | |
"id": idx, | |
"title": result.get('title', 'No Title'), | |
"link": result.get('link', 'No URL') | |
}) | |
await sse_queue.put(("final_message", response)) | |
await sse_queue.put(("final_sources", json.dumps(sources_for_answer))) | |
SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) | |
SESSION_STORE["answer"] = response | |
SESSION_STORE["source_contents"] = contents | |
await sse_queue.put(("action", { | |
"name": "sources", | |
"payload": { | |
"search_results": current_search_results, | |
"search_contents": current_search_contents | |
} | |
})) | |
await sse_queue.put(("action", { | |
"name": "evaluate", | |
"payload": {"query": user_query, "contents": [contents], "response": response} | |
})) | |
await sse_queue.put(("complete", "done")) | |
else: | |
await sse_queue.put(("error", "No results found.")) | |
elif cat_lower == "super": | |
current_search_results = [] | |
current_search_contents = [] | |
await sse_queue.put(("step", "Thinking...")) | |
start = time.time() | |
main_query_intent = await state['query_processor'].get_query_intent(user_query) | |
sub_queries, _ = await state['query_processor'].decompose_query(user_query, main_query_intent) | |
await sse_queue.put(("step", "Searching...")) | |
async def sub_query_task(sub_query): | |
try: | |
async def sub_sub_query_task(sub_sub_query): | |
optimized_query = await state['search_engine'].generate_optimized_query(sub_sub_query) | |
search_results = await state['search_engine'].search( | |
optimized_query, | |
num_results=10, | |
exclude_filetypes=["pdf"] | |
) | |
filtered_urls = await state['search_engine'].filter_urls( | |
sub_sub_query, | |
category, | |
search_results | |
) | |
current_search_results.extend(filtered_urls) | |
urls = [r.get('link', 'No URL') for r in filtered_urls] | |
search_contents = await state['crawler'].fetch_page_contents( | |
urls, | |
sub_sub_query, | |
state["session_id"], | |
max_attempts=1, | |
timeout=20 | |
) | |
current_search_contents.extend(search_contents) | |
contents = "" | |
if search_contents: | |
for k, c in enumerate(search_contents, 1): | |
if isinstance(c, Exception): | |
logger.info(f"Error fetching content: {c}") | |
elif c: | |
contents += f"[SOURCE {k} START]\n{c}\n[SOURCE {k} END]\n\n" | |
return contents | |
await sse_queue.put(("task", (sub_query, "RUNNING"))) | |
sub_sub_queries, _ = await state['query_processor'].decompose_query(sub_query) | |
tasks = [] | |
if len(sub_sub_queries) > 1 and sub_sub_queries[0] != user_query: | |
for sub_sub_query in sub_sub_queries: | |
tasks.append(sub_sub_query_task(sub_sub_query)) | |
results = await asyncio.gather(*tasks) | |
if any(result.strip() for result in results): | |
await sse_queue.put(("task", (sub_query, "DONE"))) | |
else: | |
await sse_queue.put(("task", (sub_query, "FAILED"))) | |
return results | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError): | |
await sse_queue.put(("task", (sub_query, "FAILED"))) | |
return [] | |
tasks = [] | |
if len(sub_queries) > 1 and sub_queries[0] != user_query: | |
for sub_query in sub_queries: | |
tasks.append(sub_query_task(sub_query)) | |
results = await asyncio.gather(*tasks) | |
end = time.time() | |
# Start with user-provided context | |
previous_contents = [] | |
if user_context: | |
previous_contents.append(user_context) | |
for result in results: | |
if result: | |
for content in result: | |
if isinstance(content, str) and len(content.strip()) > 0: | |
previous_contents.append(content) | |
contents = "\n\n".join(previous_contents) | |
unique_results = [] | |
seen = set() | |
for entry in current_search_results: | |
link = entry["link"] | |
if link not in seen: | |
seen.add(link) | |
unique_results.append(entry) | |
current_search_results = unique_results | |
current_search_contents = list(set(current_search_contents)) | |
if len(contents.strip()) > 0: | |
await sse_queue.put(("step", "Generating Response...")) | |
token_count = state['model'].get_num_tokens(contents) | |
if token_count > MAX_TOKENS_ALLOWED: | |
contents = await state['late_chunker'].chunker( | |
text=contents, | |
query=user_query, | |
max_tokens=MAX_TOKENS_ALLOWED | |
) | |
logger.info(f"Number of tokens in the answer: {token_count}") | |
logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(contents)}") | |
await sse_queue.put(("sources_read", len(current_search_contents))) | |
response = "" | |
chunk_counter = 1 | |
is_first_chunk = True | |
async for chunk in state['reasoner'].answer(user_query, contents): | |
if is_first_chunk: | |
await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds")) | |
is_first_chunk = False | |
await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) | |
response += chunk | |
chunk_counter += 1 | |
sources_for_answer = [] | |
for idx, (result, content) in enumerate(zip(current_search_results, current_search_contents), 1): | |
if content: # Only include if content was successfully fetched | |
sources_for_answer.append({ | |
"id": idx, | |
"title": result.get('title', 'No Title'), | |
"link": result.get('link', 'No URL') | |
}) | |
await sse_queue.put(("final_message", response)) | |
await sse_queue.put(("final_sources", json.dumps(sources_for_answer))) | |
SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) | |
SESSION_STORE["answer"] = response | |
SESSION_STORE["source_contents"] = contents | |
await sse_queue.put(("action", { | |
"name": "sources", | |
"payload": { | |
"search_results": current_search_results, | |
"search_contents": current_search_contents | |
} | |
})) | |
await sse_queue.put(("action", { | |
"name": "evaluate", | |
"payload": {"query": user_query, "contents": [contents], "response": response} | |
})) | |
await sse_queue.put(("complete", "done")) | |
else: | |
await sse_queue.put(("error", "No results found.")) | |
elif cat_lower == "ultra": | |
current_search_results = [] | |
current_search_contents = [] | |
match = re.search( | |
r"^This is the previous context of the conversation:\s*.*?\s*Current Query:\s*(.*)$", | |
user_query, | |
flags=re.DOTALL | re.MULTILINE | |
) | |
if match: | |
user_query = match.group(1) | |
await sse_queue.put(("step", "Thinking...")) | |
await asyncio.sleep(0.01) # Sleep for a short time to allow the message to be sent | |
async def on_event_callback(event_type, data): | |
if event_type == "graph_operation": | |
if data["operation_type"] == "creating_new_graph": | |
await sse_queue.put(("step", "Creating New Graph...")) | |
elif data["operation_type"] == "modifying_existing_graph": | |
await sse_queue.put(("step", "Modifying Existing Graph...")) | |
elif data["operation_type"] == "loading_existing_graph": | |
await sse_queue.put(("step", "Loading Existing Graph...")) | |
elif event_type == "sub_query_created": | |
sub_query = data["sub_query"] | |
await sse_queue.put(("task", (sub_query, "RUNNING"))) | |
elif event_type == "search_process_started": | |
await sse_queue.put(("step", "Searching...")) | |
elif event_type == "sub_query_processed": | |
sub_query = data["sub_query"] | |
await sse_queue.put(("task", (sub_query, "DONE"))) | |
elif event_type == "sub_query_failed": | |
sub_query = data["sub_query"] | |
await sse_queue.put(("task", (sub_query, "FAILED"))) | |
elif event_type == "search_results_filtered": | |
current_search_results.extend(data["filtered_urls"]) | |
filtered_urls = data["filtered_urls"] | |
current_search_results.extend(filtered_urls) | |
elif event_type == "search_contents_fetched": | |
current_search_contents.extend(data["contents"]) | |
contents = data["contents"] | |
current_search_contents.extend(contents) | |
elif event_type == "search_process_completed": | |
await sse_queue.put(("step", "Processing final graph tasks...")) | |
await asyncio.sleep(0.01) # Sleep for a short time to allow the message to be sent | |
state['graph_rag'].set_on_event_callback(on_event_callback) | |
start = time.time() | |
# state['graph_rag'].initialize_schema() | |
await state['graph_rag'].process_graph( | |
user_query, | |
similarity_threshold=0.8, | |
relevance_threshold=0.8, | |
max_tokens_allowed=MAX_TOKENS_ALLOWED | |
) | |
end = time.time() | |
unique_results = [] | |
seen = set() | |
for entry in current_search_results: | |
link = entry["link"] | |
if link not in seen: | |
seen.add(link) | |
unique_results.append(entry) | |
current_search_results = unique_results | |
current_search_contents = list(set(current_search_contents)) | |
await sse_queue.put(("step", "Generating Response...")) | |
answer = state['graph_rag'].query_graph(user_query) | |
if answer: | |
# Start with user-provided context | |
previous_contents = [] | |
if user_context: | |
previous_contents.append(user_context) | |
token_count = state['model'].get_num_tokens(answer) | |
if token_count > MAX_TOKENS_ALLOWED: | |
answer = await state['late_chunker'].chunker( | |
text=answer, | |
query=user_query, | |
max_tokens=MAX_TOKENS_ALLOWED | |
) | |
logger.info(f"Number of tokens in the answer: {token_count}") | |
logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(answer)}") | |
await sse_queue.put(("sources_read", len(current_search_contents))) | |
response = "" | |
chunk_counter = 1 | |
is_first_chunk = True | |
async for chunk in state['reasoner'].answer(user_query, answer): | |
if is_first_chunk: | |
await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds")) | |
is_first_chunk = False | |
await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) | |
response += chunk | |
chunk_counter += 1 | |
sources_for_answer = [] | |
for idx, (result, content) in enumerate(zip(current_search_results, current_search_contents), 1): | |
if content: # Only include if content was successfully fetched | |
sources_for_answer.append({ | |
"id": idx, | |
"title": result.get('title', 'No Title'), | |
"link": result.get('link', 'No URL') | |
}) | |
await sse_queue.put(("final_message", response)) | |
await sse_queue.put(("final_sources", json.dumps(sources_for_answer))) | |
SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) | |
SESSION_STORE["answer"] = response | |
SESSION_STORE["source_contents"] = contents | |
await sse_queue.put(("action", { | |
"name": "sources", | |
"payload": {"search_results": current_search_results, "search_contents": current_search_contents}, | |
})) | |
await sse_queue.put(("action", { | |
"name": "graph", | |
"payload": {"query": user_query}, | |
})) | |
await sse_queue.put(("action", { | |
"name": "evaluate", | |
"payload": {"query": user_query, "contents": [answer], "response": response}, | |
})) | |
await sse_queue.put(("complete", "done")) | |
else: | |
await sse_queue.put(("error", "No results found.")) | |
else: | |
await sse_queue.put(("final_message", "I'm not sure how to handle your query.")) | |
except Exception as e: | |
await sse_queue.put(("error", str(e))) | |
traceback.print_exc() | |
stop() | |
# Create a FastAPI app | |
app = FastAPI() | |
# Define allowed origins | |
origins = [ | |
"http://localhost:3000", | |
"http://localhost:7860" | |
"http://localhost:8000", | |
"http://localhost" | |
] | |
# Add the CORS middleware to your FastAPI app | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, # Allows only these origins | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.) | |
allow_headers=["*"], # Allows all headers | |
) | |
# Serve the React app (the production build) at the root URL. | |
app.mount("/static", StaticFiles(directory="frontend/build/static", html=True), name="static") | |
# Define the routes for the FastAPI app | |
# Define the route for sources action to display search results | |
def action_sources(payload: Dict[str, Any]) -> Dict[str, Any]: | |
try: | |
search_contents = payload.get("search_contents", []) | |
search_results = payload.get("search_results", []) | |
sources = [] | |
word_limit = 15 # Maximum number of words for the description | |
for result, contents in zip(search_results, search_contents): | |
if contents: | |
title = result.get('title', 'No Title') | |
link = result.get('link', 'No URL') | |
snippet = result.get('snippet', 'No snippet') | |
cleaned = re.sub(r'<[^>]+>|\[\/?.*?\]', '', snippet) | |
words = cleaned.split() | |
if len(words) > word_limit: | |
description = " ".join(words[:word_limit]) + "..." | |
else: | |
description = " ".join(words) | |
source_obj = { | |
"title": title, | |
"link": link, | |
"description": description | |
} | |
sources.append(source_obj) | |
return {"result": sources} | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
# Define the route for graph action to display the graph | |
def action_graph() -> Dict[str, Any]: | |
state = SESSION_STORE | |
try: | |
html_str = state['graph_rag'].display_graph() | |
return {"result": html_str} | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
# Define the route for evaluate action to display evaluation results | |
async def action_evaluate(payload: Dict[str, Any]) -> Dict[str, Any]: | |
try: | |
query = payload.get("query", "") | |
contents = payload.get("contents", []) | |
response = payload.get("response", "") | |
metrics = payload.get("metrics", []) | |
state = SESSION_STORE | |
evaluator = state["evaluator"] | |
result = await evaluator.evaluate_response(query, response, contents, include_metrics=metrics) | |
return {"result": result} | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
# Define the route for excerpts action to display excerpts from the sources | |
async def action_excerpts() -> Dict[str, Any]: | |
def validate_excerpts_format(excerpts): | |
if not isinstance(excerpts, list): | |
return False | |
for item in excerpts: | |
if not isinstance(item, dict): | |
return False | |
for statement, sources in item.items(): | |
if not isinstance(statement, str) or not isinstance(sources, dict): | |
return False | |
for src_num, excerpt in sources.items(): | |
if not (isinstance(src_num, int) or isinstance(src_num, str)): | |
return False | |
if not isinstance(excerpt, str): | |
return False | |
return True | |
try: | |
state = SESSION_STORE | |
response = state["answer"] | |
contents = state["source_contents"] | |
if not response or not contents: | |
raise ValueError("Required data for excerpts not found") | |
excerpts_list = await state["reasoner"].get_excerpts(response, contents) | |
cleaned_excerpts = re.sub( | |
r'```[\w\s]*\n?|```|~~~[\w\s]*\n?|~~~', '', excerpts_list, flags=re.MULTILINE | re.DOTALL | |
).strip() | |
try: | |
excerpts = eval(cleaned_excerpts) | |
except Exception: | |
print(f"Error parsing excerpts:\n{cleaned_excerpts}") | |
raise ValueError("Excerpts could not be parsed as a Python list.") | |
if not validate_excerpts_format(excerpts): | |
print(f"Excerpts format validation failed:\n{excerpts}") | |
raise ValueError("Excerpts are not in the required format.") | |
print(f"Excerpts:\n{excerpts}") | |
return {"result": excerpts} | |
except Exception as e: | |
print(f"Error in action_excerpts: {e}") | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
# Define the route for settings to set or update the environment variables | |
async def update_settings(data: Dict[str, Any]): | |
from src.helpers.helper import ( | |
prepare_provider_key_updates, | |
prepare_proxy_list_updates, | |
update_env_vars | |
) | |
provider = data.get("Model_Provider", "").strip() | |
model_name = data.get("Model_Name", "").strip() | |
multiple_api_keys = data.get("Model_API_Keys", "").strip() | |
brave_api_key = data.get("Brave_Search_API_Key", "").strip() | |
proxy_list = data.get("Proxy_List", "").strip() | |
model_temperature = str(data.get("Model_Temperature", 0.0)) | |
model_top_p = str(data.get("Model_Top_P", 1.0)) | |
prov_lower = provider.lower() | |
key_updates = prepare_provider_key_updates(prov_lower, multiple_api_keys) | |
env_updates = {} | |
env_updates.update(key_updates) | |
px = prepare_proxy_list_updates(proxy_list) | |
if px: | |
env_updates.update(px) | |
env_updates["BRAVE_API_KEY"] = brave_api_key | |
env_updates["MODEL_PROVIDER"] = prov_lower | |
env_updates["MODEL_NAME"] = model_name | |
env_updates["MODEL_TEMPERATURE"] = model_temperature | |
env_updates["MODEL_TOP_P"] = model_top_p | |
update_env_vars(env_updates) | |
load_dotenv(override=True) | |
await initialize_components() | |
return {"success": True} | |
# Define the route for adding/uploading content for a specific session | |
async def add_content(files: Optional[List[UploadFile]] = File(None), urls: str = Form(...)): | |
state = SESSION_STORE | |
session_id = state.get("session_id") | |
if not session_id: | |
raise HTTPException(status_code=400, detail="Session ID is not set. Please start a session first.") | |
session_upload_path = os.path.join(UPLOAD_DIRECTORY, session_id) | |
os.makedirs(session_upload_path, exist_ok=True) | |
saved_filenames = [] | |
if files: | |
total_new_files_size = sum(file.size for file in files) | |
current_folder_size = get_folder_size(session_upload_path) | |
# Check if the total size exceeds the maximum allowed folder size | |
if current_folder_size + total_new_files_size > MAX_FOLDER_SIZE: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Cannot add files as total storage would exceed 10 MB. Current size: {current_folder_size / (1024 * 1024):.2f} MB" | |
) | |
for file in files: | |
file_path = os.path.join(session_upload_path, file.filename) | |
try: | |
with open(file_path, "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
saved_filenames.append(file.filename) | |
finally: | |
file.file.close() | |
try: | |
parsed_urls = json.loads(urls) | |
print(f"Received links: {parsed_urls}") | |
except json.JSONDecodeError: | |
raise HTTPException(status_code=400, detail="Invalid URL format.") | |
# Store user-provided links in session | |
if parsed_urls: | |
SESSION_STORE["user_provided_links"] = parsed_urls | |
return { | |
"message": "Content added successfully", | |
"files_added": saved_filenames, | |
"links_added": parsed_urls | |
} | |
# Define the route to update the selected services for searching | |
async def update_selected_services(data: Dict[str, Any]): | |
state = SESSION_STORE | |
selected_services = data.get("services", {}) | |
state["selected_services"] = selected_services | |
logger.info(f"Updated selected services: {selected_services}") | |
return {"success": True, "services": selected_services} | |
# Define the route to receive OAuth tokens from the frontend | |
async def receive_session_token(data: Dict[str, Any]): | |
provider = data.get("provider") # 'google', 'microsoft', 'slack' | |
token = data.get("token") | |
if not provider or not token: | |
raise HTTPException(status_code=400, detail="Provider and token are required") | |
SESSION_STORE["oauth_tokens"][provider] = { | |
"token": token, | |
"timestamp": time.time() | |
} | |
logger.info(f"Stored token {token} for provider {provider}") | |
return {"success": True, "message": f"{provider} token stored successfully"} | |
# Define the route for cleaning up a session if the session ID matches | |
async def cleanup_session(): | |
state = SESSION_STORE | |
session_id = state.get("session_id") | |
if not session_id: | |
raise HTTPException(status_code=400, detail="Session ID is not set. Please start a session first.") | |
session_upload_path = os.path.join(UPLOAD_DIRECTORY, session_id) | |
if session_id: | |
# Clear the session upload directory | |
clear_folder(session_upload_path) | |
# Clear user-provided links and caches | |
SESSION_STORE["user_provided_links"] = [] | |
SESSION_STORE["user_files_cache"] = {} | |
SESSION_STORE["user_links_cache"] = {} | |
SESSION_STORE["selected_services"] = {} | |
SESSION_STORE["oauth_tokens"] = {} | |
return {"message": "Cleanup successful."} | |
return {"message": "No session ID provided, cleanup skipped."} | |
def init_chat(): | |
if not SESSION_STORE: | |
print("Initializing chat...") | |
# Create the upload directory if it doesn't exist | |
print("Creating upload directory...") | |
os.makedirs(UPLOAD_DIRECTORY, exist_ok=True) | |
# Initialize the session store | |
SESSION_STORE["settings_saved"] = False | |
SESSION_STORE["session_id"] = None | |
SESSION_STORE["answer"] = None | |
SESSION_STORE["source_contents"] = None | |
SESSION_STORE["chat_history"] = [] | |
SESSION_STORE["user_provided_links"] = [] | |
SESSION_STORE["user_files_cache"] = {} | |
SESSION_STORE["user_links_cache"] = {} | |
SESSION_STORE["selected_services"] = {} | |
SESSION_STORE["oauth_tokens"] = {} | |
print("Chat initialized!") | |
return {"sucess": True} | |
else: | |
print("Chat already initialized!") | |
return {"success": False} | |
async def sse_message(request: Request, user_message: str): | |
state = SESSION_STORE | |
sse_queue = asyncio.Queue() | |
async def event_generator(): | |
# Build the prompt | |
context = state["chat_history"][-3:] | |
if context: | |
prompt = \ | |
f"""This is the previous context of the conversation: | |
{context} | |
Current Query: | |
{user_message}""" | |
else: | |
prompt = user_message | |
task = asyncio.create_task(process_query(prompt, sse_queue)) | |
state["process_task"] = task | |
while True: | |
if await request.is_disconnected(): | |
task.cancel() | |
break | |
try: | |
event_type, data = await asyncio.wait_for(sse_queue.get(), timeout=5) | |
if event_type == "token": | |
yield f"event: token\ndata: {data}\n\n" | |
elif event_type == "final_message": | |
yield f"event: final_message\ndata: {data}\n\n" | |
elif event_type == "error": | |
stop_on_error() | |
yield format_error_sse("error", data) | |
elif event_type == "step": | |
yield f"event: step\ndata: {data}\n\n" | |
elif event_type == "task": | |
subq, status = data | |
j = {"task": subq, "status": status} | |
yield f"event: task\ndata: {json.dumps(j)}\n\n" | |
elif event_type == "sources_read": | |
yield f"event: sources_read\ndata: {data}\n\n" | |
elif event_type == "action": | |
yield f"event: action\ndata: {json.dumps(data)}\n\n" | |
elif event_type == "complete": | |
yield f"event: complete\ndata: {data}\n\n" | |
break | |
else: | |
yield f"event: message\ndata: {data}\n\n" | |
except asyncio.TimeoutError: | |
if task.done(): | |
break | |
continue | |
except asyncio.CancelledError: | |
break | |
if not task.done(): | |
task.cancel() | |
if "process_task" in state: | |
del state["process_task"] | |
return StreamingResponse(event_generator(), media_type="text/event-stream") | |
def stop(): | |
state = SESSION_STORE | |
if "process_task" in state: | |
state["process_task"].cancel() | |
del state["process_task"] | |
return {"message": "Stopped task manually"} | |
# Catch-all route for frontend paths. | |
async def serve_frontend(full_path: str, request: Request): | |
index_path = os.path.join("frontend", "build", "index.html") | |
if not os.path.exists(index_path): | |
raise HTTPException(status_code=500, detail="Frontend build not found") | |
return FileResponse(index_path) |