Spaces:
Sleeping
Sleeping
Upload 39 files
Browse files- .gitignore +1 -0
- Dockerfile +74 -0
- LLM/__init__.py +13 -0
- LLM/image_answerer.py +136 -0
- LLM/llm_handler.py +235 -0
- LLM/one_shotter.py +218 -0
- LLM/tabular_answer.py +128 -0
- RAG/__init__.py +1 -0
- RAG/advanced_rag_processor.py +169 -0
- RAG/rag_embeddings/.gitkeep +0 -0
- RAG/rag_modules/__init__.py +1 -0
- RAG/rag_modules/answer_generator.py +97 -0
- RAG/rag_modules/context_manager.py +81 -0
- RAG/rag_modules/embedding_manager.py +42 -0
- RAG/rag_modules/query_expansion.py +146 -0
- RAG/rag_modules/reranking_manager.py +63 -0
- RAG/rag_modules/search_manager.py +334 -0
- api/__init__.py +1 -0
- api/api.py +498 -0
- config/__init__.py +1 -0
- config/config.py +142 -0
- logger/__init__.py +1 -0
- logger/logger.py +340 -0
- preprocessing/__init__.py +23 -0
- preprocessing/preprocessing.py +63 -0
- preprocessing/preprocessing_modules/__init__.py +29 -0
- preprocessing/preprocessing_modules/docx_extractor.py +94 -0
- preprocessing/preprocessing_modules/embedding_manager.py +118 -0
- preprocessing/preprocessing_modules/file_downloader.py +108 -0
- preprocessing/preprocessing_modules/image_extractor.py +120 -0
- preprocessing/preprocessing_modules/metadata_manager.py +262 -0
- preprocessing/preprocessing_modules/modular_preprocessor.py +290 -0
- preprocessing/preprocessing_modules/pdf_downloader.py +112 -0
- preprocessing/preprocessing_modules/pptx_extractor.py +118 -0
- preprocessing/preprocessing_modules/text_chunker.py +167 -0
- preprocessing/preprocessing_modules/text_extractor.py +62 -0
- preprocessing/preprocessing_modules/vector_storage.py +212 -0
- preprocessing/preprocessing_modules/xlsx_extractor.py +119 -0
- requirements.txt +33 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.env
|
Dockerfile
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9-slim
|
2 |
+
|
3 |
+
# Set working directory
|
4 |
+
WORKDIR /app
|
5 |
+
|
6 |
+
# Install system dependencies for multi-format document processing
|
7 |
+
RUN apt-get update && apt-get install -y \
|
8 |
+
curl \
|
9 |
+
libgl1-mesa-glx \
|
10 |
+
libglib2.0-0 \
|
11 |
+
libsm6 \
|
12 |
+
libxext6 \
|
13 |
+
libxrender-dev \
|
14 |
+
libgomp1 \
|
15 |
+
libgstreamer1.0-0 \
|
16 |
+
libavcodec-dev \
|
17 |
+
libavformat-dev \
|
18 |
+
libswscale-dev \
|
19 |
+
&& rm -rf /var/lib/apt/lists/*
|
20 |
+
|
21 |
+
# Create cache directories with proper permissions before copying files
|
22 |
+
RUN mkdir -p /app/.cache/huggingface && \
|
23 |
+
mkdir -p /app/.cache/sentence_transformers && \
|
24 |
+
chmod -R 777 /app/.cache
|
25 |
+
|
26 |
+
# Set environment variables for HuggingFace cache (before installing packages)
|
27 |
+
ENV HF_HOME=/app/.cache/huggingface
|
28 |
+
ENV TRANSFORMERS_CACHE=/app/.cache/huggingface
|
29 |
+
ENV SENTENCE_TRANSFORMERS_HOME=/app/.cache/sentence_transformers
|
30 |
+
ENV HF_HUB_CACHE=/app/.cache/huggingface
|
31 |
+
|
32 |
+
# Copy requirements first for better Docker layer caching
|
33 |
+
COPY requirements.txt .
|
34 |
+
|
35 |
+
# Install Python dependencies
|
36 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
37 |
+
|
38 |
+
# Copy application files AND folder structure
|
39 |
+
COPY *.py ./
|
40 |
+
COPY *.sh ./
|
41 |
+
COPY requirements.txt ./
|
42 |
+
COPY *.md ./
|
43 |
+
COPY api/ ./api/
|
44 |
+
COPY config/ ./config/
|
45 |
+
COPY LLM/ ./LLM/
|
46 |
+
COPY RAG/ ./RAG/
|
47 |
+
COPY logger/ ./logger/
|
48 |
+
COPY preprocessing/ ./preprocessing/
|
49 |
+
|
50 |
+
# Set up directories and permissions during build (when we have root access)
|
51 |
+
RUN mkdir -p /app/.cache/huggingface && \
|
52 |
+
mkdir -p /app/.cache/sentence_transformers && \
|
53 |
+
chmod -R 777 /app/.cache && \
|
54 |
+
chmod +x startup.sh && \
|
55 |
+
if [ -d "RAG/rag_embeddings" ]; then \
|
56 |
+
find RAG/rag_embeddings -name "*.lock" -delete 2>/dev/null || true; \
|
57 |
+
find RAG/rag_embeddings -name "*.db-shm" -delete 2>/dev/null || true; \
|
58 |
+
find RAG/rag_embeddings -name "*.db-wal" -delete 2>/dev/null || true; \
|
59 |
+
chmod -R 755 RAG/rag_embeddings; \
|
60 |
+
fi
|
61 |
+
|
62 |
+
# Expose port
|
63 |
+
EXPOSE 7860
|
64 |
+
|
65 |
+
# Set environment variables
|
66 |
+
ENV HOST=0.0.0.0
|
67 |
+
ENV PORT=7860
|
68 |
+
|
69 |
+
# Health check
|
70 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=60s --retries=3 \
|
71 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
72 |
+
|
73 |
+
# Run the application
|
74 |
+
CMD ["bash", "startup.sh"]
|
LLM/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LLM Handler Package
|
2 |
+
|
3 |
+
from .llm_handler import llm_handler
|
4 |
+
from .tabular_answer import get_answer_for_tabluar
|
5 |
+
from .image_answerer import get_answer_for_image
|
6 |
+
from .one_shotter import get_oneshot_answer
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
'llm_handler',
|
10 |
+
'get_answer_for_tabluar',
|
11 |
+
'get_answer_for_image',
|
12 |
+
'get_oneshot_answer'
|
13 |
+
]
|
LLM/image_answerer.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
import google.generativeai as genai
|
4 |
+
from PIL import Image
|
5 |
+
from io import BytesIO
|
6 |
+
from typing import List, Union
|
7 |
+
import logging
|
8 |
+
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
|
11 |
+
# Set up logging
|
12 |
+
logging.basicConfig(level=logging.INFO)
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
load_dotenv()
|
16 |
+
|
17 |
+
# Configure Gemini API for image processing
|
18 |
+
genai.configure(api_key=os.getenv("GEMINI_API_KEY_IMAGE"))
|
19 |
+
|
20 |
+
def load_image(image_source: str) -> Image.Image:
|
21 |
+
"""Load image from a URL or local path."""
|
22 |
+
try:
|
23 |
+
if image_source.startswith("http://") or image_source.startswith("https://"):
|
24 |
+
logger.info(f"Loading image from URL: {image_source}")
|
25 |
+
response = requests.get(image_source, timeout=30)
|
26 |
+
response.raise_for_status()
|
27 |
+
return Image.open(BytesIO(response.content)).convert("RGB")
|
28 |
+
elif os.path.isfile(image_source):
|
29 |
+
logger.info(f"Loading image from file: {image_source}")
|
30 |
+
return Image.open(image_source).convert("RGB")
|
31 |
+
else:
|
32 |
+
raise ValueError("Invalid image source: must be a valid URL or file path")
|
33 |
+
except Exception as e:
|
34 |
+
logger.error(f"Failed to load image from {image_source}: {e}")
|
35 |
+
raise RuntimeError(f"Failed to load image: {e}")
|
36 |
+
|
37 |
+
def get_answer_for_image(image_source: str, questions: List[str], retries: int = 3) -> List[str]:
|
38 |
+
"""Ask questions about an image using Gemini Vision model."""
|
39 |
+
try:
|
40 |
+
logger.info(f"Processing image with {len(questions)} questions")
|
41 |
+
image = load_image(image_source)
|
42 |
+
|
43 |
+
prompt = """
|
44 |
+
Answer the following questions about the image. Give the answers in the same order as the questions.
|
45 |
+
Answers should be descriptive and detailed. Give one answer per line with numbering as "1. 2. 3. ..".
|
46 |
+
|
47 |
+
Example answer format:
|
48 |
+
1. Answer 1, Explanation
|
49 |
+
2. Answer 2, Explanation
|
50 |
+
3. Answer 3, Explanation
|
51 |
+
|
52 |
+
Questions:
|
53 |
+
"""
|
54 |
+
prompt += "\n".join(f"{i+1}. {q}" for i, q in enumerate(questions))
|
55 |
+
|
56 |
+
model = genai.GenerativeModel("gemini-1.5-flash")
|
57 |
+
|
58 |
+
for attempt in range(retries):
|
59 |
+
try:
|
60 |
+
logger.info(f"Attempt {attempt + 1} of {retries} to get response from Gemini")
|
61 |
+
response = model.generate_content(
|
62 |
+
[prompt, image],
|
63 |
+
generation_config=genai.types.GenerationConfig(
|
64 |
+
temperature=0.4,
|
65 |
+
max_output_tokens=2048
|
66 |
+
)
|
67 |
+
)
|
68 |
+
raw_text = response.text.strip()
|
69 |
+
logger.info(f"Received response from Gemini: {len(raw_text)} characters")
|
70 |
+
|
71 |
+
answers = extract_ordered_answers(raw_text, len(questions))
|
72 |
+
if len(answers) == len(questions):
|
73 |
+
logger.info(f"Successfully extracted {len(answers)} answers")
|
74 |
+
return answers
|
75 |
+
else:
|
76 |
+
logger.warning(f"Expected {len(questions)} answers, got {len(answers)}")
|
77 |
+
|
78 |
+
except Exception as e:
|
79 |
+
logger.error(f"Attempt {attempt + 1} failed: {e}")
|
80 |
+
if attempt == retries - 1:
|
81 |
+
raise RuntimeError(f"Failed after {retries} attempts: {e}")
|
82 |
+
|
83 |
+
raise RuntimeError("Failed to get valid response from Gemini.")
|
84 |
+
|
85 |
+
except Exception as e:
|
86 |
+
logger.error(f"Error in get_answer_for_image: {e}")
|
87 |
+
raise
|
88 |
+
|
89 |
+
def extract_ordered_answers(raw_text: str, expected_count: int) -> List[str]:
|
90 |
+
"""Parse the raw Gemini output into a clean list of answers."""
|
91 |
+
import re
|
92 |
+
|
93 |
+
logger.debug(f"Extracting {expected_count} answers from raw text")
|
94 |
+
lines = raw_text.splitlines()
|
95 |
+
answers = []
|
96 |
+
|
97 |
+
for line in lines:
|
98 |
+
# Match numbered lines: "1. Answer", "1) Answer", "1 - Answer", etc.
|
99 |
+
match = re.match(r"^\s*(\d+)[\).\s-]*\s*(.+)", line)
|
100 |
+
if match:
|
101 |
+
answer_text = match.group(2).strip()
|
102 |
+
if answer_text: # Only add non-empty answers
|
103 |
+
answers.append(answer_text)
|
104 |
+
|
105 |
+
# Fallback: if numbering failed, use plain lines
|
106 |
+
if len(answers) < expected_count:
|
107 |
+
logger.warning("Numbered extraction failed, using fallback method")
|
108 |
+
answers = [line.strip() for line in lines if line.strip()]
|
109 |
+
|
110 |
+
# Return exactly the expected number of answers
|
111 |
+
result = answers[:expected_count]
|
112 |
+
|
113 |
+
# If we still don't have enough answers, pad with error messages
|
114 |
+
while len(result) < expected_count:
|
115 |
+
result.append("Unable to extract answer from image")
|
116 |
+
|
117 |
+
logger.info(f"Extracted {len(result)} answers")
|
118 |
+
return result
|
119 |
+
|
120 |
+
def process_image_query(image_path: str, query: str) -> str:
|
121 |
+
"""Process a single query about an image."""
|
122 |
+
try:
|
123 |
+
questions = [query]
|
124 |
+
answers = get_answer_for_image(image_path, questions)
|
125 |
+
return answers[0] if answers else "Unable to process image query"
|
126 |
+
except Exception as e:
|
127 |
+
logger.error(f"Error processing image query: {e}")
|
128 |
+
return f"Error processing image: {str(e)}"
|
129 |
+
|
130 |
+
def process_multiple_image_queries(image_path: str, queries: List[str]) -> List[str]:
|
131 |
+
"""Process multiple queries about an image."""
|
132 |
+
try:
|
133 |
+
return get_answer_for_image(image_path, queries)
|
134 |
+
except Exception as e:
|
135 |
+
logger.error(f"Error processing multiple image queries: {e}")
|
136 |
+
return [f"Error processing image: {str(e)}"] * len(queries)
|
LLM/llm_handler.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Multi-LLM Handler with failover support
|
3 |
+
Uses Groq, Gemini, and OpenAI with automatic failover for reliability
|
4 |
+
"""
|
5 |
+
|
6 |
+
import asyncio
|
7 |
+
import re
|
8 |
+
import time
|
9 |
+
from typing import Optional, Dict, Any, List
|
10 |
+
import os
|
11 |
+
import requests
|
12 |
+
import google.generativeai as genai
|
13 |
+
import openai
|
14 |
+
from dotenv import load_dotenv
|
15 |
+
from config.config import get_provider_configs
|
16 |
+
|
17 |
+
load_dotenv()
|
18 |
+
|
19 |
+
class MultiLLMHandler:
|
20 |
+
"""Multi-LLM handler with automatic failover across providers."""
|
21 |
+
|
22 |
+
def __init__(self):
|
23 |
+
"""Initialize the multi-LLM handler with all available providers."""
|
24 |
+
self.providers = get_provider_configs()
|
25 |
+
self.current_provider = None
|
26 |
+
self.current_config = None
|
27 |
+
|
28 |
+
# Initialize the first available provider (prefer Gemini/OpenAI for general RAG)
|
29 |
+
self._initialize_provider()
|
30 |
+
|
31 |
+
print(f"✅ Initialized Multi-LLM Handler with {self.provider.upper()}: {self.model_name}")
|
32 |
+
|
33 |
+
def _initialize_provider(self):
|
34 |
+
"""Initialize the first available provider."""
|
35 |
+
# Prefer Gemini first for general text tasks
|
36 |
+
if self.providers["gemini"]:
|
37 |
+
self.current_provider = "gemini"
|
38 |
+
self.current_config = self.providers["gemini"][0]
|
39 |
+
genai.configure(api_key=self.current_config["api_key"])
|
40 |
+
# Then OpenAI
|
41 |
+
elif self.providers["openai"]:
|
42 |
+
self.current_provider = "openai"
|
43 |
+
self.current_config = self.providers["openai"][0]
|
44 |
+
openai.api_key = self.current_config["api_key"]
|
45 |
+
# Finally Groq
|
46 |
+
elif self.providers["groq"]:
|
47 |
+
self.current_provider = "groq"
|
48 |
+
self.current_config = self.providers["groq"][0]
|
49 |
+
else:
|
50 |
+
raise ValueError("No LLM providers available with valid API keys")
|
51 |
+
|
52 |
+
@property
|
53 |
+
def provider(self):
|
54 |
+
"""Get current provider name."""
|
55 |
+
return self.current_provider
|
56 |
+
|
57 |
+
@property
|
58 |
+
def model_name(self):
|
59 |
+
"""Get current model name."""
|
60 |
+
return self.current_config["model"] if self.current_config else "unknown"
|
61 |
+
|
62 |
+
async def _call_groq(self, prompt: str, temperature: float, max_tokens: int) -> str:
|
63 |
+
"""Call Groq API."""
|
64 |
+
headers = {
|
65 |
+
"Authorization": f"Bearer {self.current_config['api_key']}",
|
66 |
+
"Content-Type": "application/json"
|
67 |
+
}
|
68 |
+
|
69 |
+
data = {
|
70 |
+
"model": self.current_config["model"],
|
71 |
+
"messages": [{"role": "user", "content": prompt}],
|
72 |
+
"temperature": temperature,
|
73 |
+
"max_tokens": max_tokens
|
74 |
+
}
|
75 |
+
|
76 |
+
# Hide reasoning tokens (e.g., <think>) for Qwen reasoning models
|
77 |
+
try:
|
78 |
+
model_name = (self.current_config.get("model") or "").lower()
|
79 |
+
if "qwen" in model_name:
|
80 |
+
# Per request, use the chat completion parameter to hide reasoning content
|
81 |
+
data["reasoning_effort"] = "hidden"
|
82 |
+
except Exception:
|
83 |
+
# Be resilient if config shape changes
|
84 |
+
pass
|
85 |
+
|
86 |
+
response = requests.post(
|
87 |
+
"https://api.groq.com/openai/v1/chat/completions",
|
88 |
+
headers=headers,
|
89 |
+
json=data,
|
90 |
+
timeout=30
|
91 |
+
)
|
92 |
+
response.raise_for_status()
|
93 |
+
|
94 |
+
result = response.json()
|
95 |
+
text = result["choices"][0]["message"]["content"].strip()
|
96 |
+
# Safety net: strip any <think>...</think> blocks if present
|
97 |
+
try:
|
98 |
+
model_name = (self.current_config.get("model") or "").lower()
|
99 |
+
if "qwen" in model_name and "<think>" in text.lower():
|
100 |
+
text = re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE).strip()
|
101 |
+
except Exception:
|
102 |
+
pass
|
103 |
+
return text
|
104 |
+
|
105 |
+
async def _call_gemini(self, prompt: str, temperature: float, max_tokens: int) -> str:
|
106 |
+
"""Call Gemini API."""
|
107 |
+
model = genai.GenerativeModel(self.current_config["model"])
|
108 |
+
|
109 |
+
generation_config = genai.types.GenerationConfig(
|
110 |
+
temperature=temperature,
|
111 |
+
max_output_tokens=max_tokens
|
112 |
+
)
|
113 |
+
|
114 |
+
response = await asyncio.to_thread(
|
115 |
+
model.generate_content,
|
116 |
+
prompt,
|
117 |
+
generation_config=generation_config
|
118 |
+
)
|
119 |
+
return response.text.strip()
|
120 |
+
|
121 |
+
async def _call_openai(self, prompt: str, temperature: float, max_tokens: int) -> str:
|
122 |
+
"""Call OpenAI API."""
|
123 |
+
response = await asyncio.to_thread(
|
124 |
+
openai.ChatCompletion.create,
|
125 |
+
model=self.current_config["model"],
|
126 |
+
messages=[{"role": "user", "content": prompt}],
|
127 |
+
temperature=temperature,
|
128 |
+
max_tokens=max_tokens
|
129 |
+
)
|
130 |
+
return response.choices[0].message.content.strip()
|
131 |
+
|
132 |
+
async def _try_with_failover(self, prompt: str, temperature: float, max_tokens: int) -> str:
|
133 |
+
"""Try to generate text with automatic failover."""
|
134 |
+
# Get all available providers in order
|
135 |
+
provider_order = []
|
136 |
+
# Prefer Gemini -> OpenAI -> Groq for general text
|
137 |
+
if self.providers["gemini"]:
|
138 |
+
provider_order.extend([("gemini", config) for config in self.providers["gemini"]])
|
139 |
+
if self.providers["openai"]:
|
140 |
+
provider_order.extend([("openai", config) for config in self.providers["openai"]])
|
141 |
+
if self.providers["groq"]:
|
142 |
+
provider_order.extend([("groq", config) for config in self.providers["groq"]])
|
143 |
+
|
144 |
+
last_error = None
|
145 |
+
|
146 |
+
for provider_name, config in provider_order:
|
147 |
+
try:
|
148 |
+
# Set current provider
|
149 |
+
old_provider = self.current_provider
|
150 |
+
old_config = self.current_config
|
151 |
+
|
152 |
+
self.current_provider = provider_name
|
153 |
+
self.current_config = config
|
154 |
+
|
155 |
+
# Configure API if needed
|
156 |
+
if provider_name == "gemini":
|
157 |
+
genai.configure(api_key=config["api_key"])
|
158 |
+
elif provider_name == "openai":
|
159 |
+
openai.api_key = config["api_key"]
|
160 |
+
|
161 |
+
# Try the API call
|
162 |
+
if provider_name == "groq":
|
163 |
+
return await self._call_groq(prompt, temperature, max_tokens)
|
164 |
+
elif provider_name == "gemini":
|
165 |
+
return await self._call_gemini(prompt, temperature, max_tokens)
|
166 |
+
elif provider_name == "openai":
|
167 |
+
return await self._call_openai(prompt, temperature, max_tokens)
|
168 |
+
|
169 |
+
except Exception as e:
|
170 |
+
print(f"⚠️ {provider_name.upper()} ({config['name']}) failed: {str(e)}")
|
171 |
+
last_error = e
|
172 |
+
|
173 |
+
# Restore previous provider
|
174 |
+
self.current_provider = old_provider
|
175 |
+
self.current_config = old_config
|
176 |
+
continue
|
177 |
+
|
178 |
+
# If all providers failed
|
179 |
+
raise RuntimeError(f"All LLM providers failed. Last error: {last_error}")
|
180 |
+
|
181 |
+
async def generate_text(self,
|
182 |
+
prompt: Optional[str] = None,
|
183 |
+
system_prompt: Optional[str] = None,
|
184 |
+
user_prompt: Optional[str] = None,
|
185 |
+
temperature: Optional[float] = 0.4,
|
186 |
+
max_tokens: Optional[int] = 1200) -> str:
|
187 |
+
"""Generate text using multi-LLM with failover."""
|
188 |
+
# Handle both single prompt and system/user prompt formats
|
189 |
+
if prompt:
|
190 |
+
final_prompt = prompt
|
191 |
+
elif system_prompt and user_prompt:
|
192 |
+
final_prompt = f"{system_prompt}\n\n{user_prompt}"
|
193 |
+
elif user_prompt:
|
194 |
+
final_prompt = user_prompt
|
195 |
+
else:
|
196 |
+
raise ValueError("Must provide either 'prompt' or 'user_prompt'")
|
197 |
+
|
198 |
+
return await self._try_with_failover(
|
199 |
+
final_prompt,
|
200 |
+
temperature or 0.4,
|
201 |
+
max_tokens or 1200
|
202 |
+
)
|
203 |
+
|
204 |
+
async def generate_simple(self,
|
205 |
+
prompt: str,
|
206 |
+
temperature: Optional[float] = 0.4,
|
207 |
+
max_tokens: Optional[int] = 1200) -> str:
|
208 |
+
"""Simple text generation (alias for generate_text for compatibility)."""
|
209 |
+
return await self.generate_text(prompt=prompt, temperature=temperature, max_tokens=max_tokens)
|
210 |
+
|
211 |
+
def get_provider_info(self) -> Dict[str, Any]:
|
212 |
+
"""Get information about the current provider."""
|
213 |
+
return {
|
214 |
+
"provider": self.current_provider,
|
215 |
+
"model": self.model_name,
|
216 |
+
"config_name": self.current_config["name"] if self.current_config else "none",
|
217 |
+
"available_providers": {
|
218 |
+
"groq": len(self.providers["groq"]),
|
219 |
+
"gemini": len(self.providers["gemini"]),
|
220 |
+
"openai": len(self.providers["openai"])
|
221 |
+
}
|
222 |
+
}
|
223 |
+
|
224 |
+
async def test_connection(self) -> bool:
|
225 |
+
"""Test the connection to the current LLM provider."""
|
226 |
+
try:
|
227 |
+
test_prompt = "Say 'Hello' if you can read this."
|
228 |
+
response = await self.generate_simple(test_prompt, temperature=0.1, max_tokens=10)
|
229 |
+
return "hello" in response.lower()
|
230 |
+
except Exception as e:
|
231 |
+
print(f"❌ Connection test failed: {str(e)}")
|
232 |
+
return False
|
233 |
+
|
234 |
+
# Create a global instance
|
235 |
+
llm_handler = MultiLLMHandler()
|
LLM/one_shotter.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import asyncio
|
3 |
+
from typing import List, Dict
|
4 |
+
from urllib.parse import urlparse
|
5 |
+
import httpx
|
6 |
+
from bs4 import BeautifulSoup
|
7 |
+
import os
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
|
10 |
+
load_dotenv()
|
11 |
+
|
12 |
+
# Import our multi-LLM handler
|
13 |
+
from LLM.llm_handler import llm_handler
|
14 |
+
|
15 |
+
# URL extraction pattern (same as ShastraDocs)
|
16 |
+
URL_PATTERN = re.compile(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
|
17 |
+
|
18 |
+
def extract_urls_from_text(text: str) -> List[str]:
|
19 |
+
urls = URL_PATTERN.findall(text or "")
|
20 |
+
seen = set()
|
21 |
+
clean_urls = []
|
22 |
+
for url in urls:
|
23 |
+
clean_url = url.rstrip('.,;:!?)')
|
24 |
+
if clean_url and clean_url not in seen and validate_url(clean_url):
|
25 |
+
seen.add(clean_url)
|
26 |
+
clean_urls.append(clean_url)
|
27 |
+
return clean_urls
|
28 |
+
|
29 |
+
def validate_url(url: str) -> bool:
|
30 |
+
try:
|
31 |
+
result = urlparse(url)
|
32 |
+
return bool(result.scheme and result.netloc)
|
33 |
+
except Exception:
|
34 |
+
return False
|
35 |
+
|
36 |
+
async def scrape_url(url: str, max_chars: int = 4000) -> Dict[str, str]:
|
37 |
+
"""Async URL scraping using httpx + BeautifulSoup (FastAPI-friendly)."""
|
38 |
+
try:
|
39 |
+
timeout = httpx.Timeout(20.0)
|
40 |
+
headers = {
|
41 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
42 |
+
}
|
43 |
+
async with httpx.AsyncClient(timeout=timeout, headers=headers, follow_redirects=True) as client:
|
44 |
+
resp = await client.get(url)
|
45 |
+
resp.raise_for_status()
|
46 |
+
soup = BeautifulSoup(resp.content, 'html.parser')
|
47 |
+
for tag in soup(['script', 'style', 'nav', 'footer', 'header', 'aside']):
|
48 |
+
tag.decompose()
|
49 |
+
text_content = soup.get_text(separator=' ', strip=True)
|
50 |
+
cleaned = ' '.join(text_content.split())
|
51 |
+
if len(cleaned) > max_chars:
|
52 |
+
cleaned = cleaned[:max_chars] + "..."
|
53 |
+
return {
|
54 |
+
'url': url,
|
55 |
+
'content': cleaned,
|
56 |
+
'status': 'success',
|
57 |
+
'length': len(cleaned),
|
58 |
+
'title': soup.title.string if soup.title else 'No title'
|
59 |
+
}
|
60 |
+
except httpx.TimeoutException:
|
61 |
+
return {'url': url, 'content': 'Timeout error', 'status': 'timeout', 'length': 0, 'title': 'Timeout'}
|
62 |
+
except Exception as e:
|
63 |
+
return {'url': url, 'content': f'Error: {str(e)[:100]}', 'status': 'error', 'length': 0, 'title': 'Error'}
|
64 |
+
|
65 |
+
async def scrape_urls(urls: List[str], max_chars: int = 4000) -> List[Dict[str, str]]:
|
66 |
+
if not urls:
|
67 |
+
return []
|
68 |
+
sem = asyncio.Semaphore(5)
|
69 |
+
async def _scrape(u):
|
70 |
+
async with sem:
|
71 |
+
return await scrape_url(u, max_chars)
|
72 |
+
results = await asyncio.gather(*[_scrape(u) for u in urls], return_exceptions=True)
|
73 |
+
final = []
|
74 |
+
for i, r in enumerate(results):
|
75 |
+
if isinstance(r, Exception):
|
76 |
+
final.append({'url': urls[i], 'content': f'Exception: {str(r)[:100]}', 'status': 'exception', 'length': 0, 'title': 'Exception'})
|
77 |
+
else:
|
78 |
+
final.append(r)
|
79 |
+
return final
|
80 |
+
|
81 |
+
def build_additional_content(scrapes: List[Dict[str, str]]) -> str:
|
82 |
+
parts = []
|
83 |
+
for r in scrapes:
|
84 |
+
if r.get('status') == 'success' and r.get('length', 0) > 50:
|
85 |
+
parts.append("\n" + "="*50)
|
86 |
+
parts.append(f"SOURCE: Additional Source")
|
87 |
+
parts.append(f"URL: {r.get('url','')}")
|
88 |
+
parts.append(f"TITLE: {r.get('title','No title')}")
|
89 |
+
parts.append("-"*30 + " CONTENT " + "-"*30)
|
90 |
+
parts.append(r.get('content',''))
|
91 |
+
parts.append("="*50)
|
92 |
+
return "\n".join(parts)
|
93 |
+
|
94 |
+
def parse_numbered_answers(text: str, expected_count: int) -> List[str]:
|
95 |
+
"""Parse numbered answers, with sane fallbacks."""
|
96 |
+
pattern = re.compile(r'^\s*(\d+)[\).\-]\s*(.+)$', re.MULTILINE)
|
97 |
+
matches = pattern.findall(text or "")
|
98 |
+
result: Dict[int, str] = {}
|
99 |
+
for num_str, answer in matches:
|
100 |
+
try:
|
101 |
+
num = int(num_str)
|
102 |
+
if 1 <= num <= expected_count:
|
103 |
+
clean_answer = re.sub(r'\s+', ' ', answer).strip()
|
104 |
+
if clean_answer:
|
105 |
+
result[num] = clean_answer
|
106 |
+
except Exception:
|
107 |
+
continue
|
108 |
+
answers: List[str] = []
|
109 |
+
for i in range(1, expected_count + 1):
|
110 |
+
answers.append(result.get(i, f"Unable to find answer for question {i}"))
|
111 |
+
return answers
|
112 |
+
|
113 |
+
def parse_answers_from_json(raw: str, expected_count: int) -> List[str]:
|
114 |
+
import json, re
|
115 |
+
# Try direct JSON
|
116 |
+
try:
|
117 |
+
obj = json.loads(raw)
|
118 |
+
if isinstance(obj, dict) and isinstance(obj.get('answers'), list):
|
119 |
+
out = [str(x).strip() for x in obj['answers']][:expected_count]
|
120 |
+
while len(out) < expected_count:
|
121 |
+
out.append(f"Unable to find answer for question {len(out)+1}")
|
122 |
+
return out
|
123 |
+
except Exception:
|
124 |
+
pass
|
125 |
+
# Try to extract JSON fragment
|
126 |
+
m = re.search(r'\{[^\{\}]*"answers"[^\{\}]*\}', raw or "", re.DOTALL)
|
127 |
+
if m:
|
128 |
+
try:
|
129 |
+
obj = json.loads(m.group(0))
|
130 |
+
if isinstance(obj, dict) and isinstance(obj.get('answers'), list):
|
131 |
+
out = [str(x).strip() for x in obj['answers']][:expected_count]
|
132 |
+
while len(out) < expected_count:
|
133 |
+
out.append(f"Unable to find answer for question {len(out)+1}")
|
134 |
+
return out
|
135 |
+
except Exception:
|
136 |
+
pass
|
137 |
+
# Fallback to numbered parsing
|
138 |
+
return parse_numbered_answers(raw or "", expected_count)
|
139 |
+
|
140 |
+
async def get_oneshot_answer(content: str, questions: List[str]) -> List[str]:
|
141 |
+
"""
|
142 |
+
Enhanced oneshot QA flow with ShastraDocs-style URL extraction and scraping.
|
143 |
+
- Extract URLs from content and questions
|
144 |
+
- Scrape relevant pages
|
145 |
+
- Merge additional content and feed to LLM
|
146 |
+
- Return per-question answers
|
147 |
+
"""
|
148 |
+
if not questions:
|
149 |
+
return []
|
150 |
+
|
151 |
+
try:
|
152 |
+
# Build numbered questions
|
153 |
+
numbered_questions = "\n".join([f"{i+1}. {q}" for i, q in enumerate(questions)])
|
154 |
+
|
155 |
+
# Find URLs from content and questions
|
156 |
+
combined = (content or "") + "\n" + "\n".join(questions or [])
|
157 |
+
found_urls = extract_urls_from_text(combined)
|
158 |
+
|
159 |
+
# Special case: content starts with URL marker
|
160 |
+
if content.startswith("URL for Context:"):
|
161 |
+
only_url = content.replace("URL for Context:", "").strip()
|
162 |
+
if validate_url(only_url):
|
163 |
+
if only_url not in found_urls:
|
164 |
+
found_urls.insert(0, only_url)
|
165 |
+
|
166 |
+
# Scrape URLs if any
|
167 |
+
additional_content = ""
|
168 |
+
if found_urls:
|
169 |
+
print(f"🚀 Scraping {len(found_urls)} URL(s) for additional context...")
|
170 |
+
scrape_results = await scrape_urls(found_urls, max_chars=4000)
|
171 |
+
additional_content = build_additional_content(scrape_results)
|
172 |
+
print(f"📄 Additional content length: {len(additional_content)}")
|
173 |
+
|
174 |
+
# Merge final context
|
175 |
+
if additional_content:
|
176 |
+
final_context = (content or "") + "\n\nADDITIONAL INFORMATION FROM SCRAPED SOURCES:\n" + additional_content
|
177 |
+
else:
|
178 |
+
final_context = content or ""
|
179 |
+
|
180 |
+
print(f"📊 Final context length: {len(final_context)}")
|
181 |
+
|
182 |
+
# Prompts (ask for JSON answers to improve parsing)
|
183 |
+
system_prompt = (
|
184 |
+
"You are an expert assistant. Read ALL provided context (including any 'ADDITIONAL INFORMATION FROM\n"
|
185 |
+
"SCRAPED SOURCES') and answer the questions comprehensively. If info is missing, say so."
|
186 |
+
)
|
187 |
+
|
188 |
+
user_prompt = f"""FULL CONTEXT:
|
189 |
+
{final_context[:8000]}{"..." if len(final_context) > 8000 else ""}
|
190 |
+
|
191 |
+
QUESTIONS:
|
192 |
+
{numbered_questions}
|
193 |
+
|
194 |
+
Respond in this EXACT JSON format:
|
195 |
+
{{
|
196 |
+
"answers": [
|
197 |
+
"<Answer to question 1>",
|
198 |
+
"<Answer to question 2>",
|
199 |
+
"<Answer to question 3>"
|
200 |
+
]
|
201 |
+
}}"""
|
202 |
+
|
203 |
+
print(f"🤖 Using {llm_handler.provider.upper()} model: {llm_handler.model_name}")
|
204 |
+
raw = await llm_handler.generate_text(
|
205 |
+
system_prompt=system_prompt,
|
206 |
+
user_prompt=user_prompt,
|
207 |
+
temperature=0.4,
|
208 |
+
max_tokens=1800
|
209 |
+
)
|
210 |
+
|
211 |
+
print(f"🔄 LLM response length: {len(raw) if raw else 0}")
|
212 |
+
answers = parse_answers_from_json(raw, len(questions))
|
213 |
+
print(f"✅ Parsed {len(answers)} answers")
|
214 |
+
return answers
|
215 |
+
|
216 |
+
except Exception as e:
|
217 |
+
print(f"❌ Error in oneshot answer generation: {str(e)}")
|
218 |
+
return [f"Error processing question: {str(e)}" for _ in questions]
|
LLM/tabular_answer.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import math
|
4 |
+
from typing import List
|
5 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
6 |
+
from langchain_groq import ChatGroq
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
|
9 |
+
load_dotenv()
|
10 |
+
TABULAR_VERBOSE = os.environ.get("TABULAR_VERBOSE", "0") in ("1", "true", "True", "yes", "YES")
|
11 |
+
|
12 |
+
# Initialize Groq LLM for tabular data using specialized API key
|
13 |
+
TABULAR_MODEL = os.environ.get("GROQ_TABULAR_MODEL", os.environ.get("GROQ_MODEL_TABULAR", "qwen/qwen3-32b"))
|
14 |
+
GROQ_LLM = ChatGroq(
|
15 |
+
groq_api_key=os.environ.get("GROQ_API_KEY_TABULAR", os.environ.get("GROQ_API_KEY")),
|
16 |
+
model_name=TABULAR_MODEL
|
17 |
+
)
|
18 |
+
|
19 |
+
def get_answer_for_tabluar(
|
20 |
+
data: str,
|
21 |
+
questions: List[str],
|
22 |
+
batch_size: int = 10,
|
23 |
+
verbose: bool = False
|
24 |
+
) -> List[str]:
|
25 |
+
"""
|
26 |
+
Query Groq LLM for tabular data analysis, handling batches and preserving order of answers.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
data (str): Tabular context in markdown or plain-text.
|
30 |
+
questions (List[str]): List of questions to ask.
|
31 |
+
batch_size (int): Max number of questions per batch.
|
32 |
+
verbose (bool): If True, print raw LLM responses.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
List[str]: Ordered list of answers corresponding to input questions.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def parse_numbered_answers(text: str, expected: int) -> List[str]:
|
39 |
+
"""
|
40 |
+
Parse answers from a numbered list format ('1.', '2.', etc.)
|
41 |
+
Use non-greedy capture with lookahead to stop at the next number or end.
|
42 |
+
"""
|
43 |
+
pattern = re.compile(r"^\s*(\d{1,2})[\.)\-]\s*(.*?)(?=\n\s*\d{1,2}[\.)\-]\s*|$)", re.MULTILINE | re.DOTALL)
|
44 |
+
matches = pattern.findall(text)
|
45 |
+
|
46 |
+
result = {}
|
47 |
+
for num_str, answer in matches:
|
48 |
+
try:
|
49 |
+
num = int(num_str)
|
50 |
+
except ValueError:
|
51 |
+
continue
|
52 |
+
if 1 <= num <= expected:
|
53 |
+
clean_answer = re.sub(r'\s+', ' ', answer).strip()
|
54 |
+
result[num] = clean_answer
|
55 |
+
|
56 |
+
# If no structured matches, fall back to line-based heuristic
|
57 |
+
if not result:
|
58 |
+
lines = [ln.strip() for ln in text.strip().splitlines() if ln.strip()]
|
59 |
+
for i in range(min(expected, len(lines))):
|
60 |
+
result[i + 1] = lines[i]
|
61 |
+
|
62 |
+
# Build fixed-length list
|
63 |
+
answers = []
|
64 |
+
for i in range(1, expected + 1):
|
65 |
+
answers.append(result.get(i, f"Unable to answer question {i}"))
|
66 |
+
|
67 |
+
return answers
|
68 |
+
|
69 |
+
if not questions:
|
70 |
+
return []
|
71 |
+
|
72 |
+
# Process questions in batches
|
73 |
+
all_answers = []
|
74 |
+
total_batches = math.ceil(len(questions) / batch_size)
|
75 |
+
|
76 |
+
for batch_idx in range(total_batches):
|
77 |
+
start = batch_idx * batch_size
|
78 |
+
end = min(start + batch_size, len(questions))
|
79 |
+
batch_questions = questions[start:end]
|
80 |
+
|
81 |
+
print(f"Processing batch {batch_idx + 1}/{total_batches} ({len(batch_questions)} questions)")
|
82 |
+
|
83 |
+
# Create numbered question list
|
84 |
+
numbered_questions = "\\n".join([f"{i+1}. {q}" for i, q in enumerate(batch_questions)])
|
85 |
+
|
86 |
+
# Create prompt
|
87 |
+
system_prompt = """You are an expert data analyst. Analyze the provided tabular data and answer the questions accurately.
|
88 |
+
|
89 |
+
Instructions:
|
90 |
+
- Answer each question based ONLY on the data provided
|
91 |
+
- If data is insufficient, state "Information not available in the provided data"
|
92 |
+
- Provide clear, concise answers
|
93 |
+
- Format your response as a numbered list (1., 2., 3., etc.)
|
94 |
+
- Do not add explanations unless specifically asked"""
|
95 |
+
|
96 |
+
user_prompt = f"""Data:
|
97 |
+
{data}
|
98 |
+
|
99 |
+
Questions:
|
100 |
+
{numbered_questions}
|
101 |
+
|
102 |
+
Please provide numbered answers (1., 2., 3., etc.) for each question."""
|
103 |
+
|
104 |
+
try:
|
105 |
+
# Create messages
|
106 |
+
messages = [
|
107 |
+
SystemMessage(content=system_prompt),
|
108 |
+
HumanMessage(content=user_prompt)
|
109 |
+
]
|
110 |
+
|
111 |
+
# Get response from LLM
|
112 |
+
response = GROQ_LLM.invoke(messages)
|
113 |
+
raw_response = response.content or ""
|
114 |
+
|
115 |
+
if verbose or TABULAR_VERBOSE:
|
116 |
+
print(f"🟢 Raw LLM Response (batch {batch_idx + 1}):\n{raw_response[:1200]}\n--- END RAW ---")
|
117 |
+
|
118 |
+
# Parse the response
|
119 |
+
batch_answers = parse_numbered_answers(raw_response, len(batch_questions))
|
120 |
+
all_answers.extend(batch_answers)
|
121 |
+
|
122 |
+
except Exception as e:
|
123 |
+
print(f"Error processing batch {batch_idx + 1}: {str(e)}")
|
124 |
+
# Add error answers for this batch
|
125 |
+
error_answers = [f"Error processing question: {str(e)}" for _ in batch_questions]
|
126 |
+
all_answers.extend(error_answers)
|
127 |
+
|
128 |
+
return all_answers
|
RAG/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# RAG Package
|
RAG/advanced_rag_processor.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Advanced RAG Processor - Modular Version
|
3 |
+
Orchestrates all RAG components for document question answering.
|
4 |
+
Version: 3.0 - Modular Architecture
|
5 |
+
"""
|
6 |
+
|
7 |
+
import time
|
8 |
+
from typing import Dict, Tuple
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
# Import all modular components
|
12 |
+
from RAG.rag_modules.query_expansion import QueryExpansionManager
|
13 |
+
from RAG.rag_modules.embedding_manager import EmbeddingManager
|
14 |
+
from RAG.rag_modules.search_manager import SearchManager
|
15 |
+
from RAG.rag_modules.reranking_manager import RerankingManager
|
16 |
+
from RAG.rag_modules.context_manager import ContextManager
|
17 |
+
from RAG.rag_modules.answer_generator import AnswerGenerator
|
18 |
+
|
19 |
+
from LLM.llm_handler import llm_handler
|
20 |
+
from config.config import OUTPUT_DIR, TOP_K
|
21 |
+
|
22 |
+
|
23 |
+
class AdvancedRAGProcessor:
|
24 |
+
"""
|
25 |
+
Advanced RAG processor with modular architecture for better maintainability.
|
26 |
+
Orchestrates query expansion, hybrid search, reranking, and answer generation.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self):
|
30 |
+
"""Initialize the advanced RAG processor with all modules."""
|
31 |
+
self.base_db_path = Path(OUTPUT_DIR)
|
32 |
+
|
33 |
+
# Initialize all managers
|
34 |
+
print("🚀 Initializing Advanced RAG Processor (Modular)...")
|
35 |
+
|
36 |
+
# Core components
|
37 |
+
self.embedding_manager = EmbeddingManager()
|
38 |
+
self.query_expansion_manager = QueryExpansionManager()
|
39 |
+
self.search_manager = SearchManager(self.embedding_manager)
|
40 |
+
self.reranking_manager = RerankingManager()
|
41 |
+
self.context_manager = ContextManager()
|
42 |
+
self.answer_generator = AnswerGenerator()
|
43 |
+
|
44 |
+
# Keep reference to LLM handler for info
|
45 |
+
self.llm_handler = llm_handler
|
46 |
+
|
47 |
+
print(f"✅ Advanced RAG Processor initialized with {self.llm_handler.provider.upper()} LLM")
|
48 |
+
print("📦 All modules loaded successfully:")
|
49 |
+
print(" 🔄 Query Expansion Manager")
|
50 |
+
print(" 🧠 Embedding Manager")
|
51 |
+
print(" 🔍 Search Manager (Hybrid)")
|
52 |
+
print(" 🎯 Reranking Manager")
|
53 |
+
print(" 📝 Context Manager")
|
54 |
+
print(" 💬 Answer Generator")
|
55 |
+
|
56 |
+
async def answer_question(self, question: str, doc_id: str, logger=None, request_id: str = None) -> Tuple[str, Dict[str, float]]:
|
57 |
+
"""
|
58 |
+
Answer a question using advanced RAG techniques with detailed timing.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
question: The question to answer
|
62 |
+
doc_id: Document ID to search in
|
63 |
+
logger: Optional logger for tracking
|
64 |
+
request_id: Optional request ID for logging
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Tuple of (answer, timing_breakdown)
|
68 |
+
"""
|
69 |
+
timings = {}
|
70 |
+
overall_start = time.time()
|
71 |
+
|
72 |
+
try:
|
73 |
+
# Check if collection exists
|
74 |
+
collection_name = f"{doc_id}_collection"
|
75 |
+
try:
|
76 |
+
client = self.search_manager.get_qdrant_client(doc_id)
|
77 |
+
collection_info = client.get_collection(collection_name)
|
78 |
+
except Exception:
|
79 |
+
return "I don't have information about this document. Please ensure the document has been processed.", timings
|
80 |
+
|
81 |
+
print(f"🚀 Advanced RAG processing for: {question[:100]}...")
|
82 |
+
|
83 |
+
# Step 1: Query Expansion
|
84 |
+
step_start = time.time()
|
85 |
+
expanded_queries = await self.query_expansion_manager.expand_query(question)
|
86 |
+
expansion_time = time.time() - step_start
|
87 |
+
timings['query_expansion'] = expansion_time
|
88 |
+
if logger and request_id:
|
89 |
+
logger.log_pipeline_stage(request_id, "query_expansion", expansion_time)
|
90 |
+
|
91 |
+
# Step 2: Hybrid Search with Fusion
|
92 |
+
step_start = time.time()
|
93 |
+
search_results = await self.search_manager.hybrid_search(expanded_queries, doc_id, TOP_K)
|
94 |
+
search_time = time.time() - step_start
|
95 |
+
timings['hybrid_search'] = search_time
|
96 |
+
if logger and request_id:
|
97 |
+
logger.log_pipeline_stage(request_id, "hybrid_search", search_time)
|
98 |
+
|
99 |
+
if not search_results:
|
100 |
+
return "I couldn't find relevant information to answer your question.", timings
|
101 |
+
|
102 |
+
# Step 3: Reranking
|
103 |
+
step_start = time.time()
|
104 |
+
reranked_results = await self.reranking_manager.rerank_results(question, search_results)
|
105 |
+
rerank_time = time.time() - step_start
|
106 |
+
timings['reranking'] = rerank_time
|
107 |
+
if logger and request_id:
|
108 |
+
logger.log_pipeline_stage(request_id, "reranking", rerank_time)
|
109 |
+
|
110 |
+
# Step 4: Multi-perspective Context Creation
|
111 |
+
step_start = time.time()
|
112 |
+
context = self.context_manager.create_enhanced_context(question, reranked_results)
|
113 |
+
context_time = time.time() - step_start
|
114 |
+
timings['context_creation'] = context_time
|
115 |
+
if logger and request_id:
|
116 |
+
logger.log_pipeline_stage(request_id, "context_creation", context_time)
|
117 |
+
|
118 |
+
# Step 5: Enhanced Answer Generation
|
119 |
+
step_start = time.time()
|
120 |
+
answer = await self.answer_generator.generate_enhanced_answer(question, context, expanded_queries)
|
121 |
+
generation_time = time.time() - step_start
|
122 |
+
timings['llm_generation'] = generation_time
|
123 |
+
if logger and request_id:
|
124 |
+
logger.log_pipeline_stage(request_id, "llm_generation", generation_time)
|
125 |
+
|
126 |
+
# Calculate total time
|
127 |
+
total_time = time.time() - overall_start
|
128 |
+
timings['total_pipeline'] = total_time
|
129 |
+
|
130 |
+
print(f"\n✅ Advanced RAG processing completed in {total_time:.4f}s")
|
131 |
+
print(f" 🔍 Query expansion: {expansion_time:.4f}s")
|
132 |
+
print(f" 🔎 Hybrid search: {search_time:.4f}s")
|
133 |
+
print(f" 🎯 Reranking: {rerank_time:.4f}s")
|
134 |
+
print(f" 📝 Context creation: {context_time:.4f}s")
|
135 |
+
print(f" 💬 LLM generation: {generation_time:.4f}s")
|
136 |
+
|
137 |
+
return answer, timings
|
138 |
+
|
139 |
+
except Exception as e:
|
140 |
+
error_time = time.time() - overall_start
|
141 |
+
timings['error_time'] = error_time
|
142 |
+
print(f"❌ Error in advanced RAG processing: {str(e)}")
|
143 |
+
return f"I encountered an error while processing your question: {str(e)}", timings
|
144 |
+
|
145 |
+
def cleanup(self):
|
146 |
+
"""Cleanup all manager resources."""
|
147 |
+
print("🧹 Cleaning up Advanced RAG processor resources...")
|
148 |
+
|
149 |
+
# Cleanup search manager (which has the most resources)
|
150 |
+
self.search_manager.cleanup()
|
151 |
+
|
152 |
+
print("✅ Advanced RAG cleanup completed")
|
153 |
+
|
154 |
+
def get_system_info(self) -> Dict:
|
155 |
+
"""Get information about the RAG system."""
|
156 |
+
return {
|
157 |
+
"version": "3.0 - Modular",
|
158 |
+
"llm_provider": self.llm_handler.provider,
|
159 |
+
"llm_model": self.llm_handler.model_name,
|
160 |
+
"modules": [
|
161 |
+
"QueryExpansionManager",
|
162 |
+
"EmbeddingManager",
|
163 |
+
"SearchManager",
|
164 |
+
"RerankingManager",
|
165 |
+
"ContextManager",
|
166 |
+
"AnswerGenerator"
|
167 |
+
],
|
168 |
+
"base_db_path": str(self.base_db_path)
|
169 |
+
}
|
RAG/rag_embeddings/.gitkeep
ADDED
File without changes
|
RAG/rag_modules/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# RAG Modules Package
|
RAG/rag_modules/answer_generator.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Answer Generation Module for Advanced RAG
|
3 |
+
Handles LLM-based answer generation with enhanced prompting.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import List
|
7 |
+
from LLM.llm_handler import llm_handler
|
8 |
+
from config.config import TEMPERATURE, MAX_TOKENS
|
9 |
+
|
10 |
+
|
11 |
+
class AnswerGenerator:
|
12 |
+
"""Manages answer generation using LLM."""
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
"""Initialize the answer generator."""
|
16 |
+
self.llm_handler = llm_handler
|
17 |
+
print("✅ Answer Generator initialized")
|
18 |
+
|
19 |
+
async def generate_enhanced_answer(self, original_question: str, context: str, expanded_queries: List[str]) -> str:
|
20 |
+
"""Generate enhanced answer using the original question with retrieved context."""
|
21 |
+
|
22 |
+
# Use only the original question for LLM generation
|
23 |
+
query_context = f"Question: {original_question}"
|
24 |
+
|
25 |
+
system_prompt = """
|
26 |
+
|
27 |
+
You are an expert AI assistant specializing in document analysis and policy-related question answering. You have access to relevant document excerpts and must respond only based on this information. You are designed specifically for analyzing official documents and answering user queries related to them.
|
28 |
+
|
29 |
+
STRICT RULES AND RESPONSE CONDITIONS:
|
30 |
+
|
31 |
+
Irrelevant/Out-of-Scope Queries (e.g., programming help, general product info, coding tasks):
|
32 |
+
Respond EXACTLY:
|
33 |
+
|
34 |
+
"I cannot help with that. I am designed only to answer queries related to the provided document excerpts."
|
35 |
+
|
36 |
+
Illegal or Prohibited Requests (e.g., forgery, fraud, bypassing regulations):
|
37 |
+
Respond CLEARLY that the request is illegal. Example format:
|
38 |
+
|
39 |
+
"This request is illegal and cannot be supported. According to the applicable regulations in the document, [explain why it's illegal if mentioned]. Engaging in such activity may lead to legal consequences."
|
40 |
+
If illegality is not explicitly in the documents, use:
|
41 |
+
"This request involves illegal activity and is against policy. I cannot assist with this."
|
42 |
+
|
43 |
+
Nonexistent Concepts, Schemes, or Entities:
|
44 |
+
Respond by stating the concept does not exist and offer clarification by pointing to related valid information. Example:
|
45 |
+
|
46 |
+
"There is no mention of such a scheme in the document. However, the following related schemes are described: [summarize relevant ones]."
|
47 |
+
|
48 |
+
Valid Topics with Missing or Incomplete Information:
|
49 |
+
Respond that the exact answer is unavailable, then provide all related details and recommend official contact. Example:
|
50 |
+
|
51 |
+
"The exact information is not available in the provided document. However, here is what is relevant: [details]. For further clarification, you may contact: [official contact details if included in the document]."
|
52 |
+
|
53 |
+
Valid Questions Answerable from Document:
|
54 |
+
Provide a concise and accurate answer with clear reference to the document content. Also include any related notes that might aid understanding. Example:
|
55 |
+
|
56 |
+
"[Answer]. According to the policy document, [quote/summary from actual document content]."
|
57 |
+
|
58 |
+
GENERAL ANSWERING RULES:
|
59 |
+
|
60 |
+
Use ONLY the provided document excerpts. Never use external knowledge.
|
61 |
+
|
62 |
+
Be concise: 5-6 sentences per answer, with all the details available for that particular query.
|
63 |
+
|
64 |
+
Start directly with the answer. Do not restate or rephrase the question.
|
65 |
+
|
66 |
+
Never speculate or elaborate beyond what is explicitly stated.
|
67 |
+
|
68 |
+
When referencing information, mention "according to the document" or "as stated in the policy" rather than using internal labels like "Query X Doc Y".
|
69 |
+
|
70 |
+
Do not reference internal organizational labels like [Query 1 Doc 2] or [Relevance: X.XX] - these are for processing only.
|
71 |
+
|
72 |
+
Focus on the actual document content and policy information when providing answers.
|
73 |
+
|
74 |
+
The user may phrase questions in various ways — always infer the intent, apply the rules above, and respond accordingly.
|
75 |
+
|
76 |
+
"""
|
77 |
+
|
78 |
+
user_prompt = f"""{query_context}
|
79 |
+
|
80 |
+
Document Excerpts:
|
81 |
+
{context}
|
82 |
+
|
83 |
+
Provide a comprehensive answer based on the document excerpts above:"""
|
84 |
+
|
85 |
+
try:
|
86 |
+
answer = await self.llm_handler.generate_text(
|
87 |
+
system_prompt=system_prompt,
|
88 |
+
user_prompt=user_prompt,
|
89 |
+
temperature=TEMPERATURE,
|
90 |
+
max_tokens=MAX_TOKENS
|
91 |
+
)
|
92 |
+
|
93 |
+
return answer.strip()
|
94 |
+
|
95 |
+
except Exception as e:
|
96 |
+
print(f"❌ Error generating enhanced response with {self.llm_handler.provider.upper()}: {str(e)}")
|
97 |
+
return "I encountered an error while generating the response."
|
RAG/rag_modules/context_manager.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Context Management Module for Advanced RAG
|
3 |
+
Handles context creation and management for LLM generation.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import List, Dict
|
7 |
+
from collections import defaultdict
|
8 |
+
from config.config import MAX_CONTEXT_LENGTH
|
9 |
+
|
10 |
+
|
11 |
+
class ContextManager:
|
12 |
+
"""Manages context creation for LLM generation."""
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
"""Initialize the context manager."""
|
16 |
+
print("✅ Context Manager initialized")
|
17 |
+
|
18 |
+
def create_enhanced_context(self, question: str, results: List[Dict], max_length: int = MAX_CONTEXT_LENGTH) -> str:
|
19 |
+
"""Create enhanced context ensuring each query contributes equally."""
|
20 |
+
# Group results by expanded query index
|
21 |
+
query_to_chunks = defaultdict(list)
|
22 |
+
for i, result in enumerate(results):
|
23 |
+
# Find the most relevant expanded query for this chunk
|
24 |
+
if 'contributing_queries' in result and result['contributing_queries']:
|
25 |
+
# Use the highest scoring contributing query
|
26 |
+
best_contrib = max(result['contributing_queries'], key=lambda cq: cq.get('semantic_score', cq.get('bm25_score', 0)))
|
27 |
+
query_idx = best_contrib['query_idx']
|
28 |
+
else:
|
29 |
+
query_idx = 0 # fallback to first query
|
30 |
+
query_to_chunks[query_idx].append((i, result))
|
31 |
+
|
32 |
+
# Sort chunks within each query by their relevance scores
|
33 |
+
for q_idx in query_to_chunks:
|
34 |
+
query_to_chunks[q_idx].sort(key=lambda x: x[1].get('rerank_score', x[1].get('final_score', x[1].get('score', 0))), reverse=True)
|
35 |
+
|
36 |
+
# Calculate chunks per query (should be 3 for each query with total budget = 9 and 3 queries)
|
37 |
+
num_queries = len(query_to_chunks)
|
38 |
+
if num_queries == 0:
|
39 |
+
return ""
|
40 |
+
|
41 |
+
# Ensure each query contributes equally (round-robin with guaranteed slots)
|
42 |
+
context_parts = []
|
43 |
+
current_length = 0
|
44 |
+
added_chunks = set()
|
45 |
+
|
46 |
+
# Calculate how many chunks each query should contribute
|
47 |
+
chunks_per_query = len(results) // num_queries if num_queries > 0 else len(results)
|
48 |
+
extra_chunks = len(results) % num_queries
|
49 |
+
|
50 |
+
print(f"📊 Context Creation: {num_queries} queries, {chunks_per_query} chunks per query (+{extra_chunks} extra)")
|
51 |
+
|
52 |
+
for q_idx in sorted(query_to_chunks.keys()):
|
53 |
+
# Determine how many chunks this query should contribute
|
54 |
+
query_chunk_limit = chunks_per_query + (1 if q_idx < extra_chunks else 0)
|
55 |
+
query_chunks_added = 0
|
56 |
+
|
57 |
+
print(f" Query {q_idx+1}: Adding up to {query_chunk_limit} chunks")
|
58 |
+
|
59 |
+
for i, result in query_to_chunks[q_idx]:
|
60 |
+
if i not in added_chunks and query_chunks_added < query_chunk_limit:
|
61 |
+
text = result['payload'].get('text', '')
|
62 |
+
relevance_info = ""
|
63 |
+
if 'rerank_score' in result:
|
64 |
+
relevance_info = f" [Relevance: {result['rerank_score']:.2f}]"
|
65 |
+
elif 'final_score' in result:
|
66 |
+
relevance_info = f" [Score: {result['final_score']:.2f}]"
|
67 |
+
doc_text = f"[Query {q_idx+1} Doc {len(added_chunks)+1}]{relevance_info}\n{text}\n"
|
68 |
+
|
69 |
+
if current_length + len(doc_text) > max_length:
|
70 |
+
print(f" ⚠️ Context length limit reached at {current_length} chars")
|
71 |
+
break
|
72 |
+
|
73 |
+
context_parts.append(doc_text)
|
74 |
+
current_length += len(doc_text)
|
75 |
+
added_chunks.add(i)
|
76 |
+
query_chunks_added += 1
|
77 |
+
|
78 |
+
print(f" Query {q_idx+1}: Added {query_chunks_added} chunks")
|
79 |
+
|
80 |
+
print(f"📝 Final context: {len(added_chunks)} chunks, {current_length} chars")
|
81 |
+
return "\n".join(context_parts)
|
RAG/rag_modules/embedding_manager.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Embedding Management Module for Advanced RAG
|
3 |
+
Handles text encoding and embedding operations.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import asyncio
|
7 |
+
from typing import List
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
from config.config import EMBEDDING_MODEL
|
10 |
+
|
11 |
+
|
12 |
+
class EmbeddingManager:
|
13 |
+
"""Manages text embeddings for RAG operations."""
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
"""Initialize the embedding manager."""
|
17 |
+
self.embedding_model = None
|
18 |
+
self._init_embedding_model()
|
19 |
+
|
20 |
+
def _init_embedding_model(self):
|
21 |
+
"""Initialize the embedding model."""
|
22 |
+
print(f"🔄 Loading embedding model: {EMBEDDING_MODEL}")
|
23 |
+
self.embedding_model = SentenceTransformer(EMBEDDING_MODEL)
|
24 |
+
print(f"✅ Embedding model loaded successfully")
|
25 |
+
|
26 |
+
async def encode_query(self, query: str) -> List[float]:
|
27 |
+
"""Encode a query into embeddings."""
|
28 |
+
def encode_sync():
|
29 |
+
embedding = self.embedding_model.encode([query], normalize_embeddings=True)
|
30 |
+
return embedding[0].astype("float32").tolist()
|
31 |
+
|
32 |
+
loop = asyncio.get_event_loop()
|
33 |
+
return await loop.run_in_executor(None, encode_sync)
|
34 |
+
|
35 |
+
async def encode_texts(self, texts: List[str]) -> List[List[float]]:
|
36 |
+
"""Encode multiple texts into embeddings."""
|
37 |
+
def encode_sync():
|
38 |
+
embeddings = self.embedding_model.encode(texts, normalize_embeddings=True)
|
39 |
+
return [emb.astype("float32").tolist() for emb in embeddings]
|
40 |
+
|
41 |
+
loop = asyncio.get_event_loop()
|
42 |
+
return await loop.run_in_executor(None, encode_sync)
|
RAG/rag_modules/query_expansion.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Query Expansion Module for Advanced RAG
|
3 |
+
Now uses Groq's llama3-8b-8192 model directly for generating focused sub-queries.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import re
|
7 |
+
import time
|
8 |
+
import asyncio
|
9 |
+
from typing import List
|
10 |
+
from groq import Groq
|
11 |
+
from config.config import (
|
12 |
+
ENABLE_QUERY_EXPANSION,
|
13 |
+
QUERY_EXPANSION_COUNT,
|
14 |
+
GROQ_API_KEY_LITE,
|
15 |
+
GROQ_MODEL_LITE,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class QueryExpansionManager:
|
20 |
+
"""Manages query expansion for better information retrieval."""
|
21 |
+
|
22 |
+
def __init__(self):
|
23 |
+
"""Initialize the query expansion manager with Groq client."""
|
24 |
+
# Initialize Groq client with the lite key and llama3-8b-8192 model
|
25 |
+
self.model = GROQ_MODEL_LITE or "llama3-8b-8192"
|
26 |
+
if not GROQ_API_KEY_LITE:
|
27 |
+
print("⚠️ GROQ_API_KEY_LITE is not set. Query expansion will fall back to original query.")
|
28 |
+
self.client = None
|
29 |
+
else:
|
30 |
+
self.client = Groq(api_key=GROQ_API_KEY_LITE)
|
31 |
+
print(f"✅ Query Expansion Manager initialized using Groq model: {self.model}")
|
32 |
+
|
33 |
+
async def expand_query(self, original_query: str) -> List[str]:
|
34 |
+
"""Break complex queries into focused parts for better information retrieval using Groq."""
|
35 |
+
if not ENABLE_QUERY_EXPANSION:
|
36 |
+
return [original_query]
|
37 |
+
if not self.client:
|
38 |
+
return [original_query]
|
39 |
+
|
40 |
+
try:
|
41 |
+
expansion_prompt = f"""Analyze this question and break it down into exactly {QUERY_EXPANSION_COUNT} specific, focused sub-questions that can be searched independently in a document. Each sub-question should target a distinct piece of information or process.
|
42 |
+
|
43 |
+
For complex questions with multiple parts, identify:
|
44 |
+
1. Different processes or procedures mentioned
|
45 |
+
2. Specific information requests (emails, contact details, forms, etc.)
|
46 |
+
3. Different entities or subjects involved
|
47 |
+
4. Sequential steps that might be documented separately
|
48 |
+
|
49 |
+
Original question: {original_query}
|
50 |
+
|
51 |
+
Break this into exactly {QUERY_EXPANSION_COUNT} focused search queries that target different aspects:
|
52 |
+
|
53 |
+
Examples of good breakdown:
|
54 |
+
- "What is the dental claim submission process?"
|
55 |
+
- "How to update surname/name in policy records?"
|
56 |
+
- "What are the company contact details and grievance email?"
|
57 |
+
|
58 |
+
Provide only {QUERY_EXPANSION_COUNT} focused sub-questions, one per line, without numbering or additional formatting:"""
|
59 |
+
|
60 |
+
# Call Groq's chat completions in a thread to avoid blocking the event loop
|
61 |
+
response = await asyncio.to_thread(
|
62 |
+
self.client.chat.completions.create,
|
63 |
+
messages=[{"role": "user", "content": expansion_prompt}],
|
64 |
+
model=self.model,
|
65 |
+
temperature=0.3,
|
66 |
+
max_tokens=300,
|
67 |
+
)
|
68 |
+
|
69 |
+
expanded_queries = [] # Start with empty list - don't include original
|
70 |
+
|
71 |
+
if response and response.choices:
|
72 |
+
content = response.choices[0].message.content if response.choices[0].message else ""
|
73 |
+
sub_queries = (content or "").strip().split('\n')
|
74 |
+
for query in sub_queries:
|
75 |
+
if len(expanded_queries) >= QUERY_EXPANSION_COUNT: # Stop when we have enough
|
76 |
+
break
|
77 |
+
query = query.strip()
|
78 |
+
# Remove any numbering or bullet points that might be added
|
79 |
+
query = re.sub(r'^[\d\.\-\*\s]+', '', query).strip()
|
80 |
+
if query and len(query) > 10:
|
81 |
+
expanded_queries.append(query)
|
82 |
+
|
83 |
+
# If we don't have enough sub-queries, fall back to using the original
|
84 |
+
if len(expanded_queries) < QUERY_EXPANSION_COUNT:
|
85 |
+
expanded_queries = [original_query] * QUERY_EXPANSION_COUNT
|
86 |
+
|
87 |
+
# Ensure we have exactly QUERY_EXPANSION_COUNT queries
|
88 |
+
final_queries = expanded_queries[:QUERY_EXPANSION_COUNT]
|
89 |
+
|
90 |
+
print(f"🔄 Query broken down from 1 complex question to {len(final_queries)} focused sub-queries using Groq {self.model}")
|
91 |
+
print(f"📌 Original query will be used for final LLM generation only")
|
92 |
+
for i, q in enumerate(final_queries):
|
93 |
+
print(f" Sub-query {i+1}: {q[:80]}...")
|
94 |
+
|
95 |
+
return final_queries
|
96 |
+
|
97 |
+
except Exception as e:
|
98 |
+
print(f"⚠️ Query expansion failed: {e}")
|
99 |
+
return [original_query]
|
100 |
+
|
101 |
+
def _identify_query_components(self, query: str) -> dict:
|
102 |
+
"""Identify different components in a complex query for better breakdown."""
|
103 |
+
components = {
|
104 |
+
'processes': [],
|
105 |
+
'documents': [],
|
106 |
+
'contacts': [],
|
107 |
+
'eligibility': [],
|
108 |
+
'timelines': [],
|
109 |
+
'benefits': []
|
110 |
+
}
|
111 |
+
|
112 |
+
# Define keywords for different component types
|
113 |
+
process_keywords = ['process', 'procedure', 'steps', 'how to', 'submit', 'apply', 'claim', 'update', 'change', 'enroll']
|
114 |
+
document_keywords = ['documents', 'forms', 'papers', 'certificate', 'proof', 'evidence', 'requirements']
|
115 |
+
contact_keywords = ['email', 'phone', 'contact', 'grievance', 'customer service', 'support', 'helpline']
|
116 |
+
eligibility_keywords = ['eligibility', 'criteria', 'qualify', 'eligible', 'conditions', 'requirements']
|
117 |
+
timeline_keywords = ['timeline', 'period', 'duration', 'time', 'days', 'months', 'waiting', 'grace']
|
118 |
+
benefit_keywords = ['benefits', 'coverage', 'limits', 'amount', 'reimbursement', 'claim amount']
|
119 |
+
|
120 |
+
query_lower = query.lower()
|
121 |
+
|
122 |
+
# Check for process-related content
|
123 |
+
if any(keyword in query_lower for keyword in process_keywords):
|
124 |
+
components['processes'].append('process identification')
|
125 |
+
|
126 |
+
# Check for document-related content
|
127 |
+
if any(keyword in query_lower for keyword in document_keywords):
|
128 |
+
components['documents'].append('document requirements')
|
129 |
+
|
130 |
+
# Check for contact-related content
|
131 |
+
if any(keyword in query_lower for keyword in contact_keywords):
|
132 |
+
components['contacts'].append('contact information')
|
133 |
+
|
134 |
+
# Check for eligibility-related content
|
135 |
+
if any(keyword in query_lower for keyword in eligibility_keywords):
|
136 |
+
components['eligibility'].append('eligibility criteria')
|
137 |
+
|
138 |
+
# Check for timeline-related content
|
139 |
+
if any(keyword in query_lower for keyword in timeline_keywords):
|
140 |
+
components['timelines'].append('timeline information')
|
141 |
+
|
142 |
+
# Check for benefit-related content
|
143 |
+
if any(keyword in query_lower for keyword in benefit_keywords):
|
144 |
+
components['benefits'].append('benefit details')
|
145 |
+
|
146 |
+
return components
|
RAG/rag_modules/reranking_manager.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Reranking Module for Advanced RAG
|
3 |
+
Handles result reranking using cross-encoder models.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import List, Dict
|
7 |
+
from sentence_transformers import CrossEncoder
|
8 |
+
from config.config import ENABLE_RERANKING, RERANKER_MODEL, RERANK_TOP_K
|
9 |
+
|
10 |
+
|
11 |
+
class RerankingManager:
|
12 |
+
"""Manages result reranking using cross-encoder models."""
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
"""Initialize the reranking manager."""
|
16 |
+
self.reranker_model = None
|
17 |
+
if ENABLE_RERANKING:
|
18 |
+
self._init_reranker_model()
|
19 |
+
print("✅ Reranking Manager initialized")
|
20 |
+
|
21 |
+
def _init_reranker_model(self):
|
22 |
+
"""Initialize the reranker model."""
|
23 |
+
print(f"🔄 Loading reranker model: {RERANKER_MODEL}")
|
24 |
+
self.reranker_model = CrossEncoder(RERANKER_MODEL)
|
25 |
+
print(f"✅ Reranker model loaded successfully")
|
26 |
+
|
27 |
+
async def rerank_results(self, query: str, search_results: List[Dict]) -> List[Dict]:
|
28 |
+
"""Rerank search results using cross-encoder."""
|
29 |
+
if not ENABLE_RERANKING or not self.reranker_model or len(search_results) <= 1:
|
30 |
+
return search_results
|
31 |
+
|
32 |
+
try:
|
33 |
+
# Prepare pairs for reranking
|
34 |
+
query_doc_pairs = []
|
35 |
+
for result in search_results:
|
36 |
+
doc_text = result['payload'].get('text', '')[:512] # Limit text length
|
37 |
+
query_doc_pairs.append([query, doc_text])
|
38 |
+
|
39 |
+
# Get reranking scores
|
40 |
+
rerank_scores = self.reranker_model.predict(query_doc_pairs)
|
41 |
+
|
42 |
+
# Combine with original scores
|
43 |
+
for i, result in enumerate(search_results):
|
44 |
+
original_score = result.get('score', 0)
|
45 |
+
rerank_score = float(rerank_scores[i])
|
46 |
+
|
47 |
+
# Weighted combination of original and rerank scores
|
48 |
+
result['rerank_score'] = rerank_score
|
49 |
+
result['final_score'] = 0.3 * original_score + 0.7 * rerank_score
|
50 |
+
|
51 |
+
# Sort by final score
|
52 |
+
reranked_results = sorted(
|
53 |
+
search_results,
|
54 |
+
key=lambda x: x['final_score'],
|
55 |
+
reverse=True
|
56 |
+
)
|
57 |
+
|
58 |
+
print(f"🎯 Reranked {len(search_results)} results")
|
59 |
+
return reranked_results[:RERANK_TOP_K]
|
60 |
+
|
61 |
+
except Exception as e:
|
62 |
+
print(f"⚠️ Reranking failed: {e}")
|
63 |
+
return search_results[:RERANK_TOP_K]
|
RAG/rag_modules/search_manager.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Search Module for Advanced RAG
|
3 |
+
Handles hybrid search combining BM25 and semantic search with score fusion.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import re
|
7 |
+
import time
|
8 |
+
import numpy as np
|
9 |
+
from typing import List, Dict, Any
|
10 |
+
from pathlib import Path
|
11 |
+
from rank_bm25 import BM25Okapi
|
12 |
+
from qdrant_client import QdrantClient
|
13 |
+
|
14 |
+
from config.config import (
|
15 |
+
OUTPUT_DIR, TOP_K, SCORE_THRESHOLD, ENABLE_HYBRID_SEARCH,
|
16 |
+
BM25_WEIGHT, SEMANTIC_WEIGHT, USE_TOTAL_BUDGET_APPROACH
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class SearchManager:
|
21 |
+
"""Manages hybrid search operations combining BM25 and semantic search."""
|
22 |
+
|
23 |
+
def __init__(self, embedding_manager):
|
24 |
+
"""Initialize the search manager."""
|
25 |
+
self.embedding_manager = embedding_manager
|
26 |
+
self.base_db_path = Path(OUTPUT_DIR)
|
27 |
+
self.qdrant_clients = {}
|
28 |
+
self.bm25_indexes = {} # Cache BM25 indexes per document
|
29 |
+
self.document_chunks = {} # Cache chunks for BM25
|
30 |
+
print("✅ Search Manager initialized")
|
31 |
+
|
32 |
+
def get_qdrant_client(self, doc_id: str) -> QdrantClient:
|
33 |
+
"""Get or create Qdrant client for a specific document."""
|
34 |
+
if doc_id not in self.qdrant_clients:
|
35 |
+
db_path = self.base_db_path / f"{doc_id}_collection.db"
|
36 |
+
if not db_path.exists():
|
37 |
+
raise FileNotFoundError(f"Database not found for document {doc_id}")
|
38 |
+
self.qdrant_clients[doc_id] = QdrantClient(path=str(db_path))
|
39 |
+
return self.qdrant_clients[doc_id]
|
40 |
+
|
41 |
+
def _load_bm25_index(self, doc_id: str):
|
42 |
+
"""Load or create BM25 index for a document."""
|
43 |
+
if doc_id not in self.bm25_indexes:
|
44 |
+
print(f"🔄 Loading BM25 index for {doc_id}")
|
45 |
+
|
46 |
+
# Get all chunks from Qdrant
|
47 |
+
client = self.get_qdrant_client(doc_id)
|
48 |
+
collection_name = f"{doc_id}_collection"
|
49 |
+
|
50 |
+
try:
|
51 |
+
# Get all points from collection
|
52 |
+
result = client.scroll(
|
53 |
+
collection_name=collection_name,
|
54 |
+
limit=10000, # Adjust based on your chunk count
|
55 |
+
with_payload=True,
|
56 |
+
with_vectors=False
|
57 |
+
)
|
58 |
+
|
59 |
+
chunks = []
|
60 |
+
chunk_ids = []
|
61 |
+
|
62 |
+
for point in result[0]:
|
63 |
+
chunk_text = point.payload.get('text', '')
|
64 |
+
chunks.append(chunk_text)
|
65 |
+
chunk_ids.append(point.id)
|
66 |
+
|
67 |
+
# Tokenize chunks for BM25
|
68 |
+
tokenized_chunks = [self._tokenize_text(chunk) for chunk in chunks]
|
69 |
+
|
70 |
+
# Create BM25 index
|
71 |
+
self.bm25_indexes[doc_id] = BM25Okapi(tokenized_chunks)
|
72 |
+
self.document_chunks[doc_id] = {
|
73 |
+
'chunks': chunks,
|
74 |
+
'chunk_ids': chunk_ids,
|
75 |
+
'tokenized_chunks': tokenized_chunks
|
76 |
+
}
|
77 |
+
|
78 |
+
print(f"✅ BM25 index loaded for {doc_id} with {len(chunks)} chunks")
|
79 |
+
|
80 |
+
except Exception as e:
|
81 |
+
print(f"❌ Error loading BM25 index for {doc_id}: {e}")
|
82 |
+
# Fallback: empty index
|
83 |
+
self.bm25_indexes[doc_id] = BM25Okapi([[]])
|
84 |
+
self.document_chunks[doc_id] = {'chunks': [], 'chunk_ids': [], 'tokenized_chunks': []}
|
85 |
+
|
86 |
+
def _tokenize_text(self, text: str) -> List[str]:
|
87 |
+
"""Simple tokenization for BM25."""
|
88 |
+
# Remove special characters and convert to lowercase
|
89 |
+
text = re.sub(r'[^\w\s]', ' ', text.lower())
|
90 |
+
# Split and filter empty tokens
|
91 |
+
tokens = [token for token in text.split() if len(token) > 2]
|
92 |
+
return tokens
|
93 |
+
|
94 |
+
async def hybrid_search(self, queries: List[str], doc_id: str, top_k: int = TOP_K) -> List[Dict]:
|
95 |
+
"""
|
96 |
+
Perform hybrid search combining BM25 and semantic search.
|
97 |
+
Optimized for focused sub-queries from query breakdown.
|
98 |
+
Uses total budget approach to distribute retrieval across queries.
|
99 |
+
"""
|
100 |
+
collection_name = f"{doc_id}_collection"
|
101 |
+
client = self.get_qdrant_client(doc_id)
|
102 |
+
|
103 |
+
# Ensure BM25 index is loaded
|
104 |
+
if doc_id not in self.bm25_indexes:
|
105 |
+
self._load_bm25_index(doc_id)
|
106 |
+
|
107 |
+
# Calculate per-query budget based on approach
|
108 |
+
if USE_TOTAL_BUDGET_APPROACH and len(queries) > 1:
|
109 |
+
per_query_budget = max(1, top_k // len(queries))
|
110 |
+
extra_budget = top_k % len(queries) # Distribute remaining budget
|
111 |
+
print(f"🎯 Total Budget Approach: Distributing {top_k} candidates across {len(queries)} queries")
|
112 |
+
print(f" 📊 Base budget per query: {per_query_budget}")
|
113 |
+
if extra_budget > 0:
|
114 |
+
print(f" ➕ Extra budget for first {extra_budget} queries: +1 each")
|
115 |
+
else:
|
116 |
+
per_query_budget = top_k
|
117 |
+
extra_budget = 0
|
118 |
+
print(f"🔍 Per-Query Approach: Each query gets {per_query_budget} candidates")
|
119 |
+
|
120 |
+
all_candidates = {} # point_id -> {'score': float, 'payload': dict, 'source': str}
|
121 |
+
query_performance = {} # Track performance of each sub-query
|
122 |
+
|
123 |
+
print(f"🔍 Running hybrid search with {len(queries)} focused queries...")
|
124 |
+
|
125 |
+
for query_idx, query in enumerate(queries):
|
126 |
+
query_candidates = 0
|
127 |
+
query_start = time.time()
|
128 |
+
|
129 |
+
# Calculate this query's budget
|
130 |
+
if USE_TOTAL_BUDGET_APPROACH and len(queries) > 1:
|
131 |
+
query_budget = per_query_budget + (1 if query_idx < extra_budget else 0)
|
132 |
+
search_limit = query_budget * 2 # Get extra for better selection
|
133 |
+
else:
|
134 |
+
query_budget = per_query_budget
|
135 |
+
search_limit = query_budget * 2
|
136 |
+
|
137 |
+
print(f" Q{query_idx+1} Budget: {query_budget} candidates (searching {search_limit})")
|
138 |
+
|
139 |
+
# 1. Semantic Search
|
140 |
+
if ENABLE_HYBRID_SEARCH or not ENABLE_HYBRID_SEARCH: # Always do semantic
|
141 |
+
try:
|
142 |
+
query_vector = await self.embedding_manager.encode_query(query)
|
143 |
+
semantic_results = client.search(
|
144 |
+
collection_name=collection_name,
|
145 |
+
query_vector=query_vector,
|
146 |
+
limit=search_limit, # Use query-specific limit
|
147 |
+
score_threshold=SCORE_THRESHOLD
|
148 |
+
)
|
149 |
+
|
150 |
+
# Process semantic results with budget limit
|
151 |
+
semantic_count = 0
|
152 |
+
for result in semantic_results:
|
153 |
+
if USE_TOTAL_BUDGET_APPROACH and semantic_count >= query_budget:
|
154 |
+
break # Respect budget limit
|
155 |
+
|
156 |
+
point_id = str(result.id)
|
157 |
+
semantic_score = float(result.score)
|
158 |
+
|
159 |
+
if point_id not in all_candidates:
|
160 |
+
all_candidates[point_id] = {
|
161 |
+
'semantic_score': 0,
|
162 |
+
'bm25_score': 0,
|
163 |
+
'payload': result.payload,
|
164 |
+
'fusion_score': 0,
|
165 |
+
'contributing_queries': []
|
166 |
+
}
|
167 |
+
|
168 |
+
# Use max score across queries for semantic, but track which queries contributed
|
169 |
+
if semantic_score > all_candidates[point_id]['semantic_score']:
|
170 |
+
all_candidates[point_id]['semantic_score'] = semantic_score
|
171 |
+
|
172 |
+
all_candidates[point_id]['contributing_queries'].append({
|
173 |
+
'query_idx': query_idx,
|
174 |
+
'query_text': query[:50] + '...' if len(query) > 50 else query,
|
175 |
+
'semantic_score': semantic_score,
|
176 |
+
'type': 'semantic'
|
177 |
+
})
|
178 |
+
query_candidates += 1
|
179 |
+
semantic_count += 1
|
180 |
+
|
181 |
+
except Exception as e:
|
182 |
+
print(f"⚠️ Semantic search failed for query '{query[:50]}...': {e}")
|
183 |
+
|
184 |
+
# 2. BM25 Search (if enabled)
|
185 |
+
if ENABLE_HYBRID_SEARCH and doc_id in self.bm25_indexes:
|
186 |
+
try:
|
187 |
+
tokenized_query = self._tokenize_text(query)
|
188 |
+
bm25_scores = self.bm25_indexes[doc_id].get_scores(tokenized_query)
|
189 |
+
|
190 |
+
# Get top BM25 results with budget consideration
|
191 |
+
chunk_data = self.document_chunks[doc_id]
|
192 |
+
bm25_top_indices = np.argsort(bm25_scores)[::-1][:search_limit]
|
193 |
+
|
194 |
+
# Process BM25 results with budget limit
|
195 |
+
bm25_count = 0
|
196 |
+
for idx in bm25_top_indices:
|
197 |
+
if USE_TOTAL_BUDGET_APPROACH and bm25_count >= query_budget:
|
198 |
+
break # Respect budget limit
|
199 |
+
|
200 |
+
if idx < len(chunk_data['chunk_ids']) and bm25_scores[idx] > 0:
|
201 |
+
point_id = str(chunk_data['chunk_ids'][idx])
|
202 |
+
bm25_score = float(bm25_scores[idx])
|
203 |
+
|
204 |
+
if point_id not in all_candidates:
|
205 |
+
all_candidates[point_id] = {
|
206 |
+
'semantic_score': 0,
|
207 |
+
'bm25_score': 0,
|
208 |
+
'payload': {'text': chunk_data['chunks'][idx]},
|
209 |
+
'fusion_score': 0,
|
210 |
+
'contributing_queries': []
|
211 |
+
}
|
212 |
+
|
213 |
+
# Use max score across queries for BM25, but track which queries contributed
|
214 |
+
if bm25_score > all_candidates[point_id]['bm25_score']:
|
215 |
+
all_candidates[point_id]['bm25_score'] = bm25_score
|
216 |
+
|
217 |
+
all_candidates[point_id]['contributing_queries'].append({
|
218 |
+
'query_idx': query_idx,
|
219 |
+
'query_text': query[:50] + '...' if len(query) > 50 else query,
|
220 |
+
'bm25_score': bm25_score,
|
221 |
+
'type': 'bm25'
|
222 |
+
})
|
223 |
+
query_candidates += 1
|
224 |
+
bm25_count += 1
|
225 |
+
|
226 |
+
except Exception as e:
|
227 |
+
print(f"⚠️ BM25 search failed for query '{query[:50]}...': {e}")
|
228 |
+
|
229 |
+
# Track query performance with budget info
|
230 |
+
query_time = time.time() - query_start
|
231 |
+
query_performance[query_idx] = {
|
232 |
+
'query': query[:80] + '...' if len(query) > 80 else query,
|
233 |
+
'candidates_found': query_candidates,
|
234 |
+
'budget_allocated': query_budget if USE_TOTAL_BUDGET_APPROACH else 'unlimited',
|
235 |
+
'time': query_time
|
236 |
+
}
|
237 |
+
|
238 |
+
# 3. Score Fusion (Reciprocal Rank Fusion + Weighted Combination)
|
239 |
+
self._apply_score_fusion(all_candidates)
|
240 |
+
|
241 |
+
# 4. Sort by fusion score and return top results
|
242 |
+
sorted_candidates = sorted(
|
243 |
+
all_candidates.items(),
|
244 |
+
key=lambda x: x[1]['fusion_score'],
|
245 |
+
reverse=True
|
246 |
+
)
|
247 |
+
|
248 |
+
# Convert to result format with enhanced metadata
|
249 |
+
hybrid_results = []
|
250 |
+
for point_id, data in sorted_candidates[:top_k]:
|
251 |
+
hybrid_results.append({
|
252 |
+
'id': point_id,
|
253 |
+
'score': data['fusion_score'],
|
254 |
+
'payload': data['payload'],
|
255 |
+
'semantic_score': data['semantic_score'],
|
256 |
+
'bm25_score': data['bm25_score'],
|
257 |
+
'contributing_queries': data['contributing_queries']
|
258 |
+
})
|
259 |
+
|
260 |
+
# Log performance summary
|
261 |
+
approach_name = "Total Budget" if USE_TOTAL_BUDGET_APPROACH else "Per-Query"
|
262 |
+
print(f"🔍 Hybrid search completed ({approach_name} Approach):")
|
263 |
+
print(f" 📊 {len(all_candidates)} total candidates from {len(queries)} focused queries")
|
264 |
+
print(f" 🎯 Top {len(hybrid_results)} results selected")
|
265 |
+
|
266 |
+
# Log per-query performance with budget info
|
267 |
+
total_budget_used = 0
|
268 |
+
for idx, perf in query_performance.items():
|
269 |
+
budget_info = f" (budget: {perf['budget_allocated']})" if USE_TOTAL_BUDGET_APPROACH else ""
|
270 |
+
print(f" Q{idx+1}: {perf['candidates_found']} candidates{budget_info} in {perf['time']:.3f}s")
|
271 |
+
print(f" Query: {perf['query']}")
|
272 |
+
if USE_TOTAL_BUDGET_APPROACH and isinstance(perf['budget_allocated'], int):
|
273 |
+
total_budget_used += perf['candidates_found']
|
274 |
+
|
275 |
+
if USE_TOTAL_BUDGET_APPROACH:
|
276 |
+
print(f" 💰 Total budget efficiency: {total_budget_used}/{top_k} candidates used")
|
277 |
+
|
278 |
+
return hybrid_results
|
279 |
+
|
280 |
+
def _apply_score_fusion(self, candidates: Dict):
|
281 |
+
"""Apply advanced score fusion techniques."""
|
282 |
+
if not candidates:
|
283 |
+
return
|
284 |
+
|
285 |
+
# Normalize scores
|
286 |
+
semantic_scores = [data['semantic_score'] for data in candidates.values() if data['semantic_score'] > 0]
|
287 |
+
bm25_scores = [data['bm25_score'] for data in candidates.values() if data['bm25_score'] > 0]
|
288 |
+
|
289 |
+
# Min-Max normalization
|
290 |
+
if semantic_scores:
|
291 |
+
sem_min, sem_max = min(semantic_scores), max(semantic_scores)
|
292 |
+
sem_range = sem_max - sem_min if sem_max > sem_min else 1
|
293 |
+
else:
|
294 |
+
sem_min, sem_range = 0, 1
|
295 |
+
|
296 |
+
if bm25_scores:
|
297 |
+
bm25_min, bm25_max = min(bm25_scores), max(bm25_scores)
|
298 |
+
bm25_range = bm25_max - bm25_min if bm25_max > bm25_min else 1
|
299 |
+
else:
|
300 |
+
bm25_min, bm25_range = 0, 1
|
301 |
+
|
302 |
+
# Calculate fusion scores
|
303 |
+
for point_id, data in candidates.items():
|
304 |
+
# Normalize scores
|
305 |
+
norm_semantic = (data['semantic_score'] - sem_min) / sem_range if data['semantic_score'] > 0 else 0
|
306 |
+
norm_bm25 = (data['bm25_score'] - bm25_min) / bm25_range if data['bm25_score'] > 0 else 0
|
307 |
+
|
308 |
+
# Weighted combination
|
309 |
+
if ENABLE_HYBRID_SEARCH:
|
310 |
+
fusion_score = (SEMANTIC_WEIGHT * norm_semantic) + (BM25_WEIGHT * norm_bm25)
|
311 |
+
else:
|
312 |
+
fusion_score = norm_semantic
|
313 |
+
|
314 |
+
# Add reciprocal rank fusion bonus (helps with ranking diversity)
|
315 |
+
rank_bonus = 1.0 / (1.0 + max(norm_semantic, norm_bm25) * 10)
|
316 |
+
fusion_score += rank_bonus * 0.1
|
317 |
+
|
318 |
+
data['fusion_score'] = fusion_score
|
319 |
+
|
320 |
+
def cleanup(self):
|
321 |
+
"""Cleanup search manager resources."""
|
322 |
+
print("🧹 Cleaning up Search Manager resources...")
|
323 |
+
|
324 |
+
# Close all Qdrant clients
|
325 |
+
for client in self.qdrant_clients.values():
|
326 |
+
try:
|
327 |
+
client.close()
|
328 |
+
except Exception:
|
329 |
+
pass
|
330 |
+
|
331 |
+
self.qdrant_clients.clear()
|
332 |
+
self.bm25_indexes.clear()
|
333 |
+
self.document_chunks.clear()
|
334 |
+
print("✅ Search Manager cleanup completed")
|
api/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# API Package
|
api/api.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, Depends, Query
|
2 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
3 |
+
from pydantic import BaseModel, HttpUrl
|
4 |
+
from typing import List, Dict, Any, Optional
|
5 |
+
import tempfile
|
6 |
+
import os
|
7 |
+
import hashlib
|
8 |
+
import asyncio
|
9 |
+
import aiohttp
|
10 |
+
import time
|
11 |
+
from contextlib import asynccontextmanager
|
12 |
+
|
13 |
+
from RAG.advanced_rag_processor import AdvancedRAGProcessor
|
14 |
+
from preprocessing.preprocessing import DocumentPreprocessor
|
15 |
+
from logger.logger import rag_logger
|
16 |
+
from LLM.llm_handler import llm_handler
|
17 |
+
from LLM.tabular_answer import get_answer_for_tabluar
|
18 |
+
from LLM.image_answerer import get_answer_for_image
|
19 |
+
from LLM.one_shotter import get_oneshot_answer
|
20 |
+
from config.config import *
|
21 |
+
import config.config as config
|
22 |
+
|
23 |
+
# Initialize security
|
24 |
+
security = HTTPBearer()
|
25 |
+
admin_security = HTTPBearer()
|
26 |
+
|
27 |
+
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
28 |
+
"""Verify the bearer token for main API."""
|
29 |
+
if credentials.credentials != BEARER_TOKEN:
|
30 |
+
raise HTTPException(
|
31 |
+
status_code=401,
|
32 |
+
detail="Invalid authentication token"
|
33 |
+
)
|
34 |
+
return credentials.credentials
|
35 |
+
|
36 |
+
def verify_admin_token(credentials: HTTPAuthorizationCredentials = Depends(admin_security)):
|
37 |
+
"""Verify the bearer token for admin endpoints."""
|
38 |
+
if credentials.credentials != "9420689497":
|
39 |
+
raise HTTPException(
|
40 |
+
status_code=401,
|
41 |
+
detail="Invalid admin authentication token"
|
42 |
+
)
|
43 |
+
return credentials.credentials
|
44 |
+
|
45 |
+
# Pydantic models for request/response
|
46 |
+
class ProcessDocumentRequest(BaseModel):
|
47 |
+
documents: HttpUrl # URL to the PDF document
|
48 |
+
questions: List[str] # List of questions to answer
|
49 |
+
|
50 |
+
class ProcessDocumentResponse(BaseModel):
|
51 |
+
answers: List[str]
|
52 |
+
|
53 |
+
class HealthResponse(BaseModel):
|
54 |
+
status: str
|
55 |
+
message: str
|
56 |
+
|
57 |
+
class PreprocessingResponse(BaseModel):
|
58 |
+
status: str
|
59 |
+
message: str
|
60 |
+
doc_id: str
|
61 |
+
chunk_count: int
|
62 |
+
|
63 |
+
class LogsResponse(BaseModel):
|
64 |
+
export_timestamp: str
|
65 |
+
metadata: Dict[str, Any]
|
66 |
+
logs: List[Dict[str, Any]]
|
67 |
+
|
68 |
+
class LogsSummaryResponse(BaseModel):
|
69 |
+
summary: Dict[str, Any]
|
70 |
+
|
71 |
+
# Global instances
|
72 |
+
rag_processor: Optional[AdvancedRAGProcessor] = None
|
73 |
+
document_preprocessor: Optional[DocumentPreprocessor] = None
|
74 |
+
|
75 |
+
@asynccontextmanager
|
76 |
+
async def lifespan(app: FastAPI):
|
77 |
+
"""Initialize and cleanup the RAG processor."""
|
78 |
+
global rag_processor, document_preprocessor
|
79 |
+
|
80 |
+
# Startup
|
81 |
+
print("🚀 Initializing Advanced RAG System...")
|
82 |
+
rag_processor = AdvancedRAGProcessor() # Use advanced processor for better accuracy
|
83 |
+
document_preprocessor = DocumentPreprocessor()
|
84 |
+
print("✅ Advanced RAG System initialized successfully")
|
85 |
+
|
86 |
+
yield
|
87 |
+
|
88 |
+
# Shutdown
|
89 |
+
print("🔄 Shutting down RAG System...")
|
90 |
+
if rag_processor:
|
91 |
+
rag_processor.cleanup()
|
92 |
+
print("✅ Cleanup completed")
|
93 |
+
|
94 |
+
# FastAPI app with lifespan management
|
95 |
+
app = FastAPI(
|
96 |
+
title="Advanced RAG API",
|
97 |
+
description="API for document processing and question answering using RAG",
|
98 |
+
version="1.0.0",
|
99 |
+
lifespan=lifespan
|
100 |
+
)
|
101 |
+
|
102 |
+
@app.get("/health", response_model=HealthResponse)
|
103 |
+
async def health_check():
|
104 |
+
"""Health check endpoint."""
|
105 |
+
return HealthResponse(
|
106 |
+
status="healthy",
|
107 |
+
message="RAG API is running successfully"
|
108 |
+
)
|
109 |
+
|
110 |
+
@app.post("/hackrx/run", response_model=ProcessDocumentResponse)
|
111 |
+
async def process_document(
|
112 |
+
request: ProcessDocumentRequest,
|
113 |
+
token: str = Depends(verify_token)
|
114 |
+
):
|
115 |
+
"""
|
116 |
+
Process a PDF document and answer questions about it.
|
117 |
+
|
118 |
+
This endpoint implements an optimized flow:
|
119 |
+
1. Check if the document is already processed (pre-computed embeddings)
|
120 |
+
2. If yes, use existing embeddings for fast retrieval + generation
|
121 |
+
3. If no, run full RAG pipeline (download + process + embed + store + answer)
|
122 |
+
|
123 |
+
Args:
|
124 |
+
request: Contains document URL and list of questions
|
125 |
+
token: Bearer token for authentication
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
ProcessDocumentResponse: List of answers corresponding to the questions
|
129 |
+
"""
|
130 |
+
global rag_processor, document_preprocessor
|
131 |
+
|
132 |
+
if not rag_processor or not document_preprocessor:
|
133 |
+
raise HTTPException(
|
134 |
+
status_code=503,
|
135 |
+
detail="RAG system not initialized"
|
136 |
+
)
|
137 |
+
|
138 |
+
# Start timing and logging
|
139 |
+
start_time = time.time()
|
140 |
+
document_url = str(request.documents)
|
141 |
+
questions = request.questions
|
142 |
+
# Initialize answers for safe logging in finally
|
143 |
+
final_answers = []
|
144 |
+
status = "success"
|
145 |
+
error_message = None
|
146 |
+
doc_id = None
|
147 |
+
was_preprocessed = False
|
148 |
+
|
149 |
+
# Initialize enhanced logging
|
150 |
+
request_id = rag_logger.generate_request_id()
|
151 |
+
rag_logger.start_request_timing(request_id)
|
152 |
+
|
153 |
+
try:
|
154 |
+
print(f"📋 [{request_id}] Processing document: {document_url[:50]}...")
|
155 |
+
print(f"🤔 [{request_id}] Number of questions: {len(questions)}")
|
156 |
+
print(f"")
|
157 |
+
print(f"🚀 [{request_id}] ===== STARTING RAG PIPELINE =====")
|
158 |
+
print(f"📌 [{request_id}] PRIORITY 1: Checking stored embeddings database...")
|
159 |
+
|
160 |
+
# Generate document ID
|
161 |
+
doc_id = document_preprocessor.generate_doc_id(document_url)
|
162 |
+
|
163 |
+
# Step 1: Check if document is already processed (stored embeddings)
|
164 |
+
is_processed = document_preprocessor.is_document_processed(document_url)
|
165 |
+
was_preprocessed = is_processed
|
166 |
+
|
167 |
+
if is_processed:
|
168 |
+
print(f"✅ [{request_id}] ✅ FOUND STORED EMBEDDINGS for {doc_id}")
|
169 |
+
print(f"⚡ [{request_id}] Using fast path with pre-computed embeddings")
|
170 |
+
# Fast path: Use existing embeddings
|
171 |
+
doc_info = document_preprocessor.get_document_info(document_url)
|
172 |
+
print(f"📊 [{request_id}] Using existing collection with {doc_info.get('chunk_count', 'N/A')} chunks")
|
173 |
+
else:
|
174 |
+
print(f"❌ [{request_id}] No stored embeddings found for {doc_id}")
|
175 |
+
print(f"📌 [{request_id}] PRIORITY 2: Running full RAG pipeline (download + process + embed)...")
|
176 |
+
# Full path: Download and process document
|
177 |
+
resp = await document_preprocessor.process_document(document_url)
|
178 |
+
|
179 |
+
# Handle different return formats: [content, type] or [content, type, no_cleanup_flag]
|
180 |
+
if isinstance(resp, list):
|
181 |
+
content, _type = resp[0], resp[1]
|
182 |
+
if content == 'unsupported':
|
183 |
+
# Unsupported file type: respond gracefully without throwing server error
|
184 |
+
msg = f"Unsupported file type: {_type.lstrip('.')}"
|
185 |
+
final_answers = [msg]
|
186 |
+
status = "success" # ensure no 500 is raised in the finally block
|
187 |
+
return ProcessDocumentResponse(answers=final_answers)
|
188 |
+
|
189 |
+
if _type == "image":
|
190 |
+
try:
|
191 |
+
final_answers = get_answer_for_image(content, questions)
|
192 |
+
status = "success"
|
193 |
+
return ProcessDocumentResponse(answers=final_answers)
|
194 |
+
finally:
|
195 |
+
# Clean up the image file after processing
|
196 |
+
if os.path.exists(content):
|
197 |
+
os.unlink(content)
|
198 |
+
print(f"🗑️ Cleaned up image file: {content}")
|
199 |
+
|
200 |
+
if _type == "tabular":
|
201 |
+
final_answers = get_answer_for_tabluar(content, questions)
|
202 |
+
status = "success"
|
203 |
+
return ProcessDocumentResponse(answers=final_answers)
|
204 |
+
|
205 |
+
if _type == "oneshot":
|
206 |
+
# Process questions in batches for oneshot
|
207 |
+
tasks = [
|
208 |
+
get_oneshot_answer(content, questions[i:i + 3])
|
209 |
+
for i in range(0, len(questions), 3)
|
210 |
+
]
|
211 |
+
|
212 |
+
# Run all batches in parallel
|
213 |
+
results = await asyncio.gather(*tasks)
|
214 |
+
|
215 |
+
# Flatten results
|
216 |
+
final_answers = [ans for batch in results for ans in batch]
|
217 |
+
status = "success"
|
218 |
+
return ProcessDocumentResponse(answers=final_answers)
|
219 |
+
else:
|
220 |
+
doc_id = resp
|
221 |
+
|
222 |
+
print(f"✅ [{request_id}] Document {doc_id} processed and stored")
|
223 |
+
|
224 |
+
# Answer all questions using parallel processing for better latency
|
225 |
+
print(f"🚀 [{request_id}] Processing {len(questions)} questions in parallel...")
|
226 |
+
|
227 |
+
async def answer_single_question(question: str, index: int) -> tuple[str, Dict[str, float]]:
|
228 |
+
"""Answer a single question with error handling and timing."""
|
229 |
+
try:
|
230 |
+
question_start = time.time()
|
231 |
+
print(f"❓ [{request_id}] Q{index+1}: {question[:50]}...")
|
232 |
+
|
233 |
+
answer, pipeline_timings = await rag_processor.answer_question(
|
234 |
+
question=question,
|
235 |
+
doc_id=doc_id,
|
236 |
+
logger=rag_logger,
|
237 |
+
request_id=request_id
|
238 |
+
)
|
239 |
+
|
240 |
+
question_time = time.time() - question_start
|
241 |
+
|
242 |
+
# Log question timing
|
243 |
+
rag_logger.log_question_timing(
|
244 |
+
request_id, index, question, answer, question_time, pipeline_timings
|
245 |
+
)
|
246 |
+
|
247 |
+
print(f"✅ [{request_id}] Q{index+1} completed in {question_time:.4f}s")
|
248 |
+
return answer, pipeline_timings
|
249 |
+
except Exception as e:
|
250 |
+
print(f"❌ [{request_id}] Q{index+1} Error: {str(e)}")
|
251 |
+
return f"I encountered an error while processing this question: {str(e)}", {}
|
252 |
+
|
253 |
+
|
254 |
+
# Process questions in parallel with controlled concurrency
|
255 |
+
semaphore = asyncio.Semaphore(3) # Reduced concurrency for better logging visibility
|
256 |
+
|
257 |
+
async def bounded_answer(question: str, index: int) -> tuple[str, Dict[str, float]]:
|
258 |
+
async with semaphore:
|
259 |
+
return await answer_single_question(question, index)
|
260 |
+
|
261 |
+
# Execute all questions concurrently
|
262 |
+
tasks = [
|
263 |
+
bounded_answer(question, i)
|
264 |
+
for i, question in enumerate(questions)
|
265 |
+
]
|
266 |
+
|
267 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
268 |
+
|
269 |
+
# Handle any exceptions in answers
|
270 |
+
final_answers = []
|
271 |
+
error_count = 0
|
272 |
+
for i, result in enumerate(results):
|
273 |
+
if isinstance(result, Exception):
|
274 |
+
error_count += 1
|
275 |
+
final_answers.append(f"Error processing question {i+1}: {str(result)}")
|
276 |
+
else:
|
277 |
+
answer, _ = result
|
278 |
+
final_answers.append(answer)
|
279 |
+
|
280 |
+
# Determine final status
|
281 |
+
if error_count == 0:
|
282 |
+
status = "success"
|
283 |
+
elif error_count == len(questions):
|
284 |
+
status = "error"
|
285 |
+
else:
|
286 |
+
status = "partial"
|
287 |
+
|
288 |
+
print(f"✅ [{request_id}] Successfully processed {len(questions) - error_count}/{len(questions)} questions")
|
289 |
+
|
290 |
+
except Exception as e:
|
291 |
+
print(f"❌ [{request_id}] Error processing request: {str(e)}")
|
292 |
+
status = "error"
|
293 |
+
error_message = str(e)
|
294 |
+
final_answers = [f"Error: {str(e)}" for _ in questions]
|
295 |
+
|
296 |
+
finally:
|
297 |
+
# End request timing and get detailed timing data
|
298 |
+
timing_data = rag_logger.end_request_timing(request_id)
|
299 |
+
|
300 |
+
# Log the request with enhanced timing
|
301 |
+
processing_time = time.time() - start_time
|
302 |
+
logged_request_id = rag_logger.log_request(
|
303 |
+
document_url=document_url,
|
304 |
+
questions=questions,
|
305 |
+
answers=final_answers,
|
306 |
+
processing_time=processing_time,
|
307 |
+
status=status,
|
308 |
+
error_message=error_message,
|
309 |
+
document_id=doc_id,
|
310 |
+
was_preprocessed=was_preprocessed,
|
311 |
+
timing_data=timing_data
|
312 |
+
)
|
313 |
+
|
314 |
+
print(f"📊 Request logged with ID: {logged_request_id} (Status: {status}, Time: {processing_time:.2f}s)")
|
315 |
+
|
316 |
+
if status == "error":
|
317 |
+
raise HTTPException(
|
318 |
+
status_code=500,
|
319 |
+
detail=f"Failed to process document: {error_message}"
|
320 |
+
)
|
321 |
+
|
322 |
+
return ProcessDocumentResponse(answers=final_answers)
|
323 |
+
|
324 |
+
@app.post("/preprocess", response_model=PreprocessingResponse)
|
325 |
+
async def preprocess_document(document_url: str, force: bool = False, token: str = Depends(verify_admin_token)):
|
326 |
+
"""
|
327 |
+
Preprocess a document (for batch preprocessing).
|
328 |
+
|
329 |
+
Args:
|
330 |
+
document_url: URL of the PDF to preprocess
|
331 |
+
force: Whether to reprocess if already processed
|
332 |
+
|
333 |
+
Returns:
|
334 |
+
PreprocessingResponse: Status and document info
|
335 |
+
"""
|
336 |
+
global document_preprocessor
|
337 |
+
|
338 |
+
if not document_preprocessor:
|
339 |
+
raise HTTPException(
|
340 |
+
status_code=503,
|
341 |
+
detail="Document preprocessor not initialized"
|
342 |
+
)
|
343 |
+
|
344 |
+
try:
|
345 |
+
doc_id = await document_preprocessor.process_document(document_url, force)
|
346 |
+
doc_info = document_preprocessor.get_document_info(document_url)
|
347 |
+
|
348 |
+
return PreprocessingResponse(
|
349 |
+
status="success",
|
350 |
+
message=f"Document processed successfully",
|
351 |
+
doc_id=doc_id,
|
352 |
+
chunk_count=doc_info.get("chunk_count", 0)
|
353 |
+
)
|
354 |
+
|
355 |
+
except Exception as e:
|
356 |
+
raise HTTPException(
|
357 |
+
status_code=500,
|
358 |
+
detail=f"Failed to preprocess document: {str(e)}"
|
359 |
+
)
|
360 |
+
|
361 |
+
@app.get("/collections")
|
362 |
+
async def list_collections(token: str = Depends(verify_admin_token)):
|
363 |
+
"""List all available document collections."""
|
364 |
+
global document_preprocessor
|
365 |
+
|
366 |
+
if not document_preprocessor:
|
367 |
+
raise HTTPException(
|
368 |
+
status_code=503,
|
369 |
+
detail="Document preprocessor not initialized"
|
370 |
+
)
|
371 |
+
|
372 |
+
try:
|
373 |
+
processed_docs = document_preprocessor.list_processed_documents()
|
374 |
+
return {"collections": processed_docs}
|
375 |
+
except Exception as e:
|
376 |
+
raise HTTPException(
|
377 |
+
status_code=500,
|
378 |
+
detail=f"Failed to list collections: {str(e)}"
|
379 |
+
)
|
380 |
+
|
381 |
+
@app.get("/collections/stats")
|
382 |
+
async def get_collection_stats(token: str = Depends(verify_admin_token)):
|
383 |
+
"""Get statistics about all collections."""
|
384 |
+
global document_preprocessor
|
385 |
+
|
386 |
+
if not document_preprocessor:
|
387 |
+
raise HTTPException(
|
388 |
+
status_code=503,
|
389 |
+
detail="Document preprocessor not initialized"
|
390 |
+
)
|
391 |
+
|
392 |
+
try:
|
393 |
+
stats = document_preprocessor.get_collection_stats()
|
394 |
+
return stats
|
395 |
+
except Exception as e:
|
396 |
+
raise HTTPException(
|
397 |
+
status_code=500,
|
398 |
+
detail=f"Failed to get collection stats: {str(e)}"
|
399 |
+
)
|
400 |
+
|
401 |
+
# Logging Endpoints
|
402 |
+
@app.get("/logs", response_model=LogsResponse)
|
403 |
+
async def get_logs(
|
404 |
+
token: str = Depends(verify_admin_token),
|
405 |
+
limit: Optional[int] = Query(None, description="Maximum number of logs to return"),
|
406 |
+
minutes: Optional[int] = Query(None, description="Get logs from last N minutes"),
|
407 |
+
document_url: Optional[str] = Query(None, description="Filter logs by document URL")
|
408 |
+
):
|
409 |
+
"""
|
410 |
+
Export all API request logs as JSON.
|
411 |
+
|
412 |
+
Query Parameters:
|
413 |
+
limit: Maximum number of recent logs to return
|
414 |
+
minutes: Get logs from the last N minutes
|
415 |
+
document_url: Filter logs for a specific document URL
|
416 |
+
|
417 |
+
Returns:
|
418 |
+
LogsResponse: Complete logs export with metadata
|
419 |
+
"""
|
420 |
+
try:
|
421 |
+
if document_url:
|
422 |
+
# Get logs for specific document
|
423 |
+
logs = rag_logger.get_logs_by_document(document_url)
|
424 |
+
metadata = {
|
425 |
+
"filtered_by": "document_url",
|
426 |
+
"document_url": document_url,
|
427 |
+
"total_logs": len(logs)
|
428 |
+
}
|
429 |
+
return LogsResponse(
|
430 |
+
export_timestamp=rag_logger.export_logs()["export_timestamp"],
|
431 |
+
metadata=metadata,
|
432 |
+
logs=logs
|
433 |
+
)
|
434 |
+
|
435 |
+
elif minutes:
|
436 |
+
# Get recent logs
|
437 |
+
logs = rag_logger.get_recent_logs(minutes)
|
438 |
+
metadata = {
|
439 |
+
"filtered_by": "time_range",
|
440 |
+
"minutes": minutes,
|
441 |
+
"total_logs": len(logs)
|
442 |
+
}
|
443 |
+
return LogsResponse(
|
444 |
+
export_timestamp=rag_logger.export_logs()["export_timestamp"],
|
445 |
+
metadata=metadata,
|
446 |
+
logs=logs
|
447 |
+
)
|
448 |
+
|
449 |
+
else:
|
450 |
+
# Get all logs (with optional limit)
|
451 |
+
if limit:
|
452 |
+
logs = rag_logger.get_logs(limit)
|
453 |
+
metadata = rag_logger.get_logs_summary()
|
454 |
+
metadata["limited_to"] = limit
|
455 |
+
else:
|
456 |
+
logs_export = rag_logger.export_logs()
|
457 |
+
return LogsResponse(**logs_export)
|
458 |
+
|
459 |
+
return LogsResponse(
|
460 |
+
export_timestamp=rag_logger.export_logs()["export_timestamp"],
|
461 |
+
metadata=metadata,
|
462 |
+
logs=logs
|
463 |
+
)
|
464 |
+
|
465 |
+
except Exception as e:
|
466 |
+
raise HTTPException(
|
467 |
+
status_code=500,
|
468 |
+
detail=f"Failed to export logs: {str(e)}"
|
469 |
+
)
|
470 |
+
|
471 |
+
@app.get("/logs/summary", response_model=LogsSummaryResponse)
|
472 |
+
async def get_logs_summary(token: str = Depends(verify_admin_token)):
|
473 |
+
"""
|
474 |
+
Get summary statistics of all logs.
|
475 |
+
|
476 |
+
Returns:
|
477 |
+
LogsSummaryResponse: Summary statistics
|
478 |
+
"""
|
479 |
+
try:
|
480 |
+
summary = rag_logger.get_logs_summary()
|
481 |
+
return LogsSummaryResponse(summary=summary)
|
482 |
+
except Exception as e:
|
483 |
+
raise HTTPException(
|
484 |
+
status_code=500,
|
485 |
+
detail=f"Failed to get logs summary: {str(e)}"
|
486 |
+
)
|
487 |
+
|
488 |
+
if __name__ == "__main__":
|
489 |
+
import uvicorn
|
490 |
+
|
491 |
+
# Run the FastAPI server
|
492 |
+
uvicorn.run(
|
493 |
+
"api:app",
|
494 |
+
host=API_HOST,
|
495 |
+
port=API_PORT,
|
496 |
+
reload=API_RELOAD,
|
497 |
+
log_level="info"
|
498 |
+
)
|
config/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Config Package
|
config/config.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RAG Configuration File
|
2 |
+
# Update these settings as needed
|
3 |
+
|
4 |
+
import os
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
|
7 |
+
# Load environment variables from .env file
|
8 |
+
load_dotenv()
|
9 |
+
|
10 |
+
# Common LLM Settings
|
11 |
+
MAX_TOKENS = 1200
|
12 |
+
TEMPERATURE = 0.4
|
13 |
+
|
14 |
+
# OCR Settings
|
15 |
+
OCR_SPACE_API_KEY = os.getenv("OCR_SPACE_API_KEY", "")
|
16 |
+
|
17 |
+
# OpenAI Settings
|
18 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
19 |
+
OPENAI_MODEL = "gpt-3.5-turbo"
|
20 |
+
|
21 |
+
# Gemini Settings
|
22 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
|
23 |
+
GEMINI_MODEL = "gemini-1.5-flash"
|
24 |
+
|
25 |
+
# Groq Settings
|
26 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
27 |
+
GROQ_MODEL = os.getenv("GROQ_MODEL", "llama3-70b-8192")
|
28 |
+
|
29 |
+
GROQ_API_KEY_LITE = os.getenv("GROQ_API_KEY_LITE")
|
30 |
+
GROQ_MODEL_LITE = "llama3-8b-8192"
|
31 |
+
|
32 |
+
# API Authentication
|
33 |
+
BEARER_TOKEN = os.getenv("BEARER_TOKEN", "c6cee5b5046310e401632a7effe9c684d071a9ef5ce09b96c9ec5c3ebd13085e")
|
34 |
+
|
35 |
+
# Chunking (TOKEN-BASED)
|
36 |
+
CHUNK_SIZE = 1600 #
|
37 |
+
CHUNK_OVERLAP = 200
|
38 |
+
|
39 |
+
# Retrieval Settings
|
40 |
+
TOP_K = 12
|
41 |
+
SCORE_THRESHOLD = 0.3
|
42 |
+
RERANK_TOP_K = 9
|
43 |
+
BM25_WEIGHT = 0.3
|
44 |
+
SEMANTIC_WEIGHT = 0.7
|
45 |
+
|
46 |
+
# Advanced RAG Settings
|
47 |
+
ENABLE_RERANKING = True
|
48 |
+
ENABLE_HYBRID_SEARCH = True
|
49 |
+
ENABLE_QUERY_EXPANSION = True
|
50 |
+
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
51 |
+
QUERY_EXPANSION_COUNT = 3
|
52 |
+
MAX_CONTEXT_LENGTH = 15000
|
53 |
+
|
54 |
+
USE_TOTAL_BUDGET_APPROACH = True
|
55 |
+
|
56 |
+
# Embedding Settings
|
57 |
+
EMBEDDING_MODEL = "BAAI/bge-large-en"
|
58 |
+
BATCH_SIZE = 16
|
59 |
+
|
60 |
+
# Paths
|
61 |
+
OUTPUT_DIR = os.getenv("RAG_EMBEDDINGS_PATH", "./RAG/rag_embeddings")
|
62 |
+
|
63 |
+
# API Settings
|
64 |
+
API_HOST = "0.0.0.0"
|
65 |
+
API_PORT = 8000
|
66 |
+
API_RELOAD = True
|
67 |
+
|
68 |
+
# Multi-LLM failover system
|
69 |
+
sequence = ["primary", "secondary", "ternary", "quaternary", "quinary", "senary", "septenary", "octonary", "nonary", "denary"]
|
70 |
+
|
71 |
+
def get_provider_configs():
|
72 |
+
"""
|
73 |
+
Get configurations for all provider instances with failover support.
|
74 |
+
Supports multiple instances of each provider type for reliability.
|
75 |
+
"""
|
76 |
+
configs = {
|
77 |
+
"groq": [],
|
78 |
+
"gemini": [],
|
79 |
+
"openai": []
|
80 |
+
}
|
81 |
+
|
82 |
+
# Groq configurations with multiple API keys for failover
|
83 |
+
DEFAULT_GROQ_MODEL = "qwen/qwen3-32b"
|
84 |
+
configs["groq"] = [{
|
85 |
+
"name": sequence[i-1],
|
86 |
+
"api_key": os.getenv(f"GROQ_API_KEY_{i}"),
|
87 |
+
"model": os.getenv(f"GROQ_MODEL_{i}", DEFAULT_GROQ_MODEL)} for i in range(1, 10) if os.getenv(f"GROQ_API_KEY_{i}", "")
|
88 |
+
]
|
89 |
+
|
90 |
+
# Add main GROQ key as primary
|
91 |
+
if os.getenv("GROQ_API_KEY"):
|
92 |
+
configs["groq"].insert(0, {
|
93 |
+
"name": "main",
|
94 |
+
"api_key": os.getenv("GROQ_API_KEY"),
|
95 |
+
"model": DEFAULT_GROQ_MODEL
|
96 |
+
})
|
97 |
+
|
98 |
+
# Gemini configurations with multiple API keys for failover
|
99 |
+
DEFAULT_GEMINI_MODEL = "gemini-1.5-flash"
|
100 |
+
configs["gemini"] = [{
|
101 |
+
"name": sequence[i-1],
|
102 |
+
"api_key": os.getenv(f"GEMINI_API_KEY_{i}"),
|
103 |
+
"model": os.getenv(f"GEMINI_MODEL_{i}", DEFAULT_GEMINI_MODEL)
|
104 |
+
} for i in range(1, 10) if os.getenv(f"GEMINI_API_KEY_{i}", "")
|
105 |
+
]
|
106 |
+
|
107 |
+
# Add main GEMINI key as primary
|
108 |
+
if os.getenv("GEMINI_API_KEY"):
|
109 |
+
configs["gemini"].insert(0, {
|
110 |
+
"name": "main",
|
111 |
+
"api_key": os.getenv("GEMINI_API_KEY"),
|
112 |
+
"model": DEFAULT_GEMINI_MODEL
|
113 |
+
})
|
114 |
+
|
115 |
+
# OpenAI configurations with multiple API keys for failover
|
116 |
+
DEFAULT_OPENAI_MODEL = "gpt-4o-mini"
|
117 |
+
configs["openai"] = [{
|
118 |
+
"name": sequence[i-1],
|
119 |
+
"api_key": os.getenv(f"OPENAI_API_KEY_{i}"),
|
120 |
+
"model": os.getenv(f"OPENAI_MODEL_{i}", DEFAULT_OPENAI_MODEL)
|
121 |
+
} for i in range(1, 10) if os.getenv(f"OPENAI_API_KEY_{i}", "")
|
122 |
+
]
|
123 |
+
|
124 |
+
# Add main OPENAI key as primary
|
125 |
+
if os.getenv("OPENAI_API_KEY"):
|
126 |
+
configs["openai"].insert(0, {
|
127 |
+
"name": "main",
|
128 |
+
"api_key": os.getenv("OPENAI_API_KEY"),
|
129 |
+
"model": DEFAULT_OPENAI_MODEL
|
130 |
+
})
|
131 |
+
|
132 |
+
return configs
|
133 |
+
|
134 |
+
# Specialized API keys for different tasks
|
135 |
+
GROQ_API_KEY_TABULAR = os.getenv("GROQ_API_KEY_TABULAR", GROQ_API_KEY)
|
136 |
+
GEMINI_API_KEY_IMAGE = os.getenv("GEMINI_API_KEY_IMAGE", GEMINI_API_KEY)
|
137 |
+
GEMINI_API_KEY_MULTILINGUAL = os.getenv("GEMINI_API_KEY_MULTILINGUAL", GEMINI_API_KEY)
|
138 |
+
|
139 |
+
# Validation (optional - comment out for production)
|
140 |
+
# assert OPENAI_API_KEY, "OPENAI KEY NOT SET"
|
141 |
+
# assert GEMINI_API_KEY, "GEMINI KEY NOT SET"
|
142 |
+
# assert GROQ_API_KEY, "GROQ KEY NOT SET"
|
logger/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Logger Package
|
logger/logger.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Enhanced in-memory logging system for RAG API with detailed pipeline timing.
|
3 |
+
Since HuggingFace doesn't allow persistent file storage, logs are stored in memory.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
from datetime import datetime
|
9 |
+
from typing import List, Dict, Any, Optional
|
10 |
+
from dataclasses import dataclass, asdict, field
|
11 |
+
import threading
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class PipelineTimings:
|
15 |
+
"""Detailed timing for each stage of the RAG pipeline."""
|
16 |
+
query_expansion_time: float = 0.0
|
17 |
+
hybrid_search_time: float = 0.0
|
18 |
+
semantic_search_time: float = 0.0
|
19 |
+
bm25_search_time: float = 0.0
|
20 |
+
score_fusion_time: float = 0.0
|
21 |
+
reranking_time: float = 0.0
|
22 |
+
context_creation_time: float = 0.0
|
23 |
+
llm_generation_time: float = 0.0
|
24 |
+
total_pipeline_time: float = 0.0
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class LogEntry:
|
28 |
+
"""Enhanced structure for a single log entry with detailed timing."""
|
29 |
+
timestamp: str
|
30 |
+
request_id: str
|
31 |
+
document_url: str
|
32 |
+
questions: List[str]
|
33 |
+
answers: List[str]
|
34 |
+
processing_time_seconds: float
|
35 |
+
total_questions: int
|
36 |
+
status: str # 'success', 'error', 'partial'
|
37 |
+
error_message: Optional[str] = None
|
38 |
+
document_id: Optional[str] = None
|
39 |
+
was_preprocessed: bool = False
|
40 |
+
# Enhanced timing details
|
41 |
+
request_start_time: str = ""
|
42 |
+
request_end_time: str = ""
|
43 |
+
pipeline_timings: Dict[str, Any] = field(default_factory=dict)
|
44 |
+
# Per-question timings
|
45 |
+
question_timings: List[Dict[str, Any]] = field(default_factory=list)
|
46 |
+
|
47 |
+
class RAGLogger:
|
48 |
+
"""Enhanced in-memory logging system for RAG API requests with detailed pipeline timing."""
|
49 |
+
|
50 |
+
def __init__(self):
|
51 |
+
self.logs: List[LogEntry] = []
|
52 |
+
self.server_start_time = datetime.now().isoformat()
|
53 |
+
self.request_counter = 0
|
54 |
+
self._lock = threading.Lock()
|
55 |
+
# Active request tracking for timing
|
56 |
+
self._active_requests: Dict[str, Dict[str, Any]] = {}
|
57 |
+
|
58 |
+
def generate_request_id(self) -> str:
|
59 |
+
"""Generate a unique request ID."""
|
60 |
+
with self._lock:
|
61 |
+
self.request_counter += 1
|
62 |
+
return f"req_{self.request_counter:06d}"
|
63 |
+
|
64 |
+
def start_request_timing(self, request_id: str) -> None:
|
65 |
+
"""Start timing for a new request."""
|
66 |
+
self._active_requests[request_id] = {
|
67 |
+
'start_time': time.time(),
|
68 |
+
'start_timestamp': datetime.now().isoformat(),
|
69 |
+
'pipeline_stages': {},
|
70 |
+
'question_timings': []
|
71 |
+
}
|
72 |
+
|
73 |
+
def log_pipeline_stage(self, request_id: str, stage_name: str, duration: float) -> None:
|
74 |
+
"""Log the timing for a specific pipeline stage."""
|
75 |
+
if request_id in self._active_requests:
|
76 |
+
self._active_requests[request_id]['pipeline_stages'][stage_name] = {
|
77 |
+
'duration_seconds': round(duration, 4),
|
78 |
+
'timestamp': datetime.now().isoformat()
|
79 |
+
}
|
80 |
+
print(f"⏱️ [{request_id}] {stage_name}: {duration:.4f}s")
|
81 |
+
|
82 |
+
def log_question_timing(self, request_id: str, question_index: int, question: str,
|
83 |
+
answer: str, duration: float, pipeline_timings: Dict[str, float]) -> None:
|
84 |
+
"""Log timing for individual question processing."""
|
85 |
+
if request_id in self._active_requests:
|
86 |
+
question_timing = {
|
87 |
+
'question_index': question_index,
|
88 |
+
'question': question[:100] + "..." if len(question) > 100 else question,
|
89 |
+
'answer_length': len(answer),
|
90 |
+
'total_time_seconds': round(duration, 4),
|
91 |
+
'pipeline_breakdown': {k: round(v, 4) for k, v in pipeline_timings.items()},
|
92 |
+
'timestamp': datetime.now().isoformat()
|
93 |
+
}
|
94 |
+
self._active_requests[request_id]['question_timings'].append(question_timing)
|
95 |
+
|
96 |
+
# Enhanced console logging
|
97 |
+
print(f"\n❓ [{request_id}] Question {question_index + 1}: {question[:60]}...")
|
98 |
+
print(f" 📊 Processing time: {duration:.4f}s")
|
99 |
+
if pipeline_timings:
|
100 |
+
breakdown_str = " | ".join([f"{k}: {v:.4f}s" for k, v in pipeline_timings.items() if v > 0])
|
101 |
+
if breakdown_str:
|
102 |
+
print(f" ⚙️ Pipeline breakdown: {breakdown_str}")
|
103 |
+
print(f" 💬 Answer length: {len(answer)} characters")
|
104 |
+
|
105 |
+
def end_request_timing(self, request_id: str) -> Dict[str, Any]:
|
106 |
+
"""End timing for a request and return timing data."""
|
107 |
+
if request_id not in self._active_requests:
|
108 |
+
return {}
|
109 |
+
|
110 |
+
request_data = self._active_requests[request_id]
|
111 |
+
total_time = time.time() - request_data['start_time']
|
112 |
+
|
113 |
+
timing_data = {
|
114 |
+
'start_time': request_data['start_timestamp'],
|
115 |
+
'end_time': datetime.now().isoformat(),
|
116 |
+
'total_time_seconds': round(total_time, 4),
|
117 |
+
'pipeline_stages': request_data['pipeline_stages'],
|
118 |
+
'question_timings': request_data['question_timings']
|
119 |
+
}
|
120 |
+
|
121 |
+
# Cleanup
|
122 |
+
del self._active_requests[request_id]
|
123 |
+
|
124 |
+
return timing_data
|
125 |
+
|
126 |
+
def log_request(
|
127 |
+
self,
|
128 |
+
document_url: str,
|
129 |
+
questions: List[str],
|
130 |
+
answers: List[str],
|
131 |
+
processing_time: float,
|
132 |
+
status: str = "success",
|
133 |
+
error_message: Optional[str] = None,
|
134 |
+
document_id: Optional[str] = None,
|
135 |
+
was_preprocessed: bool = False,
|
136 |
+
timing_data: Optional[Dict[str, Any]] = None
|
137 |
+
) -> str:
|
138 |
+
"""
|
139 |
+
Log a RAG API request with enhanced timing information.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
document_url: URL of the document processed
|
143 |
+
questions: List of questions asked
|
144 |
+
answers: List of answers generated
|
145 |
+
processing_time: Time taken in seconds
|
146 |
+
status: Request status ('success', 'error', 'partial')
|
147 |
+
error_message: Error message if any
|
148 |
+
document_id: Generated document ID
|
149 |
+
was_preprocessed: Whether document was already processed
|
150 |
+
timing_data: Detailed timing breakdown from pipeline
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
str: Request ID
|
154 |
+
"""
|
155 |
+
request_id = self.generate_request_id()
|
156 |
+
|
157 |
+
# Extract timing information
|
158 |
+
pipeline_timings = {}
|
159 |
+
question_timings = []
|
160 |
+
request_start_time = ""
|
161 |
+
request_end_time = ""
|
162 |
+
|
163 |
+
if timing_data:
|
164 |
+
request_start_time = timing_data.get('start_time', '')
|
165 |
+
request_end_time = timing_data.get('end_time', '')
|
166 |
+
pipeline_timings = timing_data.get('pipeline_stages', {})
|
167 |
+
question_timings = timing_data.get('question_timings', [])
|
168 |
+
|
169 |
+
log_entry = LogEntry(
|
170 |
+
timestamp=datetime.now().isoformat(),
|
171 |
+
request_id=request_id,
|
172 |
+
document_url=document_url,
|
173 |
+
questions=questions,
|
174 |
+
answers=answers,
|
175 |
+
processing_time_seconds=round(processing_time, 2),
|
176 |
+
total_questions=len(questions),
|
177 |
+
status=status,
|
178 |
+
error_message=error_message,
|
179 |
+
document_id=document_id,
|
180 |
+
was_preprocessed=was_preprocessed,
|
181 |
+
request_start_time=request_start_time,
|
182 |
+
request_end_time=request_end_time,
|
183 |
+
pipeline_timings=pipeline_timings,
|
184 |
+
question_timings=question_timings
|
185 |
+
)
|
186 |
+
|
187 |
+
with self._lock:
|
188 |
+
self.logs.append(log_entry)
|
189 |
+
|
190 |
+
# Enhanced console logging summary
|
191 |
+
print(f"\n📊 [{request_id}] REQUEST COMPLETED:")
|
192 |
+
print(f" 🕐 Duration: {processing_time:.2f}s")
|
193 |
+
print(f" 📄 Document: {document_url[:60]}...")
|
194 |
+
print(f" ❓ Questions processed: {len(questions)}")
|
195 |
+
print(f" ✅ Status: {status.upper()}")
|
196 |
+
|
197 |
+
if pipeline_timings:
|
198 |
+
print(f" ⚙️ Pipeline performance:")
|
199 |
+
for stage, data in pipeline_timings.items():
|
200 |
+
duration = data.get('duration_seconds', 0)
|
201 |
+
print(f" • {stage.replace('_', ' ').title()}: {duration:.4f}s")
|
202 |
+
|
203 |
+
if error_message:
|
204 |
+
print(f" ❌ Error: {error_message}")
|
205 |
+
|
206 |
+
print(f" 🆔 Request ID: {request_id}")
|
207 |
+
print(" " + "="*50)
|
208 |
+
|
209 |
+
return request_id
|
210 |
+
|
211 |
+
def get_logs(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
212 |
+
"""
|
213 |
+
Get all logs as a list of dictionaries.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
limit: Maximum number of logs to return (most recent first)
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
List of log entries as dictionaries
|
220 |
+
"""
|
221 |
+
with self._lock:
|
222 |
+
logs_list = [asdict(log) for log in self.logs]
|
223 |
+
|
224 |
+
# Return most recent first
|
225 |
+
logs_list.reverse()
|
226 |
+
|
227 |
+
if limit:
|
228 |
+
logs_list = logs_list[:limit]
|
229 |
+
|
230 |
+
return logs_list
|
231 |
+
|
232 |
+
def get_logs_summary(self) -> Dict[str, Any]:
|
233 |
+
"""Get summary statistics of all logs."""
|
234 |
+
with self._lock:
|
235 |
+
total_requests = len(self.logs)
|
236 |
+
if total_requests == 0:
|
237 |
+
return {
|
238 |
+
"server_start_time": self.server_start_time,
|
239 |
+
"total_requests": 0,
|
240 |
+
"successful_requests": 0,
|
241 |
+
"error_requests": 0,
|
242 |
+
"average_processing_time": 0,
|
243 |
+
"total_questions_processed": 0,
|
244 |
+
"total_documents_processed": 0
|
245 |
+
}
|
246 |
+
|
247 |
+
successful_requests = len([log for log in self.logs if log.status == "success"])
|
248 |
+
error_requests = len([log for log in self.logs if log.status == "error"])
|
249 |
+
total_processing_time = sum(log.processing_time_seconds for log in self.logs)
|
250 |
+
total_questions = sum(log.total_questions for log in self.logs)
|
251 |
+
unique_documents = len(set(log.document_url for log in self.logs))
|
252 |
+
preprocessed_count = len([log for log in self.logs if log.was_preprocessed])
|
253 |
+
|
254 |
+
# Enhanced timing statistics
|
255 |
+
pipeline_times = []
|
256 |
+
question_times = []
|
257 |
+
stage_times = {'query_expansion': [], 'hybrid_search': [], 'reranking': [],
|
258 |
+
'context_creation': [], 'llm_generation': []}
|
259 |
+
|
260 |
+
for log in self.logs:
|
261 |
+
# Collect question timing data
|
262 |
+
for q_timing in log.question_timings:
|
263 |
+
question_times.append(q_timing.get('total_time_seconds', 0))
|
264 |
+
# Collect stage-specific timings
|
265 |
+
breakdown = q_timing.get('pipeline_breakdown', {})
|
266 |
+
for stage, duration in breakdown.items():
|
267 |
+
if stage in stage_times:
|
268 |
+
stage_times[stage].append(duration)
|
269 |
+
|
270 |
+
# Calculate averages for each stage
|
271 |
+
avg_stage_times = {}
|
272 |
+
for stage, times in stage_times.items():
|
273 |
+
if times:
|
274 |
+
avg_stage_times[f'avg_{stage}_time'] = round(sum(times) / len(times), 4)
|
275 |
+
avg_stage_times[f'max_{stage}_time'] = round(max(times), 4)
|
276 |
+
else:
|
277 |
+
avg_stage_times[f'avg_{stage}_time'] = 0
|
278 |
+
avg_stage_times[f'max_{stage}_time'] = 0
|
279 |
+
|
280 |
+
return {
|
281 |
+
"server_start_time": self.server_start_time,
|
282 |
+
"total_requests": total_requests,
|
283 |
+
"successful_requests": successful_requests,
|
284 |
+
"error_requests": error_requests,
|
285 |
+
"partial_requests": total_requests - successful_requests - error_requests,
|
286 |
+
"success_rate": round((successful_requests / total_requests) * 100, 2),
|
287 |
+
"average_processing_time": round(total_processing_time / total_requests, 2),
|
288 |
+
"total_questions_processed": total_questions,
|
289 |
+
"total_documents_processed": unique_documents,
|
290 |
+
"documents_already_preprocessed": preprocessed_count,
|
291 |
+
"documents_newly_processed": total_requests - preprocessed_count,
|
292 |
+
"average_question_time": round(sum(question_times) / len(question_times), 4) if question_times else 0,
|
293 |
+
"pipeline_performance": avg_stage_times
|
294 |
+
}
|
295 |
+
|
296 |
+
def export_logs(self) -> Dict[str, Any]:
|
297 |
+
"""
|
298 |
+
Export all logs in a structured format for external consumption.
|
299 |
+
|
300 |
+
Returns:
|
301 |
+
Dict containing metadata and all logs
|
302 |
+
"""
|
303 |
+
summary = self.get_logs_summary()
|
304 |
+
logs = self.get_logs()
|
305 |
+
|
306 |
+
return {
|
307 |
+
"export_timestamp": datetime.now().isoformat(),
|
308 |
+
"metadata": summary,
|
309 |
+
"logs": logs
|
310 |
+
}
|
311 |
+
|
312 |
+
def get_logs_by_document(self, document_url: str) -> List[Dict[str, Any]]:
|
313 |
+
"""Get all logs for a specific document URL."""
|
314 |
+
with self._lock:
|
315 |
+
filtered_logs = [
|
316 |
+
asdict(log) for log in self.logs
|
317 |
+
if log.document_url == document_url
|
318 |
+
]
|
319 |
+
|
320 |
+
# Return most recent first
|
321 |
+
filtered_logs.reverse()
|
322 |
+
return filtered_logs
|
323 |
+
|
324 |
+
def get_recent_logs(self, minutes: int = 60) -> List[Dict[str, Any]]:
|
325 |
+
"""Get logs from the last N minutes."""
|
326 |
+
cutoff_time = datetime.now().timestamp() - (minutes * 60)
|
327 |
+
|
328 |
+
with self._lock:
|
329 |
+
recent_logs = []
|
330 |
+
for log in self.logs:
|
331 |
+
log_time = datetime.fromisoformat(log.timestamp).timestamp()
|
332 |
+
if log_time >= cutoff_time:
|
333 |
+
recent_logs.append(asdict(log))
|
334 |
+
|
335 |
+
# Return most recent first
|
336 |
+
recent_logs.reverse()
|
337 |
+
return recent_logs
|
338 |
+
|
339 |
+
# Global logger instance
|
340 |
+
rag_logger = RAGLogger()
|
preprocessing/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Preprocessing package
|
2 |
+
|
3 |
+
from .preprocessing import DocumentPreprocessor
|
4 |
+
from .preprocessing_modules import (
|
5 |
+
PDFDownloader,
|
6 |
+
TextExtractor,
|
7 |
+
TextChunker,
|
8 |
+
EmbeddingManager,
|
9 |
+
VectorStorage,
|
10 |
+
MetadataManager,
|
11 |
+
ModularDocumentPreprocessor
|
12 |
+
)
|
13 |
+
|
14 |
+
__all__ = [
|
15 |
+
'DocumentPreprocessor',
|
16 |
+
'PDFDownloader',
|
17 |
+
'TextExtractor',
|
18 |
+
'TextChunker',
|
19 |
+
'EmbeddingManager',
|
20 |
+
'VectorStorage',
|
21 |
+
'MetadataManager',
|
22 |
+
'ModularDocumentPreprocessor'
|
23 |
+
]
|
preprocessing/preprocessing.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import asyncio
|
3 |
+
from typing import List, Dict, Any
|
4 |
+
|
5 |
+
from config.config import *
|
6 |
+
from .preprocessing_modules.modular_preprocessor import ModularDocumentPreprocessor
|
7 |
+
|
8 |
+
# For backward compatibility, create an alias
|
9 |
+
class DocumentPreprocessor(ModularDocumentPreprocessor):
|
10 |
+
"""Backward compatibility alias for the modular document preprocessor."""
|
11 |
+
pass
|
12 |
+
|
13 |
+
# CLI interface for preprocessing
|
14 |
+
async def main():
|
15 |
+
"""Main function for command-line usage."""
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
parser = argparse.ArgumentParser(description="Document Preprocessing for RAG")
|
19 |
+
parser.add_argument("--url", type=str, help="Single PDF URL to process")
|
20 |
+
parser.add_argument("--urls-file", type=str, help="File containing PDF URLs (one per line)")
|
21 |
+
parser.add_argument("--force", action="store_true", help="Force reprocessing even if already processed")
|
22 |
+
parser.add_argument("--list", action="store_true", help="List all processed documents")
|
23 |
+
parser.add_argument("--stats", action="store_true", help="Show collection statistics")
|
24 |
+
|
25 |
+
args = parser.parse_args()
|
26 |
+
|
27 |
+
preprocessor = DocumentPreprocessor()
|
28 |
+
|
29 |
+
if args.list:
|
30 |
+
docs = preprocessor.list_processed_documents()
|
31 |
+
print(f"\n📚 Processed Documents ({len(docs)}):")
|
32 |
+
for doc_id, info in docs.items():
|
33 |
+
print(f" • {doc_id}: {info['document_url'][:50]}... ({info.get('chunk_count', 'N/A')} chunks)")
|
34 |
+
|
35 |
+
elif args.stats:
|
36 |
+
stats = preprocessor.get_collection_stats()
|
37 |
+
print(f"\n📊 Collection Statistics:")
|
38 |
+
print(f" • Total documents: {stats['total_documents']}")
|
39 |
+
print(f" • Total collections: {stats['total_collections']}")
|
40 |
+
print(f" • Total chunks: {stats['total_chunks']}")
|
41 |
+
|
42 |
+
elif args.url:
|
43 |
+
await preprocessor.process_document(args.url, args.force)
|
44 |
+
|
45 |
+
elif args.urls_file:
|
46 |
+
if not os.path.exists(args.urls_file):
|
47 |
+
print(f"❌ File not found: {args.urls_file}")
|
48 |
+
return
|
49 |
+
|
50 |
+
with open(args.urls_file, 'r') as f:
|
51 |
+
urls = [line.strip() for line in f if line.strip()]
|
52 |
+
|
53 |
+
if urls:
|
54 |
+
await preprocessor.process_multiple_documents(urls, args.force)
|
55 |
+
else:
|
56 |
+
print("❌ No URLs found in file")
|
57 |
+
|
58 |
+
else:
|
59 |
+
print("❌ Please provide --url, --urls-file, --list, or --stats")
|
60 |
+
parser.print_help()
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
asyncio.run(main())
|
preprocessing/preprocessing_modules/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Preprocessing modules
|
2 |
+
|
3 |
+
from .pdf_downloader import PDFDownloader
|
4 |
+
from .file_downloader import FileDownloader
|
5 |
+
from .text_extractor import TextExtractor
|
6 |
+
from .text_chunker import TextChunker
|
7 |
+
from .embedding_manager import EmbeddingManager
|
8 |
+
from .vector_storage import VectorStorage
|
9 |
+
from .metadata_manager import MetadataManager
|
10 |
+
from .modular_preprocessor import ModularDocumentPreprocessor
|
11 |
+
from .docx_extractor import extract_docx
|
12 |
+
from .pptx_extractor import extract_pptx
|
13 |
+
from .xlsx_extractor import extract_xlsx
|
14 |
+
from .image_extractor import extract_image_content
|
15 |
+
|
16 |
+
__all__ = [
|
17 |
+
'PDFDownloader',
|
18 |
+
'FileDownloader',
|
19 |
+
'TextExtractor',
|
20 |
+
'TextChunker',
|
21 |
+
'EmbeddingManager',
|
22 |
+
'VectorStorage',
|
23 |
+
'MetadataManager',
|
24 |
+
'ModularDocumentPreprocessor',
|
25 |
+
'extract_docx',
|
26 |
+
'extract_pptx',
|
27 |
+
'extract_xlsx',
|
28 |
+
'extract_image_content'
|
29 |
+
]
|
preprocessing/preprocessing_modules/docx_extractor.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from docx import Document
|
2 |
+
from docx.document import Document as _Document
|
3 |
+
from docx.table import Table
|
4 |
+
from docx.text.paragraph import Paragraph
|
5 |
+
from typing import Union, List, Dict, Any
|
6 |
+
from PIL import Image
|
7 |
+
from io import BytesIO
|
8 |
+
import pytesseract
|
9 |
+
import os
|
10 |
+
|
11 |
+
from zipfile import ZipFile
|
12 |
+
from lxml import etree
|
13 |
+
from pathlib import Path
|
14 |
+
import io
|
15 |
+
|
16 |
+
def extract_docx(docx_input) -> str:
|
17 |
+
"""Extract text from DOCX files with table and text handling."""
|
18 |
+
zipf = ZipFile(docx_input)
|
19 |
+
xml_content = zipf.read("word/document.xml")
|
20 |
+
tree = etree.fromstring(xml_content)
|
21 |
+
|
22 |
+
ns = {
|
23 |
+
"w": "http://schemas.openxmlformats.org/wordprocessingml/2006/main",
|
24 |
+
"a": "http://schemas.openxmlformats.org/drawingml/2006/main",
|
25 |
+
"wps": "http://schemas.microsoft.com/office/word/2010/wordprocessingShape"
|
26 |
+
}
|
27 |
+
|
28 |
+
text_blocks = []
|
29 |
+
|
30 |
+
# Extract all tables with gridSpan handling
|
31 |
+
tables = tree.xpath("//w:tbl", namespaces=ns)
|
32 |
+
table_elements = set(tables)
|
33 |
+
table_index = 0
|
34 |
+
|
35 |
+
for tbl in tables:
|
36 |
+
rows = tbl.xpath("./w:tr", namespaces=ns)
|
37 |
+
sub_tables = []
|
38 |
+
current_table = []
|
39 |
+
|
40 |
+
prev_col_count = None
|
41 |
+
for row in rows:
|
42 |
+
row_texts = []
|
43 |
+
cells = row.xpath("./w:tc", namespaces=ns)
|
44 |
+
col_count = 0
|
45 |
+
|
46 |
+
for cell in cells:
|
47 |
+
cell_text = ""
|
48 |
+
paragraphs = cell.xpath(".//w:p", namespaces=ns)
|
49 |
+
for para in paragraphs:
|
50 |
+
text_nodes = para.xpath(".//w:t", namespaces=ns)
|
51 |
+
para_text = "".join(node.text for node in text_nodes if node.text)
|
52 |
+
if para_text.strip():
|
53 |
+
cell_text += para_text + " "
|
54 |
+
|
55 |
+
# Handle gridSpan (merged cells)
|
56 |
+
gridspan_elem = cell.xpath(".//w:gridSpan", namespaces=ns)
|
57 |
+
span = int(gridspan_elem[0].get(ns["w"] + "val", "1")) if gridspan_elem else 1
|
58 |
+
|
59 |
+
row_texts.append(cell_text.strip())
|
60 |
+
col_count += span
|
61 |
+
|
62 |
+
if row_texts and any(text.strip() for text in row_texts):
|
63 |
+
if prev_col_count is not None and col_count != prev_col_count:
|
64 |
+
# Column count changed, save current table and start new one
|
65 |
+
if current_table:
|
66 |
+
sub_tables.append(current_table)
|
67 |
+
current_table = []
|
68 |
+
|
69 |
+
current_table.append(row_texts)
|
70 |
+
prev_col_count = col_count
|
71 |
+
|
72 |
+
if current_table:
|
73 |
+
sub_tables.append(current_table)
|
74 |
+
|
75 |
+
# Format tables
|
76 |
+
for sub_table in sub_tables:
|
77 |
+
table_text = f"\\n--- Table {table_index + 1} ---\\n"
|
78 |
+
for row in sub_table:
|
79 |
+
table_text += " | ".join(row) + "\\n"
|
80 |
+
text_blocks.append(table_text)
|
81 |
+
table_index += 1
|
82 |
+
|
83 |
+
# Extract non-table paragraphs
|
84 |
+
paragraphs = tree.xpath("//w:p", namespaces=ns)
|
85 |
+
for para in paragraphs:
|
86 |
+
# Check if paragraph is inside a table
|
87 |
+
is_in_table = any(table in para.xpath("ancestor::*") for table in table_elements)
|
88 |
+
if not is_in_table:
|
89 |
+
text_nodes = para.xpath(".//w:t", namespaces=ns)
|
90 |
+
para_text = "".join(node.text for node in text_nodes if node.text)
|
91 |
+
if para_text.strip():
|
92 |
+
text_blocks.append(para_text.strip())
|
93 |
+
|
94 |
+
return "\\n\\n".join(text_blocks)
|
preprocessing/preprocessing_modules/embedding_manager.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Embedding Manager Module
|
3 |
+
|
4 |
+
Handles creation of embeddings for text chunks using sentence transformers.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import asyncio
|
8 |
+
import numpy as np
|
9 |
+
from typing import List
|
10 |
+
from sentence_transformers import SentenceTransformer
|
11 |
+
from config.config import EMBEDDING_MODEL, BATCH_SIZE
|
12 |
+
|
13 |
+
|
14 |
+
class EmbeddingManager:
|
15 |
+
"""Handles embedding creation for text chunks."""
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
"""Initialize the embedding manager."""
|
19 |
+
self.embedding_model = None
|
20 |
+
self._init_embedding_model()
|
21 |
+
|
22 |
+
def _init_embedding_model(self):
|
23 |
+
"""Initialize the embedding model."""
|
24 |
+
print(f"🔄 Loading embedding model: {EMBEDDING_MODEL}")
|
25 |
+
self.embedding_model = SentenceTransformer(EMBEDDING_MODEL)
|
26 |
+
print(f"✅ Embedding model loaded successfully")
|
27 |
+
|
28 |
+
async def create_embeddings(self, chunks: List[str]) -> np.ndarray:
|
29 |
+
"""
|
30 |
+
Create embeddings for text chunks.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
chunks: List of text chunks to embed
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
np.ndarray: Array of embeddings with shape (num_chunks, embedding_dim)
|
37 |
+
"""
|
38 |
+
print(f"🧠 Creating embeddings for {len(chunks)} chunks")
|
39 |
+
|
40 |
+
if not chunks:
|
41 |
+
raise ValueError("No chunks provided for embedding creation")
|
42 |
+
|
43 |
+
def create_embeddings_sync():
|
44 |
+
"""Synchronous embedding creation to run in thread pool."""
|
45 |
+
embeddings = self.embedding_model.encode(
|
46 |
+
chunks,
|
47 |
+
batch_size=BATCH_SIZE,
|
48 |
+
show_progress_bar=True,
|
49 |
+
normalize_embeddings=True
|
50 |
+
)
|
51 |
+
return np.array(embeddings).astype("float32")
|
52 |
+
|
53 |
+
# Run in thread pool to avoid blocking the event loop
|
54 |
+
loop = asyncio.get_event_loop()
|
55 |
+
embeddings = await loop.run_in_executor(None, create_embeddings_sync)
|
56 |
+
|
57 |
+
print(f"✅ Created embeddings with shape: {embeddings.shape}")
|
58 |
+
return embeddings
|
59 |
+
|
60 |
+
def get_embedding_dimension(self) -> int:
|
61 |
+
"""
|
62 |
+
Get the dimension of embeddings produced by the model.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
int: Embedding dimension
|
66 |
+
"""
|
67 |
+
if self.embedding_model is None:
|
68 |
+
raise RuntimeError("Embedding model not initialized")
|
69 |
+
|
70 |
+
# Get dimension from model
|
71 |
+
return self.embedding_model.get_sentence_embedding_dimension()
|
72 |
+
|
73 |
+
def validate_embeddings(self, embeddings: np.ndarray, expected_count: int) -> bool:
|
74 |
+
"""
|
75 |
+
Validate that embeddings have the expected shape and properties.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
embeddings: The embeddings array to validate
|
79 |
+
expected_count: Expected number of embeddings
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
bool: True if embeddings are valid, False otherwise
|
83 |
+
"""
|
84 |
+
if embeddings is None:
|
85 |
+
return False
|
86 |
+
|
87 |
+
if embeddings.shape[0] != expected_count:
|
88 |
+
print(f"❌ Embedding count mismatch: expected {expected_count}, got {embeddings.shape[0]}")
|
89 |
+
return False
|
90 |
+
|
91 |
+
if embeddings.dtype != np.float32:
|
92 |
+
print(f"❌ Embedding dtype mismatch: expected float32, got {embeddings.dtype}")
|
93 |
+
return False
|
94 |
+
|
95 |
+
# Check for NaN or infinite values
|
96 |
+
if np.any(np.isnan(embeddings)) or np.any(np.isinf(embeddings)):
|
97 |
+
print("❌ Embeddings contain NaN or infinite values")
|
98 |
+
return False
|
99 |
+
|
100 |
+
print(f"✅ Embeddings validation passed: {embeddings.shape}")
|
101 |
+
return True
|
102 |
+
|
103 |
+
def get_model_info(self) -> dict:
|
104 |
+
"""
|
105 |
+
Get information about the embedding model.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
dict: Model information
|
109 |
+
"""
|
110 |
+
if self.embedding_model is None:
|
111 |
+
return {"model_name": EMBEDDING_MODEL, "status": "not_loaded"}
|
112 |
+
|
113 |
+
return {
|
114 |
+
"model_name": EMBEDDING_MODEL,
|
115 |
+
"embedding_dimension": self.get_embedding_dimension(),
|
116 |
+
"max_sequence_length": getattr(self.embedding_model, 'max_seq_length', 'unknown'),
|
117 |
+
"status": "loaded"
|
118 |
+
}
|
preprocessing/preprocessing_modules/file_downloader.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import aiohttp
|
2 |
+
import asyncio
|
3 |
+
import tempfile
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
from urllib.parse import urlparse
|
7 |
+
from typing import List, Tuple
|
8 |
+
|
9 |
+
class FileDownloader:
|
10 |
+
"""Enhanced file downloader that supports multiple file types."""
|
11 |
+
|
12 |
+
async def download_file(self, url: str, timeout: int = 300, max_retries: int = 3) -> Tuple[str, str]:
|
13 |
+
"""Download any file type from a URL to a temporary file with enhanced error handling."""
|
14 |
+
print(f"📥 Downloading file from: {url[:60]}...")
|
15 |
+
|
16 |
+
for attempt in range(max_retries):
|
17 |
+
try:
|
18 |
+
timeout_config = aiohttp.ClientTimeout(
|
19 |
+
total=timeout,
|
20 |
+
connect=30,
|
21 |
+
sock_read=120
|
22 |
+
)
|
23 |
+
|
24 |
+
async with aiohttp.ClientSession(timeout=timeout_config) as session:
|
25 |
+
print(f" Attempt {attempt + 1}/{max_retries} (timeout: {timeout}s)")
|
26 |
+
|
27 |
+
async with session.get(url) as response:
|
28 |
+
if response.status != 200:
|
29 |
+
raise Exception(f"Failed to download file: HTTP {response.status}")
|
30 |
+
|
31 |
+
# Extract filename from header or URL
|
32 |
+
cd = response.headers.get('Content-Disposition', '')
|
33 |
+
filename_match = re.findall('filename="?([^"]+)"?', cd)
|
34 |
+
if filename_match:
|
35 |
+
filename = filename_match[0]
|
36 |
+
else:
|
37 |
+
from urllib.parse import unquote
|
38 |
+
path = urlparse(url).path
|
39 |
+
filename = os.path.basename(unquote(path)) # Decode URL encoding
|
40 |
+
|
41 |
+
if not filename:
|
42 |
+
filename = "downloaded_file"
|
43 |
+
|
44 |
+
ext = os.path.splitext(filename)[1]
|
45 |
+
if not ext:
|
46 |
+
return url, "url"
|
47 |
+
|
48 |
+
print(f" 📁 Detected filename: {filename}, extension: {ext}")
|
49 |
+
|
50 |
+
# Check if file type is supported
|
51 |
+
supported_extensions = ['.pdf', '.docx', '.pptx', '.png', '.xlsx', '.jpeg', '.jpg', '.txt', '.csv']
|
52 |
+
if ext not in supported_extensions:
|
53 |
+
# Return extension without dot for consistency
|
54 |
+
ext_without_dot = ext[1:] if ext.startswith('.') else ext
|
55 |
+
print(f" ❌ File type not supported: {ext}")
|
56 |
+
return 'not supported', ext_without_dot
|
57 |
+
|
58 |
+
# Get content length
|
59 |
+
content_length = response.headers.get('content-length')
|
60 |
+
if content_length:
|
61 |
+
total_size = int(content_length)
|
62 |
+
print(f" File size: {total_size / (1024 * 1024):.1f} MB")
|
63 |
+
|
64 |
+
# Create temp file with same extension
|
65 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=ext, prefix="download_")
|
66 |
+
|
67 |
+
# Write to file
|
68 |
+
downloaded = 0
|
69 |
+
async for chunk in response.content.iter_chunked(16384):
|
70 |
+
temp_file.write(chunk)
|
71 |
+
downloaded += len(chunk)
|
72 |
+
|
73 |
+
if content_length and downloaded % (1024 * 1024) == 0:
|
74 |
+
progress = (downloaded / total_size) * 100
|
75 |
+
print(f" Progress: {progress:.1f}% ({downloaded / (1024*1024):.1f} MB)")
|
76 |
+
|
77 |
+
temp_file.close()
|
78 |
+
print(f"✅ File downloaded successfully: {temp_file.name}")
|
79 |
+
# Return extension without the dot for consistency with modular_preprocessor
|
80 |
+
ext_without_dot = ext[1:] if ext.startswith('.') else ext
|
81 |
+
return temp_file.name, ext_without_dot
|
82 |
+
|
83 |
+
except asyncio.TimeoutError:
|
84 |
+
print(f" ⏰ Timeout on attempt {attempt + 1}")
|
85 |
+
if attempt < max_retries - 1:
|
86 |
+
wait_time = (attempt + 1) * 30
|
87 |
+
print(f" ⏳ Waiting {wait_time}s before retry...")
|
88 |
+
await asyncio.sleep(wait_time)
|
89 |
+
continue
|
90 |
+
|
91 |
+
except Exception as e:
|
92 |
+
print(f" ❌ Error on attempt {attempt + 1}: {str(e)}")
|
93 |
+
if attempt < max_retries - 1:
|
94 |
+
wait_time = (attempt + 1) * 15
|
95 |
+
print(f" ⏳ Waiting {wait_time}s before retry...")
|
96 |
+
await asyncio.sleep(wait_time)
|
97 |
+
continue
|
98 |
+
|
99 |
+
raise Exception(f"Failed to download file after {max_retries} attempts")
|
100 |
+
|
101 |
+
def cleanup_temp_file(self, temp_path: str) -> None:
|
102 |
+
"""Clean up temporary file."""
|
103 |
+
try:
|
104 |
+
if os.path.exists(temp_path):
|
105 |
+
os.unlink(temp_path)
|
106 |
+
print(f"🗑️ Cleaned up temporary file: {temp_path}")
|
107 |
+
except Exception as e:
|
108 |
+
print(f"⚠️ Warning: Could not cleanup temp file {temp_path}: {e}")
|
preprocessing/preprocessing_modules/image_extractor.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import pytesseract
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
from PIL import Image, ImageFile
|
6 |
+
from typing import List, Dict, Any
|
7 |
+
|
8 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
9 |
+
|
10 |
+
def load_local_image(path: str) -> np.ndarray:
|
11 |
+
"""Load image from local path."""
|
12 |
+
img = Image.open(path).convert("RGB")
|
13 |
+
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
14 |
+
|
15 |
+
def sort_contours(cnts, method="top-to-bottom"):
|
16 |
+
"""Sort contours based on the specified method."""
|
17 |
+
reverse = False
|
18 |
+
i = 1 if method == "top-to-bottom" or method == "bottom-to-top" else 0
|
19 |
+
if method == "right-to-left" or method == "bottom-to-top":
|
20 |
+
reverse = True
|
21 |
+
boundingBoxes = [cv2.boundingRect(c) for c in cnts]
|
22 |
+
(cnts, boundingBoxes) = zip(*sorted(zip(cnts, boundingBoxes),
|
23 |
+
key=lambda b: b[1][i], reverse=reverse))
|
24 |
+
return cnts, boundingBoxes
|
25 |
+
|
26 |
+
def extract_cells_from_grid(table_img: np.ndarray) -> pd.DataFrame:
|
27 |
+
"""Extract table structure from image using OpenCV."""
|
28 |
+
gray = cv2.cvtColor(table_img, cv2.COLOR_BGR2GRAY)
|
29 |
+
_, binary = cv2.threshold(~gray, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
|
30 |
+
|
31 |
+
# Detect horizontal lines
|
32 |
+
horizontal = binary.copy()
|
33 |
+
cols = horizontal.shape[1]
|
34 |
+
horizontal_size = cols // 15
|
35 |
+
horizontal_structure = cv2.getStructuringElement(cv2.MORPH_RECT, (horizontal_size, 1))
|
36 |
+
horizontal = cv2.erode(horizontal, horizontal_structure)
|
37 |
+
horizontal = cv2.dilate(horizontal, horizontal_structure)
|
38 |
+
|
39 |
+
# Detect vertical lines
|
40 |
+
vertical = binary.copy()
|
41 |
+
rows = vertical.shape[0]
|
42 |
+
vertical_size = rows // 15
|
43 |
+
vertical_structure = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vertical_size))
|
44 |
+
vertical = cv2.erode(vertical, vertical_structure)
|
45 |
+
vertical = cv2.dilate(vertical, vertical_structure)
|
46 |
+
|
47 |
+
# Combine mask
|
48 |
+
mask = cv2.add(horizontal, vertical)
|
49 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
50 |
+
|
51 |
+
cells = []
|
52 |
+
for contour in contours:
|
53 |
+
x, y, w, h = cv2.boundingRect(contour)
|
54 |
+
if w > 30 and h > 20: # Filter small contours
|
55 |
+
cell_img = table_img[y:y+h, x:x+w]
|
56 |
+
try:
|
57 |
+
text = pytesseract.image_to_string(cell_img, config='--psm 7').strip()
|
58 |
+
cells.append({'x': x, 'y': y, 'w': w, 'h': h, 'text': text})
|
59 |
+
except:
|
60 |
+
cells.append({'x': x, 'y': y, 'w': w, 'h': h, 'text': ''})
|
61 |
+
|
62 |
+
# Sort cells by position to create table structure
|
63 |
+
cells.sort(key=lambda cell: (cell['y'], cell['x']))
|
64 |
+
|
65 |
+
# Group cells into rows
|
66 |
+
rows = []
|
67 |
+
current_row = []
|
68 |
+
current_y = 0
|
69 |
+
|
70 |
+
for cell in cells:
|
71 |
+
if abs(cell['y'] - current_y) > 20: # New row threshold
|
72 |
+
if current_row:
|
73 |
+
rows.append(current_row)
|
74 |
+
current_row = [cell]
|
75 |
+
current_y = cell['y']
|
76 |
+
else:
|
77 |
+
current_row.append(cell)
|
78 |
+
|
79 |
+
if current_row:
|
80 |
+
rows.append(current_row)
|
81 |
+
|
82 |
+
# Convert to DataFrame
|
83 |
+
table_data = []
|
84 |
+
for row in rows:
|
85 |
+
row_data = [cell['text'] for cell in sorted(row, key=lambda c: c['x'])]
|
86 |
+
table_data.append(row_data)
|
87 |
+
|
88 |
+
if table_data:
|
89 |
+
max_cols = max(len(row) for row in table_data)
|
90 |
+
for row in table_data:
|
91 |
+
while len(row) < max_cols:
|
92 |
+
row.append('')
|
93 |
+
return pd.DataFrame(table_data)
|
94 |
+
else:
|
95 |
+
return pd.DataFrame()
|
96 |
+
|
97 |
+
def extract_image_content(image_path: str) -> str:
|
98 |
+
"""Extract text content from images using OCR."""
|
99 |
+
try:
|
100 |
+
# Load image
|
101 |
+
img = load_local_image(image_path)
|
102 |
+
|
103 |
+
# Basic OCR
|
104 |
+
text = pytesseract.image_to_string(img)
|
105 |
+
|
106 |
+
# Try to detect if it's a table
|
107 |
+
if '|' in text or '\\t' in text or len(text.split('\\n')) > 3:
|
108 |
+
# Try table extraction
|
109 |
+
try:
|
110 |
+
table_df = extract_cells_from_grid(img)
|
111 |
+
if not table_df.empty:
|
112 |
+
table_text = "\\n".join([" | ".join(row) for row in table_df.values])
|
113 |
+
return f"[Table detected]\\n{table_text}\\n\\n[OCR Text]\\n{text}"
|
114 |
+
except:
|
115 |
+
pass
|
116 |
+
|
117 |
+
return text.strip() if text.strip() else "[No text detected in image]"
|
118 |
+
|
119 |
+
except Exception as e:
|
120 |
+
return f"[Error processing image: {str(e)}]"
|
preprocessing/preprocessing_modules/metadata_manager.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Metadata Manager Module
|
3 |
+
|
4 |
+
Handles document metadata storage and retrieval operations.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import json
|
8 |
+
import asyncio
|
9 |
+
import hashlib
|
10 |
+
from typing import List, Dict, Any
|
11 |
+
from pathlib import Path
|
12 |
+
from config.config import EMBEDDING_MODEL, CHUNK_SIZE, CHUNK_OVERLAP
|
13 |
+
|
14 |
+
|
15 |
+
class MetadataManager:
|
16 |
+
"""Handles document metadata operations."""
|
17 |
+
|
18 |
+
def __init__(self, base_db_path: Path):
|
19 |
+
"""
|
20 |
+
Initialize the metadata manager.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
base_db_path: Base path for storing metadata files
|
24 |
+
"""
|
25 |
+
self.base_db_path = base_db_path
|
26 |
+
self.processed_docs_file = self.base_db_path / "processed_documents.json"
|
27 |
+
self.processed_docs = self._load_processed_docs()
|
28 |
+
|
29 |
+
def _load_processed_docs(self) -> Dict[str, Dict]:
|
30 |
+
"""Load the registry of processed documents."""
|
31 |
+
if self.processed_docs_file.exists():
|
32 |
+
try:
|
33 |
+
with open(self.processed_docs_file, 'r', encoding='utf-8') as f:
|
34 |
+
return json.load(f)
|
35 |
+
except Exception as e:
|
36 |
+
print(f"⚠️ Warning: Could not load processed docs registry: {e}")
|
37 |
+
return {}
|
38 |
+
|
39 |
+
def _save_processed_docs(self):
|
40 |
+
"""Save the registry of processed documents."""
|
41 |
+
try:
|
42 |
+
with open(self.processed_docs_file, 'w', encoding='utf-8') as f:
|
43 |
+
json.dump(self.processed_docs, f, indent=2, ensure_ascii=False)
|
44 |
+
except Exception as e:
|
45 |
+
print(f"⚠️ Warning: Could not save processed docs registry: {e}")
|
46 |
+
|
47 |
+
def generate_doc_id(self, document_url: str) -> str:
|
48 |
+
"""
|
49 |
+
Generate a unique document ID from the URL.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
document_url: URL of the document
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
str: Unique document ID
|
56 |
+
"""
|
57 |
+
url_hash = hashlib.md5(document_url.encode()).hexdigest()[:12]
|
58 |
+
return f"doc_{url_hash}"
|
59 |
+
|
60 |
+
def is_document_processed(self, document_url: str) -> bool:
|
61 |
+
"""
|
62 |
+
Check if a document has already been processed.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
document_url: URL of the document
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
bool: True if document is already processed
|
69 |
+
"""
|
70 |
+
doc_id = self.generate_doc_id(document_url)
|
71 |
+
return doc_id in self.processed_docs
|
72 |
+
|
73 |
+
def get_document_info(self, document_url: str) -> Dict[str, Any]:
|
74 |
+
"""
|
75 |
+
Get information about a processed document.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
document_url: URL of the document
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
Dict[str, Any]: Document information or empty dict if not found
|
82 |
+
"""
|
83 |
+
doc_id = self.generate_doc_id(document_url)
|
84 |
+
return self.processed_docs.get(doc_id, {})
|
85 |
+
|
86 |
+
def save_document_metadata(self, chunks: List[str], doc_id: str, document_url: str):
|
87 |
+
"""
|
88 |
+
Save document metadata to JSON file and update registry.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
chunks: List of text chunks
|
92 |
+
doc_id: Document identifier
|
93 |
+
document_url: Original document URL
|
94 |
+
"""
|
95 |
+
# Calculate statistics
|
96 |
+
total_chars = sum(len(chunk) for chunk in chunks)
|
97 |
+
total_words = sum(len(chunk.split()) for chunk in chunks)
|
98 |
+
avg_chunk_size = total_chars / len(chunks) if chunks else 0
|
99 |
+
|
100 |
+
# Create metadata object
|
101 |
+
metadata = {
|
102 |
+
"doc_id": doc_id,
|
103 |
+
"document_url": document_url,
|
104 |
+
"chunk_count": len(chunks),
|
105 |
+
"total_chars": total_chars,
|
106 |
+
"total_words": total_words,
|
107 |
+
"avg_chunk_size": avg_chunk_size,
|
108 |
+
"processed_at": asyncio.get_event_loop().time(),
|
109 |
+
"embedding_model": EMBEDDING_MODEL,
|
110 |
+
"chunk_size": CHUNK_SIZE,
|
111 |
+
"chunk_overlap": CHUNK_OVERLAP,
|
112 |
+
"processing_config": {
|
113 |
+
"chunk_size": CHUNK_SIZE,
|
114 |
+
"chunk_overlap": CHUNK_OVERLAP,
|
115 |
+
"embedding_model": EMBEDDING_MODEL
|
116 |
+
}
|
117 |
+
}
|
118 |
+
|
119 |
+
# Save individual document metadata
|
120 |
+
metadata_path = self.base_db_path / f"{doc_id}_metadata.json"
|
121 |
+
try:
|
122 |
+
with open(metadata_path, "w", encoding="utf-8") as f:
|
123 |
+
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
124 |
+
print(f"✅ Saved individual metadata for {doc_id}")
|
125 |
+
except Exception as e:
|
126 |
+
print(f"⚠️ Warning: Could not save individual metadata for {doc_id}: {e}")
|
127 |
+
|
128 |
+
# Update processed documents registry
|
129 |
+
self.processed_docs[doc_id] = {
|
130 |
+
"document_url": document_url,
|
131 |
+
"chunk_count": len(chunks),
|
132 |
+
"processed_at": metadata["processed_at"],
|
133 |
+
"collection_name": f"{doc_id}_collection",
|
134 |
+
"total_chars": total_chars,
|
135 |
+
"total_words": total_words
|
136 |
+
}
|
137 |
+
self._save_processed_docs()
|
138 |
+
|
139 |
+
print(f"✅ Updated registry for document {doc_id}")
|
140 |
+
|
141 |
+
def get_document_metadata(self, doc_id: str) -> Dict[str, Any]:
|
142 |
+
"""
|
143 |
+
Load individual document metadata from file.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
doc_id: Document identifier
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
Dict[str, Any]: Document metadata or empty dict if not found
|
150 |
+
"""
|
151 |
+
metadata_path = self.base_db_path / f"{doc_id}_metadata.json"
|
152 |
+
|
153 |
+
if not metadata_path.exists():
|
154 |
+
return {}
|
155 |
+
|
156 |
+
try:
|
157 |
+
with open(metadata_path, 'r', encoding='utf-8') as f:
|
158 |
+
return json.load(f)
|
159 |
+
except Exception as e:
|
160 |
+
print(f"⚠️ Warning: Could not load metadata for {doc_id}: {e}")
|
161 |
+
return {}
|
162 |
+
|
163 |
+
def list_processed_documents(self) -> Dict[str, Dict]:
|
164 |
+
"""
|
165 |
+
List all processed documents.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
Dict[str, Dict]: Copy of processed documents registry
|
169 |
+
"""
|
170 |
+
return self.processed_docs.copy()
|
171 |
+
|
172 |
+
def get_collection_stats(self) -> Dict[str, Any]:
|
173 |
+
"""
|
174 |
+
Get statistics about all collections.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
Dict[str, Any]: Collection statistics
|
178 |
+
"""
|
179 |
+
stats = {
|
180 |
+
"total_documents": len(self.processed_docs),
|
181 |
+
"total_collections": 0,
|
182 |
+
"total_chunks": 0,
|
183 |
+
"total_characters": 0,
|
184 |
+
"total_words": 0,
|
185 |
+
"documents": []
|
186 |
+
}
|
187 |
+
|
188 |
+
for doc_id, info in self.processed_docs.items():
|
189 |
+
collection_path = self.base_db_path / f"{info['collection_name']}.db"
|
190 |
+
if collection_path.exists():
|
191 |
+
stats["total_collections"] += 1
|
192 |
+
stats["total_chunks"] += info.get("chunk_count", 0)
|
193 |
+
stats["total_characters"] += info.get("total_chars", 0)
|
194 |
+
stats["total_words"] += info.get("total_words", 0)
|
195 |
+
|
196 |
+
stats["documents"].append({
|
197 |
+
"doc_id": doc_id,
|
198 |
+
"url": info["document_url"],
|
199 |
+
"chunk_count": info.get("chunk_count", 0),
|
200 |
+
"total_chars": info.get("total_chars", 0),
|
201 |
+
"total_words": info.get("total_words", 0),
|
202 |
+
"processed_at": info.get("processed_at", "unknown")
|
203 |
+
})
|
204 |
+
|
205 |
+
# Add averages
|
206 |
+
if stats["total_documents"] > 0:
|
207 |
+
stats["avg_chunks_per_doc"] = stats["total_chunks"] / stats["total_documents"]
|
208 |
+
stats["avg_chars_per_doc"] = stats["total_characters"] / stats["total_documents"]
|
209 |
+
stats["avg_words_per_doc"] = stats["total_words"] / stats["total_documents"]
|
210 |
+
|
211 |
+
return stats
|
212 |
+
|
213 |
+
def remove_document_metadata(self, doc_id: str) -> bool:
|
214 |
+
"""
|
215 |
+
Remove document metadata and registry entry.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
doc_id: Document identifier
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
bool: True if successfully removed, False otherwise
|
222 |
+
"""
|
223 |
+
try:
|
224 |
+
# Remove individual metadata file
|
225 |
+
metadata_path = self.base_db_path / f"{doc_id}_metadata.json"
|
226 |
+
if metadata_path.exists():
|
227 |
+
metadata_path.unlink()
|
228 |
+
print(f"🗑️ Removed metadata file for {doc_id}")
|
229 |
+
|
230 |
+
# Remove from registry
|
231 |
+
if doc_id in self.processed_docs:
|
232 |
+
del self.processed_docs[doc_id]
|
233 |
+
self._save_processed_docs()
|
234 |
+
print(f"🗑️ Removed registry entry for {doc_id}")
|
235 |
+
|
236 |
+
return True
|
237 |
+
|
238 |
+
except Exception as e:
|
239 |
+
print(f"❌ Error removing metadata for {doc_id}: {e}")
|
240 |
+
return False
|
241 |
+
|
242 |
+
def update_document_status(self, doc_id: str, status_info: Dict[str, Any]):
|
243 |
+
"""
|
244 |
+
Update status information for a document.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
doc_id: Document identifier
|
248 |
+
status_info: Status information to update
|
249 |
+
"""
|
250 |
+
if doc_id in self.processed_docs:
|
251 |
+
self.processed_docs[doc_id].update(status_info)
|
252 |
+
self._save_processed_docs()
|
253 |
+
print(f"✅ Updated status for document {doc_id}")
|
254 |
+
|
255 |
+
def get_registry_path(self) -> str:
|
256 |
+
"""
|
257 |
+
Get the path to the processed documents registry.
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
str: Path to registry file
|
261 |
+
"""
|
262 |
+
return str(self.processed_docs_file)
|
preprocessing/preprocessing_modules/modular_preprocessor.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modular Document Preprocessor
|
3 |
+
|
4 |
+
Main orchestrator class that uses all preprocessing modules to process documents.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import asyncio
|
9 |
+
from typing import List, Dict, Any, Union
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
from config.config import OUTPUT_DIR
|
13 |
+
from .pdf_downloader import PDFDownloader
|
14 |
+
from .file_downloader import FileDownloader
|
15 |
+
from .text_extractor import TextExtractor
|
16 |
+
from .text_chunker import TextChunker
|
17 |
+
from .embedding_manager import EmbeddingManager
|
18 |
+
from .vector_storage import VectorStorage
|
19 |
+
from .metadata_manager import MetadataManager
|
20 |
+
|
21 |
+
# Import new extractors
|
22 |
+
from .docx_extractor import extract_docx
|
23 |
+
from .pptx_extractor import extract_pptx
|
24 |
+
from .xlsx_extractor import extract_xlsx
|
25 |
+
from .image_extractor import extract_image_content
|
26 |
+
|
27 |
+
|
28 |
+
class ModularDocumentPreprocessor:
|
29 |
+
"""
|
30 |
+
Modular document preprocessor that orchestrates the entire preprocessing pipeline.
|
31 |
+
|
32 |
+
This class combines all preprocessing modules to provide a clean interface
|
33 |
+
for document processing while maintaining separation of concerns.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self):
|
37 |
+
"""Initialize the modular document preprocessor."""
|
38 |
+
# Set up base database path
|
39 |
+
self.base_db_path = Path(OUTPUT_DIR).resolve()
|
40 |
+
self._ensure_base_directory()
|
41 |
+
|
42 |
+
# Initialize all modules
|
43 |
+
self.pdf_downloader = PDFDownloader() # Keep for backward compatibility
|
44 |
+
self.file_downloader = FileDownloader() # New enhanced downloader
|
45 |
+
self.text_extractor = TextExtractor()
|
46 |
+
self.text_chunker = TextChunker()
|
47 |
+
self.embedding_manager = EmbeddingManager()
|
48 |
+
self.vector_storage = VectorStorage(self.base_db_path)
|
49 |
+
self.metadata_manager = MetadataManager(self.base_db_path)
|
50 |
+
|
51 |
+
print("✅ Modular Document Preprocessor initialized successfully")
|
52 |
+
|
53 |
+
def _ensure_base_directory(self):
|
54 |
+
"""Ensure the base directory exists."""
|
55 |
+
if not self.base_db_path.exists():
|
56 |
+
try:
|
57 |
+
self.base_db_path.mkdir(parents=True, exist_ok=True)
|
58 |
+
print(f"✅ Created directory: {self.base_db_path}")
|
59 |
+
except PermissionError:
|
60 |
+
print(f"⚠️ Directory {self.base_db_path} should exist in production environment")
|
61 |
+
if not self.base_db_path.exists():
|
62 |
+
raise RuntimeError(f"Required directory {self.base_db_path} does not exist and cannot be created")
|
63 |
+
|
64 |
+
# Delegate metadata operations to metadata manager
|
65 |
+
def generate_doc_id(self, document_url: str) -> str:
|
66 |
+
"""Generate a unique document ID from the URL."""
|
67 |
+
return self.metadata_manager.generate_doc_id(document_url)
|
68 |
+
|
69 |
+
def is_document_processed(self, document_url: str) -> bool:
|
70 |
+
"""Check if a document has already been processed."""
|
71 |
+
return self.metadata_manager.is_document_processed(document_url)
|
72 |
+
|
73 |
+
def get_document_info(self, document_url: str) -> Dict[str, Any]:
|
74 |
+
"""Get information about a processed document."""
|
75 |
+
return self.metadata_manager.get_document_info(document_url)
|
76 |
+
|
77 |
+
def list_processed_documents(self) -> Dict[str, Dict]:
|
78 |
+
"""List all processed documents."""
|
79 |
+
return self.metadata_manager.list_processed_documents()
|
80 |
+
|
81 |
+
def get_collection_stats(self) -> Dict[str, Any]:
|
82 |
+
"""Get statistics about all collections."""
|
83 |
+
return self.metadata_manager.get_collection_stats()
|
84 |
+
|
85 |
+
async def process_document(self, document_url: str, force_reprocess: bool = False, timeout: int = 300) -> Union[str, List]:
|
86 |
+
"""
|
87 |
+
Process a single document: download, extract, chunk, embed, and store.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
document_url: URL of the document (PDF, DOCX, PPTX, XLSX, images, etc.)
|
91 |
+
force_reprocess: If True, reprocess even if already processed
|
92 |
+
timeout: Download timeout in seconds (default: 300s/5min)
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
str: Document ID for normal processing
|
96 |
+
List: [content, type] for special handling (oneshot, tabular, image)
|
97 |
+
"""
|
98 |
+
doc_id = self.generate_doc_id(document_url)
|
99 |
+
|
100 |
+
# Check if already processed
|
101 |
+
if not force_reprocess and self.is_document_processed(document_url):
|
102 |
+
print(f"✅ Document {doc_id} already processed, skipping...")
|
103 |
+
return doc_id
|
104 |
+
|
105 |
+
print(f"🚀 Processing document: {doc_id}")
|
106 |
+
print(f"📄 URL: {document_url}")
|
107 |
+
|
108 |
+
temp_file_path = None
|
109 |
+
try:
|
110 |
+
# Step 1: Download file (enhanced to handle multiple types)
|
111 |
+
temp_file_path, ext = await self.file_downloader.download_file(document_url, timeout=timeout)
|
112 |
+
|
113 |
+
if temp_file_path == 'not supported':
|
114 |
+
return ['unsupported', ext]
|
115 |
+
|
116 |
+
# Step 2: Extract text based on file type
|
117 |
+
full_text = ""
|
118 |
+
match ext:
|
119 |
+
case 'pdf':
|
120 |
+
full_text = await self.text_extractor.extract_text_from_pdf(temp_file_path)
|
121 |
+
|
122 |
+
case 'docx':
|
123 |
+
full_text = extract_docx(temp_file_path)
|
124 |
+
|
125 |
+
case 'pptx':
|
126 |
+
full_text = extract_pptx(temp_file_path)
|
127 |
+
return [full_text, 'oneshot']
|
128 |
+
|
129 |
+
case 'url':
|
130 |
+
new_context = "URL for Context: " + temp_file_path
|
131 |
+
return [new_context, 'oneshot']
|
132 |
+
|
133 |
+
case 'txt':
|
134 |
+
with open(temp_file_path, 'r', encoding='utf-8') as f:
|
135 |
+
full_text = f.read()
|
136 |
+
|
137 |
+
case 'xlsx':
|
138 |
+
full_text = extract_xlsx(temp_file_path)
|
139 |
+
# Print a short preview (10-15 chars) to verify extraction
|
140 |
+
try:
|
141 |
+
preview = ''.join(full_text.split())[:15]
|
142 |
+
if preview:
|
143 |
+
print(f"🔎 XLSX extracted preview: {preview}")
|
144 |
+
except Exception:
|
145 |
+
pass
|
146 |
+
return [full_text, 'tabular']
|
147 |
+
|
148 |
+
case 'csv':
|
149 |
+
with open(temp_file_path, 'r', encoding='utf-8') as f:
|
150 |
+
full_text = f.read()
|
151 |
+
return [full_text, 'tabular']
|
152 |
+
|
153 |
+
case 'png' | 'jpeg' | 'jpg':
|
154 |
+
# Don't clean up image files - they'll be cleaned up by the caller
|
155 |
+
return [temp_file_path, 'image', True] # Third element indicates no cleanup needed
|
156 |
+
|
157 |
+
case _:
|
158 |
+
raise Exception(f"Unsupported file type: {ext}")
|
159 |
+
|
160 |
+
# Validate extracted text
|
161 |
+
if not self.text_extractor.validate_extracted_text(full_text):
|
162 |
+
raise Exception("No meaningful text extracted from document")
|
163 |
+
|
164 |
+
# Step 3: Create chunks
|
165 |
+
chunks = self.text_chunker.chunk_text(full_text)
|
166 |
+
|
167 |
+
# Check if document is too short for chunking
|
168 |
+
if len(chunks) < 5:
|
169 |
+
print(f"Only {len(chunks)} chunks formed, going for oneshot.")
|
170 |
+
return [full_text, 'oneshot']
|
171 |
+
|
172 |
+
if not chunks:
|
173 |
+
raise Exception("No chunks created from text")
|
174 |
+
|
175 |
+
# Log chunk statistics
|
176 |
+
chunk_stats = self.text_chunker.get_chunk_stats(chunks)
|
177 |
+
print(f"📊 Chunk Statistics: {chunk_stats['total_chunks']} chunks, "
|
178 |
+
f"avg size: {chunk_stats['avg_chunk_size']:.0f} chars")
|
179 |
+
|
180 |
+
# Step 4: Create embeddings
|
181 |
+
embeddings = await self.embedding_manager.create_embeddings(chunks)
|
182 |
+
|
183 |
+
# Validate embeddings
|
184 |
+
if not self.embedding_manager.validate_embeddings(embeddings, len(chunks)):
|
185 |
+
raise Exception("Invalid embeddings generated")
|
186 |
+
|
187 |
+
# Step 5: Store in Qdrant
|
188 |
+
await self.vector_storage.store_in_qdrant(chunks, embeddings, doc_id)
|
189 |
+
|
190 |
+
# Step 6: Save metadata
|
191 |
+
self.metadata_manager.save_document_metadata(chunks, doc_id, document_url)
|
192 |
+
|
193 |
+
print(f"✅ Document {doc_id} processed successfully: {len(chunks)} chunks")
|
194 |
+
return doc_id
|
195 |
+
|
196 |
+
except Exception as e:
|
197 |
+
print(f"❌ Error processing document {doc_id}: {str(e)}")
|
198 |
+
raise
|
199 |
+
finally:
|
200 |
+
# Clean up temporary file - but NOT for images since they need the file path
|
201 |
+
if temp_file_path and ext not in ['png', 'jpeg', 'jpg']:
|
202 |
+
self.file_downloader.cleanup_temp_file(temp_file_path)
|
203 |
+
|
204 |
+
async def process_multiple_documents(self, document_urls: List[str], force_reprocess: bool = False) -> Dict[str, str]:
|
205 |
+
"""
|
206 |
+
Process multiple documents concurrently.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
document_urls: List of PDF URLs
|
210 |
+
force_reprocess: If True, reprocess even if already processed
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
Dict[str, str]: Mapping of URLs to document IDs
|
214 |
+
"""
|
215 |
+
print(f"🚀 Processing {len(document_urls)} documents...")
|
216 |
+
|
217 |
+
results = {}
|
218 |
+
|
219 |
+
# Process documents concurrently (with limited concurrency)
|
220 |
+
semaphore = asyncio.Semaphore(3) # Limit to 3 concurrent downloads
|
221 |
+
|
222 |
+
async def process_single(url):
|
223 |
+
async with semaphore:
|
224 |
+
try:
|
225 |
+
doc_id = await self.process_document(url, force_reprocess)
|
226 |
+
return url, doc_id
|
227 |
+
except Exception as e:
|
228 |
+
print(f"❌ Failed to process {url}: {str(e)}")
|
229 |
+
return url, None
|
230 |
+
|
231 |
+
tasks = [process_single(url) for url in document_urls]
|
232 |
+
completed_tasks = await asyncio.gather(*tasks, return_exceptions=True)
|
233 |
+
|
234 |
+
for result in completed_tasks:
|
235 |
+
if isinstance(result, tuple):
|
236 |
+
url, doc_id = result
|
237 |
+
if doc_id:
|
238 |
+
results[url] = doc_id
|
239 |
+
|
240 |
+
print(f"✅ Successfully processed {len(results)}/{len(document_urls)} documents")
|
241 |
+
return results
|
242 |
+
|
243 |
+
def get_system_info(self) -> Dict[str, Any]:
|
244 |
+
"""
|
245 |
+
Get information about the preprocessing system.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
Dict[str, Any]: System information
|
249 |
+
"""
|
250 |
+
return {
|
251 |
+
"base_db_path": str(self.base_db_path),
|
252 |
+
"embedding_model": self.embedding_manager.get_model_info(),
|
253 |
+
"text_chunker_config": {
|
254 |
+
"chunk_size": self.text_chunker.chunk_size,
|
255 |
+
"chunk_overlap": self.text_chunker.chunk_overlap
|
256 |
+
},
|
257 |
+
"processed_documents_registry": self.metadata_manager.get_registry_path(),
|
258 |
+
"collection_stats": self.get_collection_stats()
|
259 |
+
}
|
260 |
+
|
261 |
+
def cleanup_document(self, document_url: str) -> bool:
|
262 |
+
"""
|
263 |
+
Remove all data for a specific document.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
document_url: URL of the document to clean up
|
267 |
+
|
268 |
+
Returns:
|
269 |
+
bool: True if successfully cleaned up
|
270 |
+
"""
|
271 |
+
doc_id = self.generate_doc_id(document_url)
|
272 |
+
|
273 |
+
try:
|
274 |
+
# Remove vector storage
|
275 |
+
vector_removed = self.vector_storage.delete_collection(doc_id)
|
276 |
+
|
277 |
+
# Remove metadata
|
278 |
+
metadata_removed = self.metadata_manager.remove_document_metadata(doc_id)
|
279 |
+
|
280 |
+
success = vector_removed and metadata_removed
|
281 |
+
if success:
|
282 |
+
print(f"✅ Successfully cleaned up document {doc_id}")
|
283 |
+
else:
|
284 |
+
print(f"⚠️ Partial cleanup for document {doc_id}")
|
285 |
+
|
286 |
+
return success
|
287 |
+
|
288 |
+
except Exception as e:
|
289 |
+
print(f"❌ Error cleaning up document {doc_id}: {e}")
|
290 |
+
return False
|
preprocessing/preprocessing_modules/pdf_downloader.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
PDF Downloader Module
|
3 |
+
|
4 |
+
Handles downloading PDFs from URLs with retry logic and progress tracking.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import asyncio
|
9 |
+
import tempfile
|
10 |
+
import aiohttp
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
|
14 |
+
class PDFDownloader:
|
15 |
+
"""Handles PDF downloading with enhanced error handling and retry logic."""
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
"""Initialize the PDF downloader."""
|
19 |
+
pass
|
20 |
+
|
21 |
+
async def download_pdf(self, url: str, timeout: int = 300, max_retries: int = 3) -> str:
|
22 |
+
"""
|
23 |
+
Download PDF from URL to a temporary file with enhanced error handling.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
url: URL of the PDF to download
|
27 |
+
timeout: Download timeout in seconds (default: 300s/5min)
|
28 |
+
max_retries: Maximum number of retry attempts
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
str: Path to the downloaded temporary file
|
32 |
+
|
33 |
+
Raises:
|
34 |
+
Exception: If download fails after all retries
|
35 |
+
"""
|
36 |
+
print(f"📥 Downloading PDF from: {url[:50]}...")
|
37 |
+
|
38 |
+
for attempt in range(max_retries):
|
39 |
+
try:
|
40 |
+
# Enhanced timeout settings for large files
|
41 |
+
timeout_config = aiohttp.ClientTimeout(
|
42 |
+
total=timeout, # Total timeout
|
43 |
+
connect=30, # Connection timeout
|
44 |
+
sock_read=120 # Socket read timeout
|
45 |
+
)
|
46 |
+
|
47 |
+
async with aiohttp.ClientSession(timeout=timeout_config) as session:
|
48 |
+
print(f" Attempt {attempt + 1}/{max_retries} (timeout: {timeout}s)")
|
49 |
+
|
50 |
+
async with session.get(url) as response:
|
51 |
+
if response.status != 200:
|
52 |
+
raise Exception(f"Failed to download PDF: HTTP {response.status}")
|
53 |
+
|
54 |
+
# Get content length for progress tracking
|
55 |
+
content_length = response.headers.get('content-length')
|
56 |
+
if content_length:
|
57 |
+
total_size = int(content_length)
|
58 |
+
print(f" File size: {total_size / (1024*1024):.1f} MB")
|
59 |
+
|
60 |
+
# Create temporary file
|
61 |
+
temp_file = tempfile.NamedTemporaryFile(
|
62 |
+
delete=False,
|
63 |
+
suffix=".pdf",
|
64 |
+
prefix="preprocess_"
|
65 |
+
)
|
66 |
+
|
67 |
+
# Write content to temporary file with progress tracking
|
68 |
+
downloaded = 0
|
69 |
+
async for chunk in response.content.iter_chunked(16384): # Larger chunks
|
70 |
+
temp_file.write(chunk)
|
71 |
+
downloaded += len(chunk)
|
72 |
+
|
73 |
+
# Show progress for large files
|
74 |
+
if content_length and downloaded % (1024*1024) == 0: # Every MB
|
75 |
+
progress = (downloaded / total_size) * 100
|
76 |
+
print(f" Progress: {progress:.1f}% ({downloaded/(1024*1024):.1f} MB)")
|
77 |
+
|
78 |
+
temp_file.close()
|
79 |
+
print(f"✅ PDF downloaded successfully: {temp_file.name}")
|
80 |
+
return temp_file.name
|
81 |
+
|
82 |
+
except asyncio.TimeoutError:
|
83 |
+
print(f" ⏰ Timeout on attempt {attempt + 1}")
|
84 |
+
if attempt < max_retries - 1:
|
85 |
+
wait_time = (attempt + 1) * 30 # Increasing wait time
|
86 |
+
print(f" ⏳ Waiting {wait_time}s before retry...")
|
87 |
+
await asyncio.sleep(wait_time)
|
88 |
+
continue
|
89 |
+
|
90 |
+
except Exception as e:
|
91 |
+
print(f" ❌ Error on attempt {attempt + 1}: {str(e)}")
|
92 |
+
if attempt < max_retries - 1:
|
93 |
+
wait_time = (attempt + 1) * 15
|
94 |
+
print(f" ⏳ Waiting {wait_time}s before retry...")
|
95 |
+
await asyncio.sleep(wait_time)
|
96 |
+
continue
|
97 |
+
|
98 |
+
raise Exception(f"Failed to download PDF after {max_retries} attempts")
|
99 |
+
|
100 |
+
def cleanup_temp_file(self, temp_path: str) -> None:
|
101 |
+
"""
|
102 |
+
Clean up temporary file.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
temp_path: Path to the temporary file to delete
|
106 |
+
"""
|
107 |
+
if temp_path and os.path.exists(temp_path):
|
108 |
+
try:
|
109 |
+
os.unlink(temp_path)
|
110 |
+
print(f"🗑️ Cleaned up temporary file: {temp_path}")
|
111 |
+
except Exception as e:
|
112 |
+
print(f"⚠️ Warning: Could not delete temporary file {temp_path}: {e}")
|
preprocessing/preprocessing_modules/pptx_extractor.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pptx import Presentation
|
2 |
+
from pptx.enum.shapes import MSO_SHAPE_TYPE
|
3 |
+
from typing import List, Dict, Any
|
4 |
+
from PIL import Image
|
5 |
+
from io import BytesIO
|
6 |
+
import requests
|
7 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
8 |
+
import tempfile
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
12 |
+
from config import config
|
13 |
+
|
14 |
+
# OCR Space API configuration
|
15 |
+
API_KEY = getattr(config, 'OCR_SPACE_API_KEY', None)
|
16 |
+
API_URL = "https://api.ocr.space/parse/image"
|
17 |
+
|
18 |
+
def ocr_space_file(filename, api_key=API_KEY, overlay=False, language="eng"):
|
19 |
+
"""Extract text from image file using OCR Space API"""
|
20 |
+
if not api_key:
|
21 |
+
return filename, "OCR API key not configured"
|
22 |
+
|
23 |
+
payload = {
|
24 |
+
"isOverlayRequired": overlay,
|
25 |
+
"apikey": api_key,
|
26 |
+
"language": language,
|
27 |
+
"detectOrientation": True,
|
28 |
+
"scale": True,
|
29 |
+
"isTable": False,
|
30 |
+
"OCREngine": 2
|
31 |
+
}
|
32 |
+
try:
|
33 |
+
with open(filename, "rb") as f:
|
34 |
+
response = requests.post(API_URL, files={filename: f}, data=payload, timeout=30)
|
35 |
+
|
36 |
+
if response.status_code != 200:
|
37 |
+
return filename, f"API Error: HTTP {response.status_code}"
|
38 |
+
|
39 |
+
parsed = response.json()
|
40 |
+
|
41 |
+
if parsed.get("OCRExitCode") == 1:
|
42 |
+
parsed_text = parsed.get("ParsedResults", [{}])[0].get("ParsedText", "")
|
43 |
+
return filename, parsed_text
|
44 |
+
else:
|
45 |
+
error_msg = parsed.get("ErrorMessage", ["Unknown error"])[0] if parsed.get("ErrorMessage") else "Unknown OCR error"
|
46 |
+
return filename, f"OCR Error: {error_msg}"
|
47 |
+
|
48 |
+
except requests.exceptions.Timeout:
|
49 |
+
return filename, "Error: Request timeout"
|
50 |
+
except requests.exceptions.RequestException as e:
|
51 |
+
return filename, f"Error: Network error - {str(e)}"
|
52 |
+
except Exception as e:
|
53 |
+
return filename, f"Error: {e}"
|
54 |
+
|
55 |
+
def extract_pptx(pptx_path: str) -> str:
|
56 |
+
"""Extract text and images from PowerPoint presentations."""
|
57 |
+
try:
|
58 |
+
prs = Presentation(pptx_path)
|
59 |
+
except Exception as e:
|
60 |
+
return f"Error loading PowerPoint file: {str(e)}"
|
61 |
+
|
62 |
+
all_content = []
|
63 |
+
temp_files = []
|
64 |
+
|
65 |
+
try:
|
66 |
+
for slide_idx, slide in enumerate(prs.slides):
|
67 |
+
slide_content = [f"\\n=== Slide {slide_idx + 1} ===\\n"]
|
68 |
+
slide_images = []
|
69 |
+
|
70 |
+
for shape in slide.shapes:
|
71 |
+
# Extract text
|
72 |
+
if hasattr(shape, "text") and shape.text.strip():
|
73 |
+
slide_content.append(shape.text.strip())
|
74 |
+
|
75 |
+
# Extract images
|
76 |
+
elif shape.shape_type == MSO_SHAPE_TYPE.PICTURE:
|
77 |
+
try:
|
78 |
+
image = shape.image
|
79 |
+
image_bytes = image.blob
|
80 |
+
|
81 |
+
# Save image to temp file
|
82 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
|
83 |
+
temp_file.write(image_bytes)
|
84 |
+
temp_file.close()
|
85 |
+
temp_files.append(temp_file.name)
|
86 |
+
slide_images.append(temp_file.name)
|
87 |
+
except Exception as e:
|
88 |
+
slide_content.append(f"[Image extraction error: {str(e)}]")
|
89 |
+
|
90 |
+
# Process images with OCR if API key is available
|
91 |
+
if slide_images and API_KEY:
|
92 |
+
try:
|
93 |
+
with ThreadPoolExecutor(max_workers=3) as executor:
|
94 |
+
future_to_filename = {
|
95 |
+
executor.submit(ocr_space_file, img_file): img_file
|
96 |
+
for img_file in slide_images
|
97 |
+
}
|
98 |
+
|
99 |
+
for future in as_completed(future_to_filename):
|
100 |
+
filename, ocr_result = future.result()
|
101 |
+
if ocr_result and not ocr_result.startswith("Error") and not ocr_result.startswith("OCR Error"):
|
102 |
+
slide_content.append(f"[Image Text]: {ocr_result}")
|
103 |
+
except Exception as e:
|
104 |
+
slide_content.append(f"[OCR processing error: {str(e)}]")
|
105 |
+
elif slide_images:
|
106 |
+
slide_content.append(f"[{len(slide_images)} images found - OCR not available]")
|
107 |
+
|
108 |
+
all_content.append("\\n".join(slide_content))
|
109 |
+
|
110 |
+
finally:
|
111 |
+
# Clean up temp files
|
112 |
+
for temp_file in temp_files:
|
113 |
+
try:
|
114 |
+
os.unlink(temp_file)
|
115 |
+
except:
|
116 |
+
pass
|
117 |
+
|
118 |
+
return "\\n\\n".join(all_content)
|
preprocessing/preprocessing_modules/text_chunker.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Text Chunker Module
|
3 |
+
|
4 |
+
Handles chunking text into smaller pieces with overlap for better context preservation.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import re
|
8 |
+
from typing import List
|
9 |
+
from config.config import CHUNK_SIZE, CHUNK_OVERLAP
|
10 |
+
|
11 |
+
|
12 |
+
class TextChunker:
|
13 |
+
"""Handles text chunking with overlap and smart boundary detection."""
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
"""Initialize the text chunker."""
|
17 |
+
self.chunk_size = CHUNK_SIZE
|
18 |
+
self.chunk_overlap = CHUNK_OVERLAP
|
19 |
+
|
20 |
+
def chunk_text(self, text: str) -> List[str]:
|
21 |
+
"""
|
22 |
+
Chunk text into smaller pieces with overlap.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
text: The input text to chunk
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
List[str]: List of text chunks
|
29 |
+
"""
|
30 |
+
print(f"✂️ Chunking text into {self.chunk_size} character chunks with {self.chunk_overlap} overlap")
|
31 |
+
|
32 |
+
# Clean the text
|
33 |
+
cleaned_text = self._clean_text(text)
|
34 |
+
|
35 |
+
chunks = []
|
36 |
+
start = 0
|
37 |
+
|
38 |
+
while start < len(cleaned_text):
|
39 |
+
end = start + self.chunk_size
|
40 |
+
|
41 |
+
# Try to end at sentence boundary
|
42 |
+
if end < len(cleaned_text):
|
43 |
+
end = self._find_sentence_boundary(cleaned_text, start, end)
|
44 |
+
|
45 |
+
chunk = cleaned_text[start:end].strip()
|
46 |
+
|
47 |
+
# Only add chunk if it's meaningful
|
48 |
+
if chunk and len(chunk) > 50:
|
49 |
+
chunks.append(chunk)
|
50 |
+
|
51 |
+
# Move start position with overlap
|
52 |
+
start = end - self.chunk_overlap
|
53 |
+
if start >= len(cleaned_text):
|
54 |
+
break
|
55 |
+
|
56 |
+
print(f"✅ Created {len(chunks)} chunks (size={self.chunk_size}, overlap={self.chunk_overlap})")
|
57 |
+
return chunks
|
58 |
+
|
59 |
+
def _clean_text(self, text: str) -> str:
|
60 |
+
"""
|
61 |
+
Clean text by normalizing whitespace and removing excessive line breaks.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
text: Raw text to clean
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
str: Cleaned text
|
68 |
+
"""
|
69 |
+
# Replace multiple whitespace with single space
|
70 |
+
text = re.sub(r'\s+', ' ', text)
|
71 |
+
return text.strip()
|
72 |
+
|
73 |
+
def _find_sentence_boundary(self, text: str, start: int, preferred_end: int) -> int:
|
74 |
+
"""
|
75 |
+
Find the best sentence boundary near the preferred end position.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
text: The full text
|
79 |
+
start: Start position of the chunk
|
80 |
+
preferred_end: Preferred end position
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
int: Adjusted end position at sentence boundary
|
84 |
+
"""
|
85 |
+
# Look for sentence endings within a reasonable range
|
86 |
+
search_start = max(start, preferred_end - 100)
|
87 |
+
search_end = min(len(text), preferred_end + 50)
|
88 |
+
|
89 |
+
sentence_endings = ['.', '!', '?']
|
90 |
+
best_end = preferred_end
|
91 |
+
|
92 |
+
# Search backwards from preferred end for sentence boundary
|
93 |
+
for i in range(preferred_end - 1, search_start - 1, -1):
|
94 |
+
if text[i] in sentence_endings:
|
95 |
+
# Check if this looks like a real sentence ending
|
96 |
+
if self._is_valid_sentence_ending(text, i):
|
97 |
+
best_end = i + 1
|
98 |
+
break
|
99 |
+
|
100 |
+
return best_end
|
101 |
+
|
102 |
+
def _is_valid_sentence_ending(self, text: str, pos: int) -> bool:
|
103 |
+
"""
|
104 |
+
Check if a punctuation mark represents a valid sentence ending.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
text: The full text
|
108 |
+
pos: Position of the punctuation mark
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
bool: True if it's a valid sentence ending
|
112 |
+
"""
|
113 |
+
# Avoid breaking on abbreviations like "Dr.", "Mr.", etc.
|
114 |
+
if pos > 0 and text[pos] == '.':
|
115 |
+
# Look at the character before the period
|
116 |
+
char_before = text[pos - 1]
|
117 |
+
if char_before.isupper():
|
118 |
+
# Might be an abbreviation
|
119 |
+
word_start = pos - 1
|
120 |
+
while word_start > 0 and text[word_start - 1].isalpha():
|
121 |
+
word_start -= 1
|
122 |
+
|
123 |
+
word = text[word_start:pos]
|
124 |
+
# Common abbreviations to avoid breaking on
|
125 |
+
abbreviations = {'Dr', 'Mr', 'Mrs', 'Ms', 'Prof', 'Inc', 'Ltd', 'Corp', 'Co'}
|
126 |
+
if word in abbreviations:
|
127 |
+
return False
|
128 |
+
|
129 |
+
# Check if there's a space or newline after the punctuation
|
130 |
+
if pos + 1 < len(text):
|
131 |
+
next_char = text[pos + 1]
|
132 |
+
return next_char.isspace() or next_char.isupper()
|
133 |
+
|
134 |
+
return True
|
135 |
+
|
136 |
+
def get_chunk_stats(self, chunks: List[str]) -> dict:
|
137 |
+
"""
|
138 |
+
Get statistics about the created chunks.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
chunks: List of text chunks
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
dict: Statistics about the chunks
|
145 |
+
"""
|
146 |
+
if not chunks:
|
147 |
+
return {
|
148 |
+
"total_chunks": 0,
|
149 |
+
"total_characters": 0,
|
150 |
+
"total_words": 0,
|
151 |
+
"avg_chunk_size": 0,
|
152 |
+
"min_chunk_size": 0,
|
153 |
+
"max_chunk_size": 0
|
154 |
+
}
|
155 |
+
|
156 |
+
chunk_sizes = [len(chunk) for chunk in chunks]
|
157 |
+
total_chars = sum(chunk_sizes)
|
158 |
+
total_words = sum(len(chunk.split()) for chunk in chunks)
|
159 |
+
|
160 |
+
return {
|
161 |
+
"total_chunks": len(chunks),
|
162 |
+
"total_characters": total_chars,
|
163 |
+
"total_words": total_words,
|
164 |
+
"avg_chunk_size": total_chars / len(chunks),
|
165 |
+
"min_chunk_size": min(chunk_sizes),
|
166 |
+
"max_chunk_size": max(chunk_sizes)
|
167 |
+
}
|
preprocessing/preprocessing_modules/text_extractor.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Text Extractor Module
|
3 |
+
|
4 |
+
Handles extracting text content from PDF files.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import pdfplumber
|
8 |
+
|
9 |
+
|
10 |
+
class TextExtractor:
|
11 |
+
"""Handles text extraction from PDF files."""
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
"""Initialize the text extractor."""
|
15 |
+
pass
|
16 |
+
|
17 |
+
async def extract_text_from_pdf(self, pdf_path: str) -> str:
|
18 |
+
"""
|
19 |
+
Extract text from PDF file.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
pdf_path: Path to the PDF file
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
str: Extracted text content
|
26 |
+
|
27 |
+
Raises:
|
28 |
+
Exception: If text extraction fails
|
29 |
+
"""
|
30 |
+
print(f"📖 Extracting text from PDF...")
|
31 |
+
|
32 |
+
full_text = ""
|
33 |
+
try:
|
34 |
+
with pdfplumber.open(pdf_path) as pdf:
|
35 |
+
for page_num, page in enumerate(pdf.pages):
|
36 |
+
text = page.extract_text()
|
37 |
+
if text:
|
38 |
+
full_text += f"\n--- Page {page_num + 1} ---\n"
|
39 |
+
full_text += text
|
40 |
+
|
41 |
+
print(f"✅ Extracted {len(full_text)} characters from PDF")
|
42 |
+
return full_text
|
43 |
+
|
44 |
+
except Exception as e:
|
45 |
+
raise Exception(f"Failed to extract text from PDF: {str(e)}")
|
46 |
+
|
47 |
+
def validate_extracted_text(self, text: str) -> bool:
|
48 |
+
"""
|
49 |
+
Validate that extracted text is not empty and contains meaningful content.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
text: The extracted text to validate
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
bool: True if text is valid, False otherwise
|
56 |
+
"""
|
57 |
+
if not text or not text.strip():
|
58 |
+
return False
|
59 |
+
|
60 |
+
# Check if text has at least some alphabetic characters
|
61 |
+
alphabetic_chars = sum(1 for char in text if char.isalpha())
|
62 |
+
return alphabetic_chars > 50 # At least 50 alphabetic characters
|
preprocessing/preprocessing_modules/vector_storage.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Vector Storage Module
|
3 |
+
|
4 |
+
Handles storing chunks and embeddings in Qdrant vector database.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from typing import List
|
9 |
+
from pathlib import Path
|
10 |
+
from qdrant_client import QdrantClient
|
11 |
+
from qdrant_client.models import Distance, VectorParams, PointStruct
|
12 |
+
|
13 |
+
|
14 |
+
class VectorStorage:
|
15 |
+
"""Handles vector storage operations with Qdrant."""
|
16 |
+
|
17 |
+
def __init__(self, base_db_path: Path):
|
18 |
+
"""
|
19 |
+
Initialize the vector storage.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
base_db_path: Base path for storing Qdrant databases
|
23 |
+
"""
|
24 |
+
self.base_db_path = base_db_path
|
25 |
+
|
26 |
+
async def store_in_qdrant(self, chunks: List[str], embeddings: np.ndarray, doc_id: str):
|
27 |
+
"""
|
28 |
+
Store chunks and embeddings in Qdrant.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
chunks: List of text chunks
|
32 |
+
embeddings: Corresponding embeddings array
|
33 |
+
doc_id: Document identifier
|
34 |
+
"""
|
35 |
+
if len(chunks) != embeddings.shape[0]:
|
36 |
+
raise ValueError(f"Chunk count ({len(chunks)}) doesn't match embedding count ({embeddings.shape[0]})")
|
37 |
+
|
38 |
+
collection_name = f"{doc_id}_collection"
|
39 |
+
db_path = self.base_db_path / f"{collection_name}.db"
|
40 |
+
client = QdrantClient(path=str(db_path))
|
41 |
+
|
42 |
+
print(f"💾 Storing {len(chunks)} vectors in collection: {collection_name}")
|
43 |
+
|
44 |
+
try:
|
45 |
+
# Create or recreate collection
|
46 |
+
await self._setup_collection(client, collection_name, embeddings.shape[1])
|
47 |
+
|
48 |
+
# Prepare and upload points
|
49 |
+
await self._upload_points(client, collection_name, chunks, embeddings, doc_id)
|
50 |
+
|
51 |
+
print(f"✅ Successfully stored all vectors in Qdrant")
|
52 |
+
|
53 |
+
finally:
|
54 |
+
client.close()
|
55 |
+
|
56 |
+
async def _setup_collection(self, client: QdrantClient, collection_name: str, embedding_dim: int):
|
57 |
+
"""
|
58 |
+
Set up Qdrant collection, recreating if it exists.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
client: Qdrant client
|
62 |
+
collection_name: Name of the collection
|
63 |
+
embedding_dim: Dimension of embeddings
|
64 |
+
"""
|
65 |
+
# Delete existing collection if it exists
|
66 |
+
try:
|
67 |
+
client.delete_collection(collection_name)
|
68 |
+
print(f"🗑️ Deleted existing collection: {collection_name}")
|
69 |
+
except Exception:
|
70 |
+
pass # Collection might not exist
|
71 |
+
|
72 |
+
# Create new collection
|
73 |
+
client.create_collection(
|
74 |
+
collection_name=collection_name,
|
75 |
+
vectors_config=VectorParams(
|
76 |
+
size=embedding_dim,
|
77 |
+
distance=Distance.COSINE
|
78 |
+
)
|
79 |
+
)
|
80 |
+
print(f"✅ Created new collection: {collection_name}")
|
81 |
+
|
82 |
+
async def _upload_points(self, client: QdrantClient, collection_name: str,
|
83 |
+
chunks: List[str], embeddings: np.ndarray, doc_id: str):
|
84 |
+
"""
|
85 |
+
Upload points to Qdrant collection in batches.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
client: Qdrant client
|
89 |
+
collection_name: Name of the collection
|
90 |
+
chunks: Text chunks
|
91 |
+
embeddings: Embedding vectors
|
92 |
+
doc_id: Document identifier
|
93 |
+
"""
|
94 |
+
# Prepare points
|
95 |
+
points = []
|
96 |
+
for i in range(len(chunks)):
|
97 |
+
points.append(
|
98 |
+
PointStruct(
|
99 |
+
id=i,
|
100 |
+
vector=embeddings[i].tolist(),
|
101 |
+
payload={
|
102 |
+
"text": chunks[i],
|
103 |
+
"chunk_id": i,
|
104 |
+
"doc_id": doc_id,
|
105 |
+
"char_count": len(chunks[i]),
|
106 |
+
"word_count": len(chunks[i].split())
|
107 |
+
}
|
108 |
+
)
|
109 |
+
)
|
110 |
+
|
111 |
+
# Upload in batches to handle large documents
|
112 |
+
batch_size = 100
|
113 |
+
total_batches = (len(points) + batch_size - 1) // batch_size
|
114 |
+
|
115 |
+
for i in range(0, len(points), batch_size):
|
116 |
+
batch = points[i:i + batch_size]
|
117 |
+
batch_num = (i // batch_size) + 1
|
118 |
+
|
119 |
+
print(f" Uploading batch {batch_num}/{total_batches} ({len(batch)} points)")
|
120 |
+
client.upsert(collection_name=collection_name, points=batch)
|
121 |
+
|
122 |
+
print(f"✅ Uploaded {len(points)} points in {total_batches} batches")
|
123 |
+
|
124 |
+
def collection_exists(self, doc_id: str) -> bool:
|
125 |
+
"""
|
126 |
+
Check if a collection exists for the given document ID.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
doc_id: Document identifier
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
bool: True if collection exists, False otherwise
|
133 |
+
"""
|
134 |
+
collection_name = f"{doc_id}_collection"
|
135 |
+
db_path = self.base_db_path / f"{collection_name}.db"
|
136 |
+
return db_path.exists()
|
137 |
+
|
138 |
+
def get_collection_info(self, doc_id: str) -> dict:
|
139 |
+
"""
|
140 |
+
Get information about a collection.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
doc_id: Document identifier
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
dict: Collection information
|
147 |
+
"""
|
148 |
+
collection_name = f"{doc_id}_collection"
|
149 |
+
db_path = self.base_db_path / f"{collection_name}.db"
|
150 |
+
|
151 |
+
if not db_path.exists():
|
152 |
+
return {
|
153 |
+
"collection_name": collection_name,
|
154 |
+
"exists": False,
|
155 |
+
"path": str(db_path)
|
156 |
+
}
|
157 |
+
|
158 |
+
try:
|
159 |
+
client = QdrantClient(path=str(db_path))
|
160 |
+
try:
|
161 |
+
collection_info = client.get_collection(collection_name)
|
162 |
+
return {
|
163 |
+
"collection_name": collection_name,
|
164 |
+
"exists": True,
|
165 |
+
"path": str(db_path),
|
166 |
+
"vectors_count": collection_info.vectors_count,
|
167 |
+
"status": collection_info.status
|
168 |
+
}
|
169 |
+
finally:
|
170 |
+
client.close()
|
171 |
+
except Exception as e:
|
172 |
+
return {
|
173 |
+
"collection_name": collection_name,
|
174 |
+
"exists": True,
|
175 |
+
"path": str(db_path),
|
176 |
+
"error": str(e)
|
177 |
+
}
|
178 |
+
|
179 |
+
def delete_collection(self, doc_id: str) -> bool:
|
180 |
+
"""
|
181 |
+
Delete a collection and its database file.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
doc_id: Document identifier
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
bool: True if successfully deleted, False otherwise
|
188 |
+
"""
|
189 |
+
collection_name = f"{doc_id}_collection"
|
190 |
+
db_path = self.base_db_path / f"{collection_name}.db"
|
191 |
+
|
192 |
+
try:
|
193 |
+
if db_path.exists():
|
194 |
+
# Try to delete collection properly first
|
195 |
+
try:
|
196 |
+
client = QdrantClient(path=str(db_path))
|
197 |
+
client.delete_collection(collection_name)
|
198 |
+
client.close()
|
199 |
+
except Exception:
|
200 |
+
pass # Collection might not exist or be corrupted
|
201 |
+
|
202 |
+
# Remove database directory
|
203 |
+
import shutil
|
204 |
+
shutil.rmtree(db_path, ignore_errors=True)
|
205 |
+
print(f"🗑️ Deleted collection: {collection_name}")
|
206 |
+
return True
|
207 |
+
|
208 |
+
except Exception as e:
|
209 |
+
print(f"❌ Error deleting collection {collection_name}: {e}")
|
210 |
+
return False
|
211 |
+
|
212 |
+
return True # Nothing to delete
|
preprocessing/preprocessing_modules/xlsx_extractor.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openpyxl import load_workbook
|
2 |
+
from openpyxl.drawing.image import Image as OpenPyXLImage
|
3 |
+
from typing import List, Dict, Any
|
4 |
+
from PIL import Image
|
5 |
+
from io import BytesIO
|
6 |
+
import pytesseract
|
7 |
+
import os
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
def extract_xlsx(xlsx_path: str, tesseract_cmd: str = None) -> str:
|
11 |
+
"""Extract data from Excel files including text and images."""
|
12 |
+
if tesseract_cmd:
|
13 |
+
pytesseract.pytesseract.tesseract_cmd = tesseract_cmd
|
14 |
+
|
15 |
+
try:
|
16 |
+
wb = load_workbook(xlsx_path, data_only=True)
|
17 |
+
except Exception as e:
|
18 |
+
return f"Error loading Excel file: {str(e)}"
|
19 |
+
|
20 |
+
all_sheets_content: list[str] = []
|
21 |
+
preview_text: str | None = None
|
22 |
+
any_data_found = False
|
23 |
+
|
24 |
+
for sheet in wb.worksheets:
|
25 |
+
sheet_content = [f"\n=== Sheet: {sheet.title} ===\n"]
|
26 |
+
|
27 |
+
# Extract table data
|
28 |
+
has_data = False
|
29 |
+
non_empty_rows = 0
|
30 |
+
for row in sheet.iter_rows(max_row=sheet.max_row, values_only=True):
|
31 |
+
if row is None or all(cell is None for cell in row):
|
32 |
+
continue # skip completely empty rows
|
33 |
+
has_data = True
|
34 |
+
non_empty_rows += 1
|
35 |
+
any_data_found = True
|
36 |
+
row_data = [str(cell).strip() if cell is not None else "" for cell in row]
|
37 |
+
joined = " | ".join(row_data)
|
38 |
+
sheet_content.append(joined)
|
39 |
+
if preview_text is None and joined.strip():
|
40 |
+
preview_text = joined[:15]
|
41 |
+
|
42 |
+
if not has_data:
|
43 |
+
sheet_content.append("[No data in this sheet]")
|
44 |
+
print(f"ℹ️ XLSX: Sheet '{sheet.title}' has no data (openpyxl)")
|
45 |
+
else:
|
46 |
+
print(f"🧾 XLSX: Sheet '{sheet.title}' non-empty rows: {non_empty_rows}")
|
47 |
+
|
48 |
+
# Extract images from the sheet
|
49 |
+
if hasattr(sheet, '_images'):
|
50 |
+
image_count = 0
|
51 |
+
for img in sheet._images:
|
52 |
+
try:
|
53 |
+
if hasattr(img, '_data'): # if it's a real OpenPyXL Image
|
54 |
+
image_data = img._data()
|
55 |
+
elif hasattr(img, '_ref'):
|
56 |
+
continue # cell ref-only images; ignore
|
57 |
+
else:
|
58 |
+
continue
|
59 |
+
|
60 |
+
pil_img = Image.open(BytesIO(image_data))
|
61 |
+
try:
|
62 |
+
ocr_text = pytesseract.image_to_string(pil_img).strip()
|
63 |
+
if ocr_text:
|
64 |
+
sheet_content.append(f"[Image {image_count + 1} Text]: {ocr_text}")
|
65 |
+
else:
|
66 |
+
sheet_content.append(f"[Image {image_count + 1}]: No text detected")
|
67 |
+
except Exception as ocr_e:
|
68 |
+
sheet_content.append(f"[Image {image_count + 1}]: OCR failed - {str(ocr_e)}")
|
69 |
+
|
70 |
+
image_count += 1
|
71 |
+
except Exception as e:
|
72 |
+
sheet_content.append(f"[Image extraction error: {str(e)}]")
|
73 |
+
|
74 |
+
if image_count == 0:
|
75 |
+
sheet_content.append("[No images found in this sheet]")
|
76 |
+
|
77 |
+
all_sheets_content.append("\n".join(sheet_content))
|
78 |
+
|
79 |
+
# If no data found using openpyxl, try pandas fallback (handles some edge cases better)
|
80 |
+
if not any_data_found:
|
81 |
+
print("ℹ️ XLSX: No data via openpyxl, trying pandas fallback…")
|
82 |
+
try:
|
83 |
+
xls = pd.ExcelFile(xlsx_path, engine="openpyxl")
|
84 |
+
pandas_parts = []
|
85 |
+
extracted_sheets = 0
|
86 |
+
for sheet_name in xls.sheet_names:
|
87 |
+
df = pd.read_excel(xls, sheet_name=sheet_name, dtype=str)
|
88 |
+
if not df.empty:
|
89 |
+
any_data_found = True
|
90 |
+
header = f"\n=== Sheet: {sheet_name} ===\n"
|
91 |
+
csv_like = df.fillna("").astype(str).to_csv(index=False)
|
92 |
+
pandas_parts.append(header + csv_like)
|
93 |
+
extracted_sheets += 1
|
94 |
+
if preview_text is None:
|
95 |
+
flat = "".join(csv_like.splitlines())
|
96 |
+
if flat:
|
97 |
+
preview_text = flat[:15]
|
98 |
+
else:
|
99 |
+
pandas_parts.append(f"\n=== Sheet: {sheet_name} ===\n[No data in this sheet]")
|
100 |
+
if pandas_parts:
|
101 |
+
all_sheets_content = pandas_parts
|
102 |
+
print(f"✅ XLSX: Pandas fallback extracted {extracted_sheets} non-empty sheet(s)")
|
103 |
+
except Exception as pe:
|
104 |
+
# If pandas also fails, keep whatever we had
|
105 |
+
all_sheets_content.append(f"[Pandas fallback failed: {str(pe)}]")
|
106 |
+
print(f"❌ XLSX: Pandas fallback failed: {pe}")
|
107 |
+
|
108 |
+
combined = "\n\n".join(all_sheets_content)
|
109 |
+
|
110 |
+
# Print a small preview for verification
|
111 |
+
if preview_text is None:
|
112 |
+
# fallback: take from combined text
|
113 |
+
flat_combined = "".join(combined.splitlines()).strip()
|
114 |
+
if flat_combined:
|
115 |
+
preview_text = flat_combined[:15]
|
116 |
+
if preview_text:
|
117 |
+
print(f"🔎 XLSX content preview: {preview_text}")
|
118 |
+
|
119 |
+
return combined
|
requirements.txt
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
uvicorn
|
3 |
+
pydantic
|
4 |
+
aiohttp
|
5 |
+
pdfplumber
|
6 |
+
sentence-transformers
|
7 |
+
qdrant-client
|
8 |
+
openai
|
9 |
+
google-generativeai
|
10 |
+
numpy
|
11 |
+
tqdm
|
12 |
+
python-multipart
|
13 |
+
jinja2
|
14 |
+
python-dotenv
|
15 |
+
rank-bm25
|
16 |
+
transformers
|
17 |
+
torch
|
18 |
+
|
19 |
+
# New dependencies for multiple file format support
|
20 |
+
python-docx
|
21 |
+
python-pptx
|
22 |
+
openpyxl
|
23 |
+
pillow
|
24 |
+
pytesseract
|
25 |
+
opencv-python
|
26 |
+
pandas
|
27 |
+
beautifulsoup4
|
28 |
+
lxml
|
29 |
+
langchain-groq
|
30 |
+
langchain-google-genai
|
31 |
+
langchain-core
|
32 |
+
httpx
|
33 |
+
groq
|