anindya-hf-2002 commited on
Commit
fe52a97
·
verified ·
1 Parent(s): 3fb0088

Upload 12 files

Browse files
Files changed (12) hide show
  1. .gitignore +171 -0
  2. Dockerfile +39 -0
  3. app.py +353 -0
  4. execute.sh +61 -0
  5. notebooks/tabular_rag.ipynb +1476 -0
  6. requirements.txt +14 -0
  7. src/embedding.py +64 -0
  8. src/llm.py +143 -0
  9. src/loader.py +153 -0
  10. src/processor.py +103 -0
  11. src/table_aware_chunker.py +117 -0
  12. src/vectordb.py +314 -0
.gitignore ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
Dockerfile ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python slim image as base
2
+ FROM python:3.10-slim
3
+
4
+ # Install system dependencies and wget (for downloading Ollama)
5
+ RUN apt-get update && \
6
+ apt-get install -y \
7
+ curl \
8
+ procps \
9
+ git \
10
+ wget \
11
+ lsof \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Set working directory
15
+ WORKDIR /app
16
+
17
+ # Copy project files
18
+ COPY requirements.txt .
19
+ COPY src/ ./src/
20
+ COPY *.py ./
21
+ COPY execute.sh ./execute.sh
22
+
23
+ # Install Python dependencies
24
+ RUN pip install --no-cache-dir -r requirements.txt
25
+
26
+ # Set Python path
27
+ ENV PYTHONPATH=/app
28
+
29
+ # Create directory for Ollama models
30
+ RUN mkdir -p /root/.ollama
31
+
32
+ # Expose ports for both Streamlit and Ollama
33
+ EXPOSE 8501 11434
34
+
35
+ # Make sure execute.sh is executable
36
+ RUN chmod +x ./execute.sh
37
+
38
+ # Set the entrypoint
39
+ ENTRYPOINT ["./execute.sh"]
app.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pathlib import Path
3
+ import tempfile
4
+ import os
5
+ import time
6
+ from typing import List, Dict
7
+ from pinecone import Pinecone
8
+ from src.table_aware_chunker import TableRecursiveChunker
9
+ from src.processor import TableProcessor
10
+ from src.llm import LLMChat
11
+ from src.embedding import EmbeddingModel
12
+ from chonkie import RecursiveRules
13
+ from src.vectordb import ChunkType, process_documents, ingest_data, PineconeRetriever
14
+
15
+ # Custom CSS for better UI
16
+ st.set_page_config(
17
+ page_title="📚 Table RAG Assistant",
18
+ layout="wide",
19
+ initial_sidebar_state="expanded"
20
+ )
21
+
22
+ st.markdown("""
23
+ <style>
24
+ .stApp {
25
+ max-width: 1200px;
26
+ margin: 0 auto;
27
+ }
28
+ .chat-message {
29
+ padding: 1.5rem;
30
+ border-radius: 0.5rem;
31
+ margin-bottom: 1rem;
32
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
33
+ }
34
+ .user-message {
35
+ background-color: #f0f2f6;
36
+ }
37
+ .assistant-message {
38
+ background-color: #e8f0fe;
39
+ }
40
+ .st-emotion-cache-1v0mbdj.e115fcil1 {
41
+ border-radius: 0.5rem;
42
+ }
43
+ </style>
44
+ """, unsafe_allow_html=True)
45
+
46
+ # Initialize session states
47
+ if "messages" not in st.session_state:
48
+ st.session_state.messages = []
49
+ if "documents_processed" not in st.session_state:
50
+ st.session_state.documents_processed = False
51
+ if "retriever" not in st.session_state:
52
+ st.session_state.retriever = None
53
+ if "llm" not in st.session_state:
54
+ st.session_state.llm = None
55
+ if "uploaded_files" not in st.session_state:
56
+ st.session_state.uploaded_files = []
57
+
58
+ # Enhanced RAG Template using LangChain's ChatPromptTemplate
59
+ RAG_TEMPLATE = [
60
+ {
61
+ "role": "system",
62
+ "content": """You are a knowledgeable assistant specialized in analyzing documents and tables.
63
+ Your responses should be:
64
+ - Accurate and based on the provided context
65
+ - Concise (three sentences maximum)
66
+ - Professional yet conversational
67
+ - Include specific references to tables when relevant
68
+
69
+ If you cannot find an answer in the context, acknowledge this clearly."""
70
+ },
71
+ {
72
+ "role": "human",
73
+ "content": "Context: {context}\n\nQuestion: {question}"
74
+ }
75
+ ]
76
+
77
+ def simulate_streaming_response(text: str, delay: float = 0.02) -> str:
78
+ """Simulate streaming response by yielding chunks of text with delay."""
79
+ words = text.split()
80
+ result = ""
81
+
82
+ for i, word in enumerate(words):
83
+ result += word + " "
84
+ time.sleep(delay)
85
+ # Add punctuation pause
86
+ if any(p in word for p in ['.', '!', '?', ',']):
87
+ time.sleep(delay * 2)
88
+ yield result
89
+
90
+ def clear_pinecone_index(pc, index_name="vector-index"):
91
+ """Clear the Pinecone index and reset app state."""
92
+ try:
93
+ if pc.has_index(index_name):
94
+ pc.delete_index(index_name)
95
+ st.session_state.documents_processed = False
96
+ st.session_state.retriever = None
97
+ st.session_state.messages = []
98
+ st.session_state.llm = None
99
+ st.session_state.uploaded_files = []
100
+ st.success("🧹 Database cleared successfully!")
101
+ except Exception as e:
102
+ st.error(f"❌ Error clearing database: {str(e)}")
103
+
104
+ def format_context(results: List[Dict]) -> str:
105
+ """Format retrieved results into context string."""
106
+ context_parts = []
107
+
108
+ for result in results:
109
+ if result.get("chunk_type") == ChunkType.TABLE.value:
110
+ table_text = f"Table: {result['markdown_table']}"
111
+ if result.get("table_description"):
112
+ table_text += f"\nDescription: {result['table_description']}"
113
+ context_parts.append(table_text)
114
+ else:
115
+ context_parts.append(result.get("page_content", ""))
116
+
117
+ return "\n\n".join(context_parts)
118
+
119
+ def format_chat_message(message: Dict[str, str], results: List[Dict] = None) -> str:
120
+ """Format chat message with retrieved tables in a visually appealing way."""
121
+ content = message["content"]
122
+
123
+ if results:
124
+ for result in results:
125
+ if result.get("chunk_type") == ChunkType.TABLE.value:
126
+ content += "\n\n---\n\n📊 **Relevant Table:**\n" + result['markdown_table']
127
+
128
+ return content
129
+
130
+ def initialize_components(pinecone_api_key: str):
131
+ """Initialize all required components with LangChain integration."""
132
+ try:
133
+ # Initialize Pinecone
134
+ pc = Pinecone(api_key=pinecone_api_key)
135
+
136
+ # Initialize LangChain LLM with custom parameters
137
+ llm = LLMChat(
138
+ model_name="mistral:7b",
139
+ temperature=0.3 # Lower temperature for more focused responses
140
+ )
141
+ st.session_state.llm = llm
142
+
143
+ # Initialize LangChain Embeddings
144
+ embedder = EmbeddingModel("nomic-embed-text")
145
+
146
+ # Initialize Chunker
147
+ chunker = TableRecursiveChunker(
148
+ tokenizer="gpt2",
149
+ chunk_size=512,
150
+ rules=RecursiveRules(),
151
+ min_characters_per_chunk=12
152
+ )
153
+
154
+ # Initialize Processor
155
+ processor = TableProcessor(
156
+ llm_model=llm,
157
+ embedding_model=embedder,
158
+ batch_size=8
159
+ )
160
+
161
+ return pc, llm, embedder, chunker, processor
162
+
163
+ except Exception as e:
164
+ st.error(f"❌ Error initializing components: {str(e)}")
165
+ return None, None, None, None, None
166
+
167
+ def process_all_documents(uploaded_files, chunker, processor, pc, embedder):
168
+ """Process uploaded documents with enhanced progress tracking."""
169
+ if not uploaded_files:
170
+ st.warning("📤 Please upload at least one document.")
171
+ return False
172
+
173
+ try:
174
+ temp_dir = tempfile.mkdtemp()
175
+ file_paths = []
176
+
177
+ with st.status("📑 Processing Documents", expanded=True) as status:
178
+ # Save uploaded files
179
+ st.write("📁 Saving uploaded files...")
180
+ for uploaded_file in uploaded_files:
181
+ st.write(f"Saving {uploaded_file.name}...")
182
+ file_path = Path(temp_dir) / uploaded_file.name
183
+ with open(file_path, "wb") as f:
184
+ f.write(uploaded_file.getvalue())
185
+ file_paths.append(str(file_path))
186
+
187
+ # Process documents
188
+ st.write("🔄 Processing documents...")
189
+ processed_chunks = process_documents(
190
+ file_paths=file_paths,
191
+ chunker=chunker,
192
+ processor=processor,
193
+ output_path='./output.md'
194
+ )
195
+
196
+ # Ingest data
197
+ st.write("📥 Ingesting data to vector database...")
198
+ ingest_data(
199
+ processed_chunks=processed_chunks,
200
+ embedding_model=embedder,
201
+ pinecone_client=pc
202
+ )
203
+
204
+ # Setup retriever
205
+ st.write("🎯 Setting up retriever...")
206
+ st.session_state.retriever = PineconeRetriever(
207
+ pinecone_client=pc,
208
+ index_name="vector-index",
209
+ namespace="rag",
210
+ embedding_model=embedder,
211
+ llm_model=st.session_state.llm
212
+ )
213
+
214
+ st.session_state.documents_processed = True
215
+ status.update(label="✅ Processing complete!", state="complete", expanded=False)
216
+
217
+ return True
218
+
219
+ except Exception as e:
220
+ st.error(f"❌ Error processing documents: {str(e)}")
221
+ return False
222
+
223
+ finally:
224
+ # Cleanup
225
+ for file_path in file_paths:
226
+ try:
227
+ os.remove(file_path)
228
+ except Exception:
229
+ pass
230
+ try:
231
+ os.rmdir(temp_dir)
232
+ except Exception:
233
+ pass
234
+
235
+ def main():
236
+ st.title("📚 Table RAG Assistant")
237
+ st.markdown("---")
238
+ pc = None
239
+ # Sidebar Configuration with improved styling
240
+ with st.sidebar:
241
+ st.title("⚙️ Configuration")
242
+ pinecone_api_key = st.text_input("🔑 Enter Pinecone API Key:", type="password")
243
+
244
+ st.markdown("---")
245
+ col1, col2 = st.columns(2)
246
+
247
+ with col1:
248
+ if st.button("🧹 Clear DB", use_container_width=True):
249
+ clear_pinecone_index(pc)
250
+
251
+ with col2:
252
+ if st.button("🗑️ Clear Chat", use_container_width=True):
253
+ st.session_state.messages = []
254
+ st.session_state.llm.clear_history()
255
+ st.rerun()
256
+
257
+ # Display uploaded files
258
+ if st.session_state.uploaded_files:
259
+ st.markdown("---")
260
+ st.subheader("📁 Uploaded Files")
261
+ for file in st.session_state.uploaded_files:
262
+ st.write(f"- {file.name}")
263
+
264
+ pc = None
265
+ if not pinecone_api_key:
266
+ st.sidebar.warning("⚠️ Please enter Pinecone API key to continue.")
267
+ st.stop()
268
+
269
+ # Initialize components if not already done
270
+ if st.session_state.retriever is None:
271
+ pc, llm, embedder, chunker, processor = initialize_components(pinecone_api_key)
272
+ clear_pinecone_index(pc)
273
+ if None in (pc, llm, embedder, chunker, processor):
274
+ st.stop()
275
+
276
+ # Document Upload Section with improved UI
277
+ if not st.session_state.documents_processed:
278
+ st.header("📄 Document Upload")
279
+ st.markdown("Upload your documents to get started. Supported formats: PDF, DOCX, TXT, CSV, XLSX")
280
+
281
+ uploaded_files = st.file_uploader(
282
+ "Drop your files here",
283
+ accept_multiple_files=True,
284
+ type=["pdf", "docx", "txt", "csv", "xlsx"]
285
+ )
286
+
287
+ if uploaded_files:
288
+ st.session_state.uploaded_files = uploaded_files
289
+
290
+ if st.button("🚀 Process Documents", use_container_width=True):
291
+ if process_all_documents(uploaded_files, chunker, processor, pc, embedder):
292
+ st.success("✨ Documents processed successfully!")
293
+
294
+ # Enhanced Chat Interface with Simulated Streaming
295
+ if st.session_state.documents_processed:
296
+ st.header("💬 Chat Interface")
297
+ st.markdown("Ask questions about your documents and tables")
298
+
299
+ # Display chat history with improved styling
300
+ for message in st.session_state.messages:
301
+ with st.chat_message(message["role"]):
302
+ st.markdown(format_chat_message(message, message.get("results")))
303
+
304
+ # Chat input with simulated streaming
305
+ if prompt := st.chat_input("Ask a question..."):
306
+ # Display user message
307
+ with st.chat_message("user"):
308
+ st.markdown(prompt)
309
+ st.session_state.messages.append({"role": "user", "content": prompt})
310
+
311
+ # Generate response with simulated streaming
312
+ with st.chat_message("assistant"):
313
+ response_placeholder = st.empty()
314
+
315
+ with st.spinner("🤔 Thinking..."):
316
+ # Retrieve relevant content
317
+ results = st.session_state.retriever.invoke(
318
+ question=prompt,
319
+ top_k=3
320
+ )
321
+
322
+ # Format context and get response from LLM
323
+ context = format_context(results)
324
+ chat = st.session_state.llm
325
+
326
+ input_vars = {
327
+ "question": prompt,
328
+ "context": context
329
+ }
330
+
331
+ # Get full response first
332
+ full_response = chat.chat_with_template(RAG_TEMPLATE, input_vars)
333
+
334
+ # Simulate streaming of the response
335
+ for partial_response in simulate_streaming_response(full_response):
336
+ response_placeholder.markdown(partial_response + "▌")
337
+
338
+ # Display final response with tables
339
+ formatted_response = format_chat_message(
340
+ {"role": "assistant", "content": full_response},
341
+ results
342
+ )
343
+ response_placeholder.markdown(formatted_response)
344
+
345
+ # Save to chat history
346
+ st.session_state.messages.append({
347
+ "role": "assistant",
348
+ "content": full_response,
349
+ "results": results
350
+ })
351
+
352
+ if __name__ == "__main__":
353
+ main()
execute.sh ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Function to log messages
4
+ log() {
5
+ echo "$(date +'%Y-%m-%d %H:%M:%S') - $1"
6
+ }
7
+
8
+ # Install and run Ollama
9
+ log "Starting Ollama installation..."
10
+ if ! curl -fsSL https://ollama.com/install.sh | sh; then
11
+ log "Failed to install Ollama."
12
+ exit 1
13
+ fi
14
+ log "Ollama installation completed."
15
+
16
+ # Sleep for a short duration to ensure installation completes
17
+ sleep 5
18
+
19
+ # Check if ollama command is available
20
+ if ! command -v ollama &> /dev/null; then
21
+ log "Ollama command not found. Installation may have failed."
22
+ exit 1
23
+ fi
24
+
25
+ # Start the Ollama server in the background
26
+ log "Starting Ollama server..."
27
+ ollama serve &
28
+ OLLAMA_PID=$!
29
+
30
+ # Wait for Ollama server to start (you may need to adjust sleep duration)
31
+ log "Waiting for Ollama server to start..."
32
+ sleep 10
33
+
34
+ # Check if Ollama server is running
35
+ if ! pgrep -x "ollama" > /dev/null; then
36
+ log "Ollama server did not start successfully."
37
+ kill $OLLAMA_PID 2>/dev/null || true
38
+ exit 1
39
+ fi
40
+
41
+ # Pull the required Ollama model(s) during runtime
42
+ log "Pulling Ollama models..."
43
+ if ! ollama pull nomic-embed-text; then
44
+ log "Failed to pull nomic-embed-text model."
45
+ kill $OLLAMA_PID 2>/dev/null || true
46
+ exit 1
47
+ fi
48
+
49
+ if ! ollama pull mistral:7b; then
50
+ log "Failed to pull mistral:7b model."
51
+ kill $OLLAMA_PID 2>/dev/null || true
52
+ exit 1
53
+ fi
54
+ log "Models pulled successfully."
55
+
56
+ # Sleep for a short duration to ensure models are downloaded and ready
57
+ sleep 5
58
+
59
+ # Start Streamlit app
60
+ log "Starting Streamlit app..."
61
+ exec streamlit run --server.address 0.0.0.0 --server.port 8501 app.py
notebooks/tabular_rag.ipynb ADDED
@@ -0,0 +1,1476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Note: you may need to restart the kernel to use updated packages.\n"
13
+ ]
14
+ }
15
+ ],
16
+ "source": [
17
+ "pip install langchain-community tiktoken langchainhub langchain langchain-huggingface sentence_transformers langchain-ollama ollama docling easyocr FlagEmbedding chonkie pinecone --quiet"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "curl -fsSL https://ollama.com/install.sh | sh\n",
27
+ "sleep 1\n",
28
+ "ollama pull nomic-embed-text\n",
29
+ "ollama pull mistral:7b"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 1,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "from pathlib import Path\n",
39
+ "from typing import List, Union\n",
40
+ "import logging\n",
41
+ "from dataclasses import dataclass\n",
42
+ "\n",
43
+ "from langchain_core.documents import Document as LCDocument\n",
44
+ "from langchain_core.document_loaders import BaseLoader\n",
45
+ "from docling.document_converter import DocumentConverter, PdfFormatOption\n",
46
+ "from docling.datamodel.base_models import InputFormat, ConversionStatus\n",
47
+ "from docling.datamodel.pipeline_options import (\n",
48
+ " PdfPipelineOptions,\n",
49
+ " EasyOcrOptions\n",
50
+ ")\n",
51
+ "\n",
52
+ "logging.basicConfig(level=logging.INFO)\n",
53
+ "_log = logging.getLogger(__name__)\n",
54
+ "\n",
55
+ "@dataclass\n",
56
+ "class ProcessingResult:\n",
57
+ " \"\"\"Store results of document processing\"\"\"\n",
58
+ " success_count: int = 0\n",
59
+ " failure_count: int = 0\n",
60
+ " partial_success_count: int = 0\n",
61
+ " failed_files: List[str] = None\n",
62
+ "\n",
63
+ " def __post_init__(self):\n",
64
+ " if self.failed_files is None:\n",
65
+ " self.failed_files = []\n",
66
+ "\n",
67
+ "class MultiFormatDocumentLoader(BaseLoader):\n",
68
+ " \"\"\"Loader for multiple document formats that converts to LangChain documents\"\"\"\n",
69
+ " \n",
70
+ " def __init__(\n",
71
+ " self,\n",
72
+ " file_paths: Union[str, List[str]],\n",
73
+ " enable_ocr: bool = True,\n",
74
+ " enable_tables: bool = True\n",
75
+ " ):\n",
76
+ " self._file_paths = [file_paths] if isinstance(file_paths, str) else file_paths\n",
77
+ " self._enable_ocr = enable_ocr\n",
78
+ " self._enable_tables = enable_tables\n",
79
+ " self._converter = self._setup_converter()\n",
80
+ " \n",
81
+ " def _setup_converter(self):\n",
82
+ " \"\"\"Set up the document converter with appropriate options\"\"\"\n",
83
+ " # Configure pipeline options\n",
84
+ " pipeline_options = PdfPipelineOptions(do_ocr=False, do_table_structure=False, ocr_options=EasyOcrOptions(\n",
85
+ " force_full_page_ocr=True\n",
86
+ " ))\n",
87
+ " if self._enable_ocr:\n",
88
+ " pipeline_options.do_ocr = True\n",
89
+ " if self._enable_tables:\n",
90
+ " pipeline_options.do_table_structure = True\n",
91
+ " pipeline_options.table_structure_options.do_cell_matching = True\n",
92
+ "\n",
93
+ " # Create converter with supported formats\n",
94
+ " return DocumentConverter(\n",
95
+ " allowed_formats=[\n",
96
+ " InputFormat.PDF,\n",
97
+ " InputFormat.IMAGE,\n",
98
+ " InputFormat.DOCX,\n",
99
+ " InputFormat.HTML,\n",
100
+ " InputFormat.PPTX,\n",
101
+ " InputFormat.ASCIIDOC,\n",
102
+ " InputFormat.MD,\n",
103
+ " ],\n",
104
+ " format_options={\n",
105
+ " InputFormat.PDF: PdfFormatOption(\n",
106
+ " pipeline_options=pipeline_options,\n",
107
+ " )}\n",
108
+ " )\n",
109
+ "\n",
110
+ " def lazy_load(self):\n",
111
+ " \"\"\"Convert documents and yield LangChain documents\"\"\"\n",
112
+ " results = ProcessingResult()\n",
113
+ " \n",
114
+ " for file_path in self._file_paths:\n",
115
+ " try:\n",
116
+ " path = Path(file_path)\n",
117
+ " if not path.exists():\n",
118
+ " _log.warning(f\"File not found: {file_path}\")\n",
119
+ " results.failure_count += 1\n",
120
+ " results.failed_files.append(file_path)\n",
121
+ " continue\n",
122
+ "\n",
123
+ " conversion_result = self._converter.convert(path)\n",
124
+ " \n",
125
+ " if conversion_result.status == ConversionStatus.SUCCESS:\n",
126
+ " results.success_count += 1\n",
127
+ " text = conversion_result.document.export_to_markdown()\n",
128
+ " metadata = {\n",
129
+ " 'source': str(path),\n",
130
+ " 'file_type': path.suffix,\n",
131
+ " }\n",
132
+ " yield LCDocument(\n",
133
+ " page_content=text,\n",
134
+ " metadata=metadata\n",
135
+ " )\n",
136
+ " elif conversion_result.status == ConversionStatus.PARTIAL_SUCCESS:\n",
137
+ " results.partial_success_count += 1\n",
138
+ " _log.warning(f\"Partial conversion for {file_path}\")\n",
139
+ " text = conversion_result.document.export_to_markdown()\n",
140
+ " metadata = {\n",
141
+ " 'source': str(path),\n",
142
+ " 'file_type': path.suffix,\n",
143
+ " 'conversion_status': 'partial'\n",
144
+ " }\n",
145
+ " yield LCDocument(\n",
146
+ " page_content=text,\n",
147
+ " metadata=metadata\n",
148
+ " )\n",
149
+ " else:\n",
150
+ " results.failure_count += 1\n",
151
+ " results.failed_files.append(file_path)\n",
152
+ " _log.error(f\"Failed to convert {file_path}\")\n",
153
+ " \n",
154
+ " except Exception as e:\n",
155
+ " _log.error(f\"Error processing {file_path}: {str(e)}\")\n",
156
+ " results.failure_count += 1\n",
157
+ " results.failed_files.append(file_path)\n",
158
+ "\n",
159
+ " # Log final results\n",
160
+ " total = results.success_count + results.partial_success_count + results.failure_count\n",
161
+ " _log.info(\n",
162
+ " f\"Processed {total} documents:\\n\"\n",
163
+ " f\"- Successfully converted: {results.success_count}\\n\"\n",
164
+ " f\"- Partially converted: {results.partial_success_count}\\n\"\n",
165
+ " f\"- Failed: {results.failure_count}\"\n",
166
+ " )\n",
167
+ " if results.failed_files:\n",
168
+ " _log.info(\"Failed files:\")\n",
169
+ " for file in results.failed_files:\n",
170
+ " _log.info(f\"- {file}\")\n",
171
+ " \n",
172
+ " \n",
173
+ "# if __name__ == '__main__':\n",
174
+ "# # Load documents from a list of file paths\n",
175
+ "# loader = MultiFormatDocumentLoader(\n",
176
+ "# file_paths=[\n",
177
+ "# # './data/2404.19756v1.pdf',\n",
178
+ "# # './data/OD429347375590223100.pdf',\n",
179
+ "# '/teamspace/studios/this_studio/TabularRAG/data/FeesPaymentReceipt_7thsem.pdf',\n",
180
+ "# # './data/UNIT 2 GENDER BASED VIOLENCE.pptx'\n",
181
+ "# ],\n",
182
+ "# enable_ocr=False,\n",
183
+ "# enable_tables=True\n",
184
+ "# )\n",
185
+ "# for doc in loader.lazy_load():\n",
186
+ "# print(doc.page_content)\n",
187
+ "# print(doc.metadata)\n",
188
+ "# # save document in .md file \n",
189
+ "# with open('/teamspace/studios/this_studio/TabularRAG/data/output.md', 'w') as f:\n",
190
+ "# f.write(doc.page_content)"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": 2,
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "from typing import List, Tuple, Union\n",
200
+ "import re\n",
201
+ "from dataclasses import dataclass\n",
202
+ "from chonkie.chunker import RecursiveChunker\n",
203
+ "from chonkie.types import RecursiveChunk\n",
204
+ "from chonkie import RecursiveRules\n",
205
+ "\n",
206
+ "@dataclass\n",
207
+ "class TableChunk:\n",
208
+ " \"\"\"Represents a table chunk from the markdown document.\"\"\"\n",
209
+ " text: str\n",
210
+ " start_index: int\n",
211
+ " end_index: int\n",
212
+ " token_count: int\n",
213
+ "\n",
214
+ "class TableRecursiveChunker(RecursiveChunker):\n",
215
+ " \"\"\"A recursive chunker that preserves markdown tables while chunking text.\n",
216
+ " \n",
217
+ " This chunker extends the base RecursiveChunker to handle markdown tables as special cases,\n",
218
+ " keeping them intact rather than splitting them according to the recursive rules.\n",
219
+ " \"\"\"\n",
220
+ "\n",
221
+ " def _extract_tables(self, text: str) -> Tuple[List[TableChunk], List[Tuple[int, int, str]]]:\n",
222
+ " \"\"\"\n",
223
+ " Extract markdown tables from text and return table chunks and remaining text segments.\n",
224
+ " \n",
225
+ " Args:\n",
226
+ " text: The input text containing markdown content\n",
227
+ " \n",
228
+ " Returns:\n",
229
+ " Tuple containing:\n",
230
+ " - List of TableChunk objects for tables\n",
231
+ " - List of (start_index, end_index, text) tuples for non-table segments\n",
232
+ " \"\"\"\n",
233
+ " # Regular expression for markdown tables (matches header, separator, and content rows)\n",
234
+ " table_pattern = r'(\\|[^\\n]+\\|\\n\\|[-:\\|\\s]+\\|\\n(?:\\|[^\\n]+\\|\\n)+)'\n",
235
+ " \n",
236
+ " table_chunks = []\n",
237
+ " non_table_segments = []\n",
238
+ " last_end = 0\n",
239
+ " \n",
240
+ " for match in re.finditer(table_pattern, text):\n",
241
+ " start, end = match.span()\n",
242
+ " \n",
243
+ " # Add non-table text before this table\n",
244
+ " if start > last_end:\n",
245
+ " non_table_segments.append((last_end, start, text[last_end:start]))\n",
246
+ " \n",
247
+ " # Create table chunk\n",
248
+ " table_text = match.group()\n",
249
+ " token_count = self._count_tokens(table_text)\n",
250
+ " table_chunks.append(TableChunk(\n",
251
+ " text=table_text,\n",
252
+ " start_index=start,\n",
253
+ " end_index=end,\n",
254
+ " token_count=token_count\n",
255
+ " ))\n",
256
+ " \n",
257
+ " last_end = end\n",
258
+ " \n",
259
+ " # Add remaining text after last table\n",
260
+ " if last_end < len(text):\n",
261
+ " non_table_segments.append((last_end, len(text), text[last_end:]))\n",
262
+ " \n",
263
+ " return table_chunks, non_table_segments\n",
264
+ "\n",
265
+ " def chunk(self, text: str) -> Tuple[List[RecursiveChunk], List[TableChunk]]:\n",
266
+ " \"\"\"\n",
267
+ " Chunk the text while preserving tables.\n",
268
+ " \n",
269
+ " This method overrides the base chunk method to handle tables separately from\n",
270
+ " regular text content.\n",
271
+ " \n",
272
+ " Args:\n",
273
+ " text: The input text to chunk\n",
274
+ " \n",
275
+ " Returns:\n",
276
+ " Tuple containing:\n",
277
+ " - List of RecursiveChunk objects for non-table text\n",
278
+ " - List of TableChunk objects for tables\n",
279
+ " \"\"\"\n",
280
+ " # First extract tables\n",
281
+ " table_chunks, non_table_segments = self._extract_tables(text)\n",
282
+ " \n",
283
+ " # Chunk each non-table segment using the parent class's recursive chunking\n",
284
+ " text_chunks = []\n",
285
+ " for start, end, segment in non_table_segments:\n",
286
+ " if segment.strip(): # Only process non-empty segments\n",
287
+ " # Use the parent class's recursive chunking logic\n",
288
+ " chunks = super()._recursive_chunk(segment, level=0, full_text=text)\n",
289
+ " text_chunks.extend(chunks)\n",
290
+ " \n",
291
+ " return text_chunks, table_chunks\n",
292
+ "\n",
293
+ " def chunk_batch(self, texts: List[str]) -> List[Tuple[List[RecursiveChunk], List[TableChunk]]]:\n",
294
+ " \"\"\"\n",
295
+ " Chunk multiple texts while preserving tables in each.\n",
296
+ " \n",
297
+ " Args:\n",
298
+ " texts: List of texts to chunk\n",
299
+ " \n",
300
+ " Returns:\n",
301
+ " List of tuples, each containing:\n",
302
+ " - List of RecursiveChunk objects for non-table text\n",
303
+ " - List of TableChunk objects for tables\n",
304
+ " \"\"\"\n",
305
+ " return [self.chunk(text) for text in texts]\n",
306
+ "\n",
307
+ " def __call__(self, texts: Union[str, List[str]]) -> Union[\n",
308
+ " Tuple[List[RecursiveChunk], List[TableChunk]],\n",
309
+ " List[Tuple[List[RecursiveChunk], List[TableChunk]]]\n",
310
+ " ]:\n",
311
+ " \"\"\"Make the chunker callable for convenience.\"\"\"\n",
312
+ " if isinstance(texts, str):\n",
313
+ " return self.chunk(texts)\n",
314
+ " return self.chunk_batch(texts)\n",
315
+ " \n"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "execution_count": 3,
321
+ "metadata": {},
322
+ "outputs": [],
323
+ "source": [
324
+ "from typing import List\n",
325
+ "from langchain_ollama import OllamaEmbeddings\n",
326
+ "\n",
327
+ "class EmbeddingModel:\n",
328
+ " def __init__(self, model_name: str = \"llama3.2\"):\n",
329
+ " \"\"\"\n",
330
+ " Initialize embedding model with LangChain OllamaEmbeddings\n",
331
+ " \n",
332
+ " Args:\n",
333
+ " model_name (str): Name of the model to use\n",
334
+ " \"\"\"\n",
335
+ " self.model_name = model_name\n",
336
+ " self.embeddings = OllamaEmbeddings(\n",
337
+ " model=model_name\n",
338
+ " )\n",
339
+ "\n",
340
+ " def embed(self, text: str) -> List[float]:\n",
341
+ " \"\"\"\n",
342
+ " Generate embeddings for a single text input\n",
343
+ " \n",
344
+ " Args:\n",
345
+ " text (str): Input text to embed\n",
346
+ " \n",
347
+ " Returns:\n",
348
+ " List[float]: Embedding vector\n",
349
+ " \"\"\"\n",
350
+ " try:\n",
351
+ " # Use embed_query for single text embedding\n",
352
+ " return self.embeddings.embed_query(text)\n",
353
+ " except Exception as e:\n",
354
+ " print(f\"Error generating embedding: {e}\")\n",
355
+ " return []\n",
356
+ "\n",
357
+ " def embed_batch(self, texts: List[str]) -> List[List[float]]:\n",
358
+ " \"\"\"\n",
359
+ " Generate embeddings for multiple texts\n",
360
+ " \n",
361
+ " Args:\n",
362
+ " texts (List[str]): List of input texts to embed\n",
363
+ " \n",
364
+ " Returns:\n",
365
+ " List[List[float]]: List of embedding vectors\n",
366
+ " \"\"\"\n",
367
+ " try:\n",
368
+ " # Use embed_documents for batch embedding\n",
369
+ " return self.embeddings.embed_documents(texts)\n",
370
+ " except Exception as e:\n",
371
+ " print(f\"Error generating batch embeddings: {e}\")\n",
372
+ " return []\n",
373
+ " \n",
374
+ "# if __name__ == \"__main__\":\n",
375
+ "# # Initialize the embedding model\n",
376
+ "# embedding_model = EmbeddingModel(model_name=\"llama3.2\")\n",
377
+ "\n",
378
+ "# # Generate embedding for a single text\n",
379
+ "# single_text = \"The meaning of life is 42\"\n",
380
+ "# vector = embedding_model.embed(single_text)\n",
381
+ "# print(vector[:3]) # Print first 3 dimensions\n",
382
+ "\n",
383
+ "# # Generate embeddings for multiple texts\n",
384
+ "# texts = [\"Document 1...\", \"Document 2...\"]\n",
385
+ "# vectors = embedding_model.embed_batch(texts)\n",
386
+ "# print(len(vectors)) # Number of vectors\n",
387
+ "# print(vectors[0][:3]) # First 3 dimensions of first vector"
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": 4,
393
+ "metadata": {},
394
+ "outputs": [],
395
+ "source": [
396
+ "from typing import List, Dict, Optional\n",
397
+ "from langchain_ollama import ChatOllama\n",
398
+ "from langchain_core.messages import HumanMessage, AIMessage\n",
399
+ "from langchain_core.prompts import ChatPromptTemplate\n",
400
+ "\n",
401
+ "class LLMChat:\n",
402
+ " def __init__(self, model_name: str = \"llama3.2\", temperature: float = 0):\n",
403
+ " \"\"\"\n",
404
+ " Initialize LLMChat with LangChain ChatOllama\n",
405
+ " \n",
406
+ " Args:\n",
407
+ " model_name (str): Name of the model to use\n",
408
+ " temperature (float): Temperature parameter for response generation\n",
409
+ " \"\"\"\n",
410
+ " self.model_name = model_name\n",
411
+ " self.llm = ChatOllama(\n",
412
+ " model=model_name,\n",
413
+ " temperature=temperature\n",
414
+ " )\n",
415
+ " self.history: List[Dict[str, str]] = []\n",
416
+ "\n",
417
+ " def chat_once(self, message: str):\n",
418
+ " \"\"\"\n",
419
+ " Single chat interaction without maintaining history\n",
420
+ " \n",
421
+ " Args:\n",
422
+ " message (str): User input message\n",
423
+ " \n",
424
+ " Returns:\n",
425
+ " str: Model's response\n",
426
+ " \"\"\"\n",
427
+ " try:\n",
428
+ " # Create a simple prompt template for single messages\n",
429
+ " prompt = ChatPromptTemplate.from_messages([\n",
430
+ " (\"human\", \"{input}\")\n",
431
+ " ])\n",
432
+ " \n",
433
+ " # Create and invoke the chain\n",
434
+ " chain = prompt | self.llm\n",
435
+ " response = chain.invoke({\"input\": message})\n",
436
+ " \n",
437
+ " return response.content\n",
438
+ " except Exception as e:\n",
439
+ " print(f\"Error in chat: {e}\")\n",
440
+ " return \"\"\n",
441
+ "\n",
442
+ " def chat_with_history(self, message: str):\n",
443
+ " \"\"\"\n",
444
+ " Chat interaction maintaining conversation history\n",
445
+ " \n",
446
+ " Args:\n",
447
+ " message (str): User input message\n",
448
+ " \n",
449
+ " Returns:\n",
450
+ " str: Model's response\n",
451
+ " \"\"\"\n",
452
+ " try:\n",
453
+ " # Add user message to history\n",
454
+ " self.history.append({'role': 'human', 'content': message})\n",
455
+ " \n",
456
+ " # Convert history to LangChain message format\n",
457
+ " messages = [\n",
458
+ " HumanMessage(content=msg['content']) if msg['role'] == 'human'\n",
459
+ " else AIMessage(content=msg['content'])\n",
460
+ " for msg in self.history\n",
461
+ " ]\n",
462
+ " \n",
463
+ " # Get response using chat method\n",
464
+ " response = self.llm.invoke(messages)\n",
465
+ " assistant_message = response.content\n",
466
+ " \n",
467
+ " # Add assistant response to history\n",
468
+ " self.history.append({'role': 'assistant', 'content': assistant_message})\n",
469
+ " \n",
470
+ " return assistant_message\n",
471
+ " except Exception as e:\n",
472
+ " print(f\"Error in chat with history: {e}\")\n",
473
+ " return \"\"\n",
474
+ "\n",
475
+ " def chat_with_template(self, template_messages: List[Dict[str, str]], \n",
476
+ " input_variables: Dict[str, str]):\n",
477
+ " \"\"\"\n",
478
+ " Chat using a custom template\n",
479
+ " \n",
480
+ " Args:\n",
481
+ " template_messages (List[Dict[str, str]]): List of template messages\n",
482
+ " input_variables (Dict[str, str]): Variables to fill in the template\n",
483
+ " \n",
484
+ " Returns:\n",
485
+ " str: Model's response\n",
486
+ " \"\"\"\n",
487
+ " try:\n",
488
+ " # Create prompt template from messages\n",
489
+ " prompt = ChatPromptTemplate.from_messages([\n",
490
+ " (msg['role'], msg['content'])\n",
491
+ " for msg in template_messages\n",
492
+ " ])\n",
493
+ " \n",
494
+ " # Create and invoke the chain\n",
495
+ " chain = prompt | self.llm\n",
496
+ " response = chain.invoke(input_variables)\n",
497
+ " \n",
498
+ " return response.content\n",
499
+ " except Exception as e:\n",
500
+ " print(f\"Error in template chat: {e}\")\n",
501
+ " return \"\"\n",
502
+ "\n",
503
+ " def clear_history(self):\n",
504
+ " \"\"\"Clear the conversation history\"\"\"\n",
505
+ " self.history = []\n",
506
+ "\n",
507
+ " def get_history(self) -> List[Dict[str, str]]:\n",
508
+ " \"\"\"Return the current conversation history\"\"\"\n",
509
+ " return self.history\n",
510
+ " \n",
511
+ "# if __name__ == \"__main__\":\n",
512
+ "# # Initialize the chat\n",
513
+ "# chat = LLMChat(model_name=\"llama3.1\", temperature=0)\n",
514
+ "\n",
515
+ "# # Example of using a template for translation\n",
516
+ "# template_messages = [\n",
517
+ "# {\n",
518
+ "# \"role\": \"system\",\n",
519
+ "# \"content\": \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
520
+ "# },\n",
521
+ "# {\n",
522
+ "# \"role\": \"human\",\n",
523
+ "# \"content\": \"{input}\"\n",
524
+ "# }\n",
525
+ "# ]\n",
526
+ "\n",
527
+ "# input_vars = {\n",
528
+ "# \"input_language\": \"English\",\n",
529
+ "# \"output_language\": \"German\",\n",
530
+ "# \"input\": \"I love programming.\"\n",
531
+ "# }\n",
532
+ "\n",
533
+ "# response = chat.chat_with_template(template_messages, input_vars)\n",
534
+ "# # Simple chat without history\n",
535
+ "# response = chat.chat_once(\"Hello!\")\n",
536
+ "\n",
537
+ "# # Chat with history\n",
538
+ "# response = chat.chat_with_history(\"How are you?\")"
539
+ ]
540
+ },
541
+ {
542
+ "cell_type": "code",
543
+ "execution_count": 5,
544
+ "metadata": {},
545
+ "outputs": [],
546
+ "source": [
547
+ "from typing import List, Dict, Any\n",
548
+ "from tqdm import tqdm\n",
549
+ "import time\n",
550
+ "\n",
551
+ "# from src.embedding import EmbeddingModel\n",
552
+ "# from src.llm import LLMChat\n",
553
+ "\n",
554
+ "class TableProcessor:\n",
555
+ " def __init__(self, llm_model: LLMChat, embedding_model: EmbeddingModel, batch_size: int = 8):\n",
556
+ " \"\"\"\n",
557
+ " Initialize the TableProcessor with pre-initialized models.\n",
558
+ " \n",
559
+ " Args:\n",
560
+ " llm_model (LLMChat): Initialized LLM model\n",
561
+ " embedding_model (EmbeddingModel): Initialized embedding model\n",
562
+ " batch_size (int): Batch size for processing embeddings\n",
563
+ " \"\"\"\n",
564
+ " self.llm = llm_model\n",
565
+ " self.embedder = embedding_model\n",
566
+ " self.batch_size = batch_size\n",
567
+ " \n",
568
+ " def get_table_description(self, markdown_table: str) -> str:\n",
569
+ " \"\"\"\n",
570
+ " Generate description for a single markdown table using Ollama chat.\n",
571
+ " \n",
572
+ " Args:\n",
573
+ " markdown_table (str): Input markdown table\n",
574
+ " \n",
575
+ " Returns:\n",
576
+ " str: Generated description of the table\n",
577
+ " \"\"\"\n",
578
+ " system_prompt = \"\"\"You are an AI language model. Your task is to examine the provided table, taking into account both its rows and columns, and produce a concise summary of up to 200 words. Emphasize key patterns, trends, and notable data points that provide meaningful insights into the content of the table.\"\"\"\n",
579
+ " \n",
580
+ " try:\n",
581
+ " # Use chat_once to avoid maintaining history between tables\n",
582
+ " full_prompt = f\"{system_prompt}\\n\\nTable:\\n{markdown_table}\"\n",
583
+ " return self.llm.chat_once(full_prompt)\n",
584
+ " except Exception as e:\n",
585
+ " print(f\"Error generating table description: {e}\")\n",
586
+ " return \"\"\n",
587
+ " \n",
588
+ " def process_tables(self, markdown_tables) -> List[Dict[str, Any]]:\n",
589
+ " \"\"\"\n",
590
+ " Process a list of markdown tables: generate descriptions and embeddings.\n",
591
+ " \n",
592
+ " Args:\n",
593
+ " markdown_tables (List[str]): List of markdown tables to process\n",
594
+ " \n",
595
+ " Returns:\n",
596
+ " List[Dict[str, Any]]: List of dictionaries containing processed information\n",
597
+ " \"\"\"\n",
598
+ " results = []\n",
599
+ " descriptions = []\n",
600
+ " \n",
601
+ " # Generate descriptions for all tables\n",
602
+ " with tqdm(total=len(markdown_tables), desc=\"Generating table descriptions\") as pbar:\n",
603
+ " for i, table in enumerate(markdown_tables):\n",
604
+ " description = self.get_table_description(table.text)\n",
605
+ " print(f\"\\nTable {i+1}:\")\n",
606
+ " print(f\"Description: {description}\")\n",
607
+ " print(\"-\" * 50)\n",
608
+ " descriptions.append(description)\n",
609
+ " pbar.update(1)\n",
610
+ " time.sleep(1) # Rate limiting\n",
611
+ " \n",
612
+ " # Generate embeddings in batches\n",
613
+ " embeddings = []\n",
614
+ " total_batches = (len(descriptions) + self.batch_size - 1) // self.batch_size\n",
615
+ " \n",
616
+ " with tqdm(total=total_batches, desc=\"Generating embeddings\") as pbar:\n",
617
+ " for i in range(0, len(descriptions), self.batch_size):\n",
618
+ " batch = descriptions[i:i + self.batch_size]\n",
619
+ " if len(batch) == 1:\n",
620
+ " batch_embeddings = [self.embedder.embed(batch[0])]\n",
621
+ " else:\n",
622
+ " batch_embeddings = self.embedder.embed_batch(batch)\n",
623
+ " embeddings.extend(batch_embeddings)\n",
624
+ " pbar.update(1)\n",
625
+ " \n",
626
+ " # Combine results with progress bar\n",
627
+ " with tqdm(total=len(markdown_tables), desc=\"Combining results\") as pbar:\n",
628
+ " for table, description, embedding in zip(markdown_tables, descriptions, embeddings):\n",
629
+ " results.append({\n",
630
+ " \"embedding\": embedding,\n",
631
+ " \"text\": table,\n",
632
+ " \"table_description\": description,\n",
633
+ " \"type\": \"table_chunk\"\n",
634
+ " })\n",
635
+ " pbar.update(1)\n",
636
+ " \n",
637
+ " return results\n",
638
+ "\n",
639
+ " def __call__(self, markdown_tables) -> List[Dict[str, Any]]:\n",
640
+ " \"\"\"\n",
641
+ " Make the class callable for easier use.\n",
642
+ " \n",
643
+ " Args:\n",
644
+ " markdown_tables (List[str]): List of markdown tables to process\n",
645
+ " \n",
646
+ " Returns:\n",
647
+ " List[Dict[str, Any]]: Processed results\n",
648
+ " \"\"\"\n",
649
+ " return self.process_tables(markdown_tables)"
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "code",
654
+ "execution_count": 7,
655
+ "metadata": {},
656
+ "outputs": [],
657
+ "source": [
658
+ "from typing import List, Dict, Any, Optional\n",
659
+ "import pandas as pd\n",
660
+ "import time\n",
661
+ "from tqdm import tqdm\n",
662
+ "import logging\n",
663
+ "from pinecone import Pinecone, ServerlessSpec\n",
664
+ "from dataclasses import dataclass\n",
665
+ "from enum import Enum\n",
666
+ "# from src.table_aware_chunker import TableRecursiveChunker\n",
667
+ "# from src.processor import TableProcessor\n",
668
+ "# from src.llm import LLMChat\n",
669
+ "# from src.embedding import EmbeddingModel\n",
670
+ "from chonkie import RecursiveRules\n",
671
+ "# from src.loader import MultiFormatDocumentLoader\n",
672
+ "from dotenv import load_dotenv\n",
673
+ "import os\n",
674
+ "\n",
675
+ "load_dotenv()\n",
676
+ "# API Keys\n",
677
+ "PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')\n",
678
+ "\n",
679
+ "logging.basicConfig(\n",
680
+ " level=logging.INFO,\n",
681
+ " format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'\n",
682
+ ")\n",
683
+ "logger = logging.getLogger('table_aware_rag')\n",
684
+ "\n",
685
+ "class ChunkType(Enum):\n",
686
+ " TEXT = \"text_chunk\"\n",
687
+ " TABLE = \"table_chunk\"\n",
688
+ "\n",
689
+ "@dataclass\n",
690
+ "class ProcessedChunk:\n",
691
+ " text: str # This will be the embedable text (table description for tables)\n",
692
+ " chunk_type: ChunkType\n",
693
+ " token_count: int\n",
694
+ " markdown_table: Optional[str] = None # Store original markdown table format\n",
695
+ " start_index: Optional[int] = None\n",
696
+ " end_index: Optional[int] = None\n",
697
+ "\n",
698
+ "def process_documents(\n",
699
+ " file_paths: List[str],\n",
700
+ " chunker: TableRecursiveChunker,\n",
701
+ " processor: TableProcessor,\n",
702
+ " output_path: str = './output.md'\n",
703
+ ") -> List[ProcessedChunk]:\n",
704
+ " \"\"\"\n",
705
+ " Process documents into text and table chunks\n",
706
+ " \"\"\"\n",
707
+ " # Load documents\n",
708
+ " loader = MultiFormatDocumentLoader(\n",
709
+ " file_paths=file_paths,\n",
710
+ " enable_ocr=False,\n",
711
+ " enable_tables=True\n",
712
+ " )\n",
713
+ " \n",
714
+ " # Save to markdown and read content\n",
715
+ " with open(output_path, 'w') as f:\n",
716
+ " for doc in loader.lazy_load():\n",
717
+ " f.write(doc.page_content)\n",
718
+ " \n",
719
+ " with open(output_path, 'r') as file:\n",
720
+ " text = file.read()\n",
721
+ " \n",
722
+ " # Get text and table chunks\n",
723
+ " text_chunks, table_chunks = chunker.chunk(text)\n",
724
+ " \n",
725
+ " # Process chunks\n",
726
+ " processed_chunks = []\n",
727
+ " \n",
728
+ " # Process text chunks\n",
729
+ " for chunk in text_chunks:\n",
730
+ " processed_chunks.append(\n",
731
+ " ProcessedChunk(\n",
732
+ " text=chunk.text,\n",
733
+ " chunk_type=ChunkType.TEXT,\n",
734
+ " token_count=chunk.token_count,\n",
735
+ " start_index=chunk.start_index,\n",
736
+ " end_index=chunk.end_index\n",
737
+ " )\n",
738
+ " )\n",
739
+ " \n",
740
+ " # Process table chunks\n",
741
+ " table_results = processor(table_chunks)\n",
742
+ " for table in table_results:\n",
743
+ " # Convert table chunk to string representation if needed\n",
744
+ " table_str = str(table[\"text\"].text)\n",
745
+ " \n",
746
+ " processed_chunks.append(\n",
747
+ " ProcessedChunk(\n",
748
+ " text=table[\"table_description\"], # Use description for embedding\n",
749
+ " chunk_type=ChunkType.TABLE,\n",
750
+ " token_count=len(table[\"table_description\"].split()),\n",
751
+ " markdown_table=table_str # Store string version of table\n",
752
+ " )\n",
753
+ " )\n",
754
+ " \n",
755
+ " return processed_chunks\n",
756
+ "\n",
757
+ "class PineconeRetriever:\n",
758
+ " def __init__(\n",
759
+ " self,\n",
760
+ " pinecone_client: Pinecone,\n",
761
+ " index_name: str,\n",
762
+ " namespace: str,\n",
763
+ " embedding_model: Any,\n",
764
+ " llm_model: Any\n",
765
+ " ):\n",
766
+ " \"\"\"\n",
767
+ " Initialize retriever with configurable models\n",
768
+ " \"\"\"\n",
769
+ " self.pinecone = pinecone_client\n",
770
+ " self.index = self.pinecone.Index(index_name)\n",
771
+ " self.namespace = namespace\n",
772
+ " self.embedding_model = embedding_model\n",
773
+ " self.llm_model = llm_model\n",
774
+ " \n",
775
+ " def _prepare_query(self, question: str) -> List[float]:\n",
776
+ " \"\"\"Generate embedding for query\"\"\"\n",
777
+ " return self.embedding_model.embed(question)\n",
778
+ " \n",
779
+ " def invoke(\n",
780
+ " self,\n",
781
+ " question: str,\n",
782
+ " top_k: int = 5,\n",
783
+ " chunk_type_filter: Optional[ChunkType] = None\n",
784
+ " ) -> List[Dict[str, Any]]:\n",
785
+ " \"\"\"\n",
786
+ " Retrieve similar documents with optional filtering by chunk type\n",
787
+ " \"\"\"\n",
788
+ " query_embedding = self._prepare_query(question)\n",
789
+ " \n",
790
+ " # Prepare filter if chunk type specified\n",
791
+ " filter_dict = None\n",
792
+ " if chunk_type_filter:\n",
793
+ " filter_dict = {\"chunk_type\": chunk_type_filter.value}\n",
794
+ " \n",
795
+ " results = self.index.query(\n",
796
+ " namespace=self.namespace,\n",
797
+ " vector=query_embedding,\n",
798
+ " top_k=top_k,\n",
799
+ " include_values=False,\n",
800
+ " include_metadata=True,\n",
801
+ " filter=filter_dict\n",
802
+ " )\n",
803
+ " \n",
804
+ " retrieved_docs = []\n",
805
+ " for match in results.matches:\n",
806
+ " doc = {\n",
807
+ " \"score\": match.score,\n",
808
+ " \"chunk_type\": match.metadata[\"chunk_type\"]\n",
809
+ " }\n",
810
+ " \n",
811
+ " # Handle different chunk types\n",
812
+ " if match.metadata[\"chunk_type\"] == ChunkType.TABLE.value:\n",
813
+ " doc[\"table_description\"] = match.metadata[\"text\"] # The embedded description\n",
814
+ " doc[\"markdown_table\"] = match.metadata[\"markdown_table\"] # Original table format\n",
815
+ " else:\n",
816
+ " doc[\"page_content\"] = match.metadata[\"text\"]\n",
817
+ " \n",
818
+ " retrieved_docs.append(doc)\n",
819
+ " \n",
820
+ " return retrieved_docs\n",
821
+ "\n",
822
+ "def ingest_data(\n",
823
+ " processed_chunks: List[ProcessedChunk],\n",
824
+ " embedding_model: Any,\n",
825
+ " pinecone_client: Pinecone,\n",
826
+ " index_name: str = \"vector-index\",\n",
827
+ " namespace: str = \"rag\",\n",
828
+ " batch_size: int = 100\n",
829
+ "):\n",
830
+ " \"\"\"\n",
831
+ " Ingest processed chunks into Pinecone\n",
832
+ " \"\"\"\n",
833
+ " # Create or get index\n",
834
+ " if not pinecone_client.has_index(index_name):\n",
835
+ " pinecone_client.create_index(\n",
836
+ " name=index_name,\n",
837
+ " dimension=768,\n",
838
+ " metric=\"cosine\",\n",
839
+ " spec=ServerlessSpec(\n",
840
+ " cloud='aws',\n",
841
+ " region='us-east-1'\n",
842
+ " )\n",
843
+ " )\n",
844
+ " \n",
845
+ " while not pinecone_client.describe_index(index_name).status['ready']:\n",
846
+ " time.sleep(1)\n",
847
+ " \n",
848
+ " index = pinecone_client.Index(index_name)\n",
849
+ " \n",
850
+ " # Process in batches\n",
851
+ " for i in tqdm(range(0, len(processed_chunks), batch_size)):\n",
852
+ " batch = processed_chunks[i:i+batch_size]\n",
853
+ " \n",
854
+ " # Generate embeddings for the text content\n",
855
+ " texts = [chunk.text for chunk in batch]\n",
856
+ " embeddings = embedding_model.embed_batch(texts)\n",
857
+ " \n",
858
+ " # Prepare records\n",
859
+ " records = []\n",
860
+ " for idx, chunk in enumerate(batch):\n",
861
+ " metadata = {\n",
862
+ " \"text\": chunk.text, # This is the description for tables\n",
863
+ " \"chunk_type\": chunk.chunk_type.value,\n",
864
+ " \"token_count\": chunk.token_count\n",
865
+ " }\n",
866
+ " \n",
867
+ " # Add markdown table to metadata if it's a table chunk\n",
868
+ " if chunk.markdown_table is not None:\n",
869
+ " # Ensure the table is in string format\n",
870
+ " metadata[\"markdown_table\"] = str(chunk.markdown_table)\n",
871
+ " \n",
872
+ " records.append({\n",
873
+ " \"id\": f\"chunk_{i + idx}\",\n",
874
+ " \"values\": embeddings[idx],\n",
875
+ " \"metadata\": metadata\n",
876
+ " })\n",
877
+ " \n",
878
+ " # Upsert to Pinecone\n",
879
+ " try:\n",
880
+ " index.upsert(vectors=records, namespace=namespace)\n",
881
+ " except Exception as e:\n",
882
+ " logger.error(f\"Error during upsert: {str(e)}\")\n",
883
+ " logger.error(f\"Problematic record metadata: {records[0]['metadata']}\")\n",
884
+ " raise\n",
885
+ " \n",
886
+ " time.sleep(0.5) # Rate limiting\n",
887
+ "\n",
888
+ "\n",
889
+ "# def main():\n",
890
+ "# # Initialize components\n",
891
+ "# pc = Pinecone(api_key=PINECONE_API_KEY)\n",
892
+ " \n",
893
+ "# chunker = TableRecursiveChunker(\n",
894
+ "# tokenizer=\"gpt2\",\n",
895
+ "# chunk_size=512,\n",
896
+ "# rules=RecursiveRules(),\n",
897
+ "# min_characters_per_chunk=12\n",
898
+ "# )\n",
899
+ " \n",
900
+ "# llm = LLMChat(\"qwen2.5:0.5b\")\n",
901
+ "# embedder = EmbeddingModel(\"nomic-embed-text\")\n",
902
+ " \n",
903
+ "# processor = TableProcessor(\n",
904
+ "# llm_model=llm,\n",
905
+ "# embedding_model=embedder,\n",
906
+ "# batch_size=8\n",
907
+ "# )\n",
908
+ " \n",
909
+ "# try:\n",
910
+ "# # Process documents\n",
911
+ "# processed_chunks = process_documents(\n",
912
+ "# file_paths=['/teamspace/studios/this_studio/TabularRAG/data/FeesPaymentReceipt_7thsem.pdf'],\n",
913
+ "# chunker=chunker,\n",
914
+ "# processor=processor\n",
915
+ "# )\n",
916
+ " \n",
917
+ "# # Ingest data\n",
918
+ "# ingest_data(\n",
919
+ "# processed_chunks=processed_chunks,\n",
920
+ "# embedding_model=embedder,\n",
921
+ "# pinecone_client=pc\n",
922
+ "# )\n",
923
+ " \n",
924
+ "# # Test retrieval\n",
925
+ "# retriever = PineconeRetriever(\n",
926
+ "# pinecone_client=pc,\n",
927
+ "# index_name=\"vector-index\",\n",
928
+ "# namespace=\"rag\",\n",
929
+ "# embedding_model=embedder,\n",
930
+ "# llm_model=llm\n",
931
+ "# )\n",
932
+ " \n",
933
+ " # # Test text-only retrieval\n",
934
+ " # text_results = retriever.invoke(\n",
935
+ " # question=\"What is paid fees amount?\",\n",
936
+ " # top_k=3,\n",
937
+ " # chunk_type_filter=ChunkType.TEXT\n",
938
+ " # )\n",
939
+ " # print(\"Text results:\")\n",
940
+ " # for result in text_results:\n",
941
+ " # print(result)\n",
942
+ " # Test table-only retrieval\n",
943
+ " # table_results = retriever.invoke(\n",
944
+ " # question=\"What is paid fees amount?\",\n",
945
+ " # top_k=3,\n",
946
+ " # chunk_type_filter=ChunkType.TABLE\n",
947
+ " # )\n",
948
+ " # print(\"Table results:\")\n",
949
+ " # for result in table_results:\n",
950
+ " # print(result)\n",
951
+ " \n",
952
+ " # results = retriever.invoke(\n",
953
+ " # question=\"What is paid fees amount?\",\n",
954
+ " # top_k=3\n",
955
+ " # )\n",
956
+ " \n",
957
+ " # for i, result in enumerate(results, 1):\n",
958
+ " # print(f\"\\nResult {i}:\")\n",
959
+ " # if result[\"chunk_type\"] == ChunkType.TABLE.value:\n",
960
+ " # print(f\"Table Description: {result['table_description']}\")\n",
961
+ " # print(\"Table Format:\")\n",
962
+ " # print(result['markdown_table'])\n",
963
+ " # else:\n",
964
+ " # print(f\"Content: {result['page_content']}\")\n",
965
+ " # print(f\"Score: {result['score']}\")\n",
966
+ " \n",
967
+ " # except Exception as e:\n",
968
+ " # logger.error(f\"Error in pipeline: {str(e)}\")\n",
969
+ "\n",
970
+ "# if __name__ == \"__main__\":\n",
971
+ "# main()"
972
+ ]
973
+ },
974
+ {
975
+ "cell_type": "code",
976
+ "execution_count": 8,
977
+ "metadata": {},
978
+ "outputs": [],
979
+ "source": [
980
+ "from pathlib import Path\n",
981
+ "import tempfile\n",
982
+ "import os\n",
983
+ "from typing import List, Dict\n",
984
+ "from pinecone import Pinecone\n",
985
+ "# from src.table_aware_chunker import TableRecursiveChunker\n",
986
+ "# from src.processor import TableProcessor\n",
987
+ "# from src.llm import LLMChat\n",
988
+ "# from src.embedding import EmbeddingModel\n",
989
+ "from chonkie import RecursiveRules\n",
990
+ "# from src.vectordb import ChunkType, process_documents, ingest_data, PineconeRetriever\n",
991
+ "\n",
992
+ "class TableRAGSystem:\n",
993
+ " def __init__(self, pinecone_api_key: str):\n",
994
+ " \"\"\"Initialize the Table RAG system with necessary components.\"\"\"\n",
995
+ " self.pc = Pinecone(api_key=pinecone_api_key)\n",
996
+ " \n",
997
+ " # Initialize LLM\n",
998
+ " self.llm = LLMChat(\n",
999
+ " model_name=\"mistral:7b\",\n",
1000
+ " temperature=0.3\n",
1001
+ " )\n",
1002
+ " \n",
1003
+ " # Initialize Embeddings\n",
1004
+ " self.embedder = EmbeddingModel(\"nomic-embed-text\")\n",
1005
+ " \n",
1006
+ " # Initialize Chunker\n",
1007
+ " self.chunker = TableRecursiveChunker(\n",
1008
+ " tokenizer=\"gpt2\",\n",
1009
+ " chunk_size=512,\n",
1010
+ " rules=RecursiveRules(),\n",
1011
+ " min_characters_per_chunk=12\n",
1012
+ " )\n",
1013
+ " \n",
1014
+ " # Initialize Processor\n",
1015
+ " self.processor = TableProcessor(\n",
1016
+ " llm_model=self.llm,\n",
1017
+ " embedding_model=self.embedder,\n",
1018
+ " batch_size=8\n",
1019
+ " )\n",
1020
+ " \n",
1021
+ " self.retriever = None\n",
1022
+ " \n",
1023
+ " def process_documents(self, file_paths: List[str]) -> bool:\n",
1024
+ " \"\"\"Process documents and initialize the retriever.\"\"\"\n",
1025
+ " try:\n",
1026
+ " # Process documents\n",
1027
+ " print(\"Processing documents...\")\n",
1028
+ " processed_chunks = process_documents(\n",
1029
+ " file_paths=file_paths,\n",
1030
+ " chunker=self.chunker,\n",
1031
+ " processor=self.processor,\n",
1032
+ " output_path='./output.md'\n",
1033
+ " )\n",
1034
+ " \n",
1035
+ " # Ingest data\n",
1036
+ " print(\"Ingesting data to vector database...\")\n",
1037
+ " ingest_data(\n",
1038
+ " processed_chunks=processed_chunks,\n",
1039
+ " embedding_model=self.embedder,\n",
1040
+ " pinecone_client=self.pc\n",
1041
+ " )\n",
1042
+ " \n",
1043
+ " # Setup retriever\n",
1044
+ " print(\"Setting up retriever...\")\n",
1045
+ " self.retriever = PineconeRetriever(\n",
1046
+ " pinecone_client=self.pc,\n",
1047
+ " index_name=\"vector-index\",\n",
1048
+ " namespace=\"rag\",\n",
1049
+ " embedding_model=self.embedder,\n",
1050
+ " llm_model=self.llm\n",
1051
+ " )\n",
1052
+ " \n",
1053
+ " print(\"Processing complete!\")\n",
1054
+ " return True\n",
1055
+ "\n",
1056
+ " except Exception as e:\n",
1057
+ " print(f\"Error processing documents: {str(e)}\")\n",
1058
+ " return False\n",
1059
+ "\n",
1060
+ " def format_context(self, results: List[Dict]) -> str:\n",
1061
+ " \"\"\"Format retrieved results into context string.\"\"\"\n",
1062
+ " context_parts = []\n",
1063
+ " \n",
1064
+ " for result in results:\n",
1065
+ " if result.get(\"chunk_type\") == ChunkType.TABLE.value:\n",
1066
+ " table_text = f\"Table: {result['markdown_table']}\"\n",
1067
+ " if result.get(\"table_description\"):\n",
1068
+ " table_text += f\"\\nDescription: {result['table_description']}\"\n",
1069
+ " context_parts.append(table_text)\n",
1070
+ " else:\n",
1071
+ " context_parts.append(result.get(\"page_content\", \"\"))\n",
1072
+ " \n",
1073
+ " return \"\\n\\n\".join(context_parts)\n",
1074
+ "\n",
1075
+ " def query(self, question: str) -> Dict:\n",
1076
+ " \"\"\"Query the system with a question.\"\"\"\n",
1077
+ " if not self.retriever:\n",
1078
+ " raise ValueError(\"Documents must be processed before querying\")\n",
1079
+ " \n",
1080
+ " # Retrieve relevant content\n",
1081
+ " results = self.retriever.invoke(\n",
1082
+ " question=question,\n",
1083
+ " top_k=3\n",
1084
+ " )\n",
1085
+ " \n",
1086
+ " # Format context and get response from LLM\n",
1087
+ " context = self.format_context(results)\n",
1088
+ " \n",
1089
+ " # RAG Template\n",
1090
+ " rag_template = [\n",
1091
+ " {\n",
1092
+ " \"role\": \"system\",\n",
1093
+ " \"content\": \"\"\"You are a knowledgeable assistant specialized in analyzing documents and tables. \n",
1094
+ " Your responses should be:\n",
1095
+ " - Accurate and based on the provided context\n",
1096
+ " - Concise (three sentences maximum)\n",
1097
+ " - Professional yet conversational\n",
1098
+ " - Include specific references to tables when relevant\n",
1099
+ " \n",
1100
+ " If you cannot find an answer in the context, acknowledge this clearly.\"\"\"\n",
1101
+ " },\n",
1102
+ " {\n",
1103
+ " \"role\": \"human\",\n",
1104
+ " \"content\": \"Context: {context}\\n\\nQuestion: {question}\"\n",
1105
+ " }\n",
1106
+ " ]\n",
1107
+ " \n",
1108
+ " input_vars = {\n",
1109
+ " \"question\": question,\n",
1110
+ " \"context\": context\n",
1111
+ " }\n",
1112
+ "\n",
1113
+ " response = self.llm.chat_with_template(rag_template, input_vars)\n",
1114
+ " \n",
1115
+ " return {\n",
1116
+ " \"response\": response,\n",
1117
+ " \"context\": context,\n",
1118
+ " \"retrieved_results\": results\n",
1119
+ " }\n",
1120
+ "\n",
1121
+ " def clear_index(self, index_name: str = \"vector-index\"):\n",
1122
+ " \"\"\"Clear the Pinecone index.\"\"\"\n",
1123
+ " try:\n",
1124
+ " self.pc.delete_index(index_name)\n",
1125
+ " self.retriever = None\n",
1126
+ " print(\"Database cleared successfully!\")\n",
1127
+ " except Exception as e:\n",
1128
+ " print(f\"Error clearing database: {str(e)}\")"
1129
+ ]
1130
+ },
1131
+ {
1132
+ "cell_type": "code",
1133
+ "execution_count": 10,
1134
+ "metadata": {},
1135
+ "outputs": [
1136
+ {
1137
+ "name": "stderr",
1138
+ "output_type": "stream",
1139
+ "text": [
1140
+ "INFO:pinecone_plugin_interface.logging:Discovering subpackages in _NamespacePath(['/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pinecone_plugins'])\n",
1141
+ "INFO:pinecone_plugin_interface.logging:Looking for plugins in pinecone_plugins.inference\n",
1142
+ "INFO:pinecone_plugin_interface.logging:Installing plugin inference into Pinecone\n",
1143
+ "INFO:docling.document_converter:Going to convert document batch...\n",
1144
+ "INFO:docling.utils.accelerator_utils:Accelerator device: 'cpu'\n"
1145
+ ]
1146
+ },
1147
+ {
1148
+ "name": "stdout",
1149
+ "output_type": "stream",
1150
+ "text": [
1151
+ "Processing documents...\n"
1152
+ ]
1153
+ },
1154
+ {
1155
+ "name": "stderr",
1156
+ "output_type": "stream",
1157
+ "text": [
1158
+ "INFO:docling.utils.accelerator_utils:Accelerator device: 'cpu'\n",
1159
+ "INFO:docling.pipeline.base_pipeline:Processing document FeesPaymentReceipt_7thsem.pdf\n",
1160
+ "INFO:docling.document_converter:Finished converting document FeesPaymentReceipt_7thsem.pdf in 6.28 sec.\n",
1161
+ "INFO:__main__:Processed 1 documents:\n",
1162
+ "- Successfully converted: 1\n",
1163
+ "- Partially converted: 0\n",
1164
+ "- Failed: 0\n",
1165
+ "Generating table descriptions: 0%| | 0/1 [00:00<?, ?it/s]INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat \"HTTP/1.1 200 OK\"\n",
1166
+ "Generating table descriptions: 100%|██████████| 1/1 [01:36<00:00, 96.89s/it]"
1167
+ ]
1168
+ },
1169
+ {
1170
+ "name": "stdout",
1171
+ "output_type": "stream",
1172
+ "text": [
1173
+ "\n",
1174
+ "Table 1:\n",
1175
+ "Description: The table provides a breakdown of various costs associated with educational expenses, including tuition fees, lodging, fooding, and other charges. The most significant cost is the tuition fee at $22,500. It's interesting to note that there are two categories labeled as \"Outstanding\" for both tuition fees & others, and fooding, suggesting that these costs have not been fully paid.\n",
1176
+ "\n",
1177
+ " The lodging including facilities for one semester is also a substantial cost, although the amount is not specified in this table. The presence of an \"Excess\" and \"Late Fine 22500 Total\" categories implies that there may be additional fees for late payments or exceeding certain limits.\n",
1178
+ "\n",
1179
+ " Overall, the data suggests that the total educational costs can be quite high, with a significant portion of these costs being outstanding, potentially indicating a need for financial planning and budgeting strategies to manage these expenses effectively.\n",
1180
+ "--------------------------------------------------\n"
1181
+ ]
1182
+ },
1183
+ {
1184
+ "name": "stderr",
1185
+ "output_type": "stream",
1186
+ "text": [
1187
+ "Generating table descriptions: 100%|██████████| 1/1 [01:37<00:00, 97.89s/it]\n",
1188
+ "Generating embeddings: 0%| | 0/1 [00:00<?, ?it/s]INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/embed \"HTTP/1.1 200 OK\"\n",
1189
+ "Generating embeddings: 100%|██████████| 1/1 [00:02<00:00, 2.13s/it]\n",
1190
+ "Combining results: 100%|██████████| 1/1 [00:00<00:00, 24105.20it/s]\n"
1191
+ ]
1192
+ },
1193
+ {
1194
+ "name": "stdout",
1195
+ "output_type": "stream",
1196
+ "text": [
1197
+ "Ingesting data to vector database...\n"
1198
+ ]
1199
+ },
1200
+ {
1201
+ "name": "stderr",
1202
+ "output_type": "stream",
1203
+ "text": [
1204
+ "INFO:pinecone_plugin_interface.logging:Discovering subpackages in _NamespacePath(['/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pinecone_plugins'])\n",
1205
+ "INFO:pinecone_plugin_interface.logging:Looking for plugins in pinecone_plugins.inference\n",
1206
+ " 0%| | 0/1 [00:00<?, ?it/s]INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/embed \"HTTP/1.1 200 OK\"\n",
1207
+ "100%|██████████| 1/1 [00:02<00:00, 2.26s/it]\n",
1208
+ "INFO:pinecone_plugin_interface.logging:Discovering subpackages in _NamespacePath(['/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pinecone_plugins'])\n",
1209
+ "INFO:pinecone_plugin_interface.logging:Looking for plugins in pinecone_plugins.inference\n",
1210
+ "INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/embed \"HTTP/1.1 200 OK\"\n"
1211
+ ]
1212
+ },
1213
+ {
1214
+ "name": "stdout",
1215
+ "output_type": "stream",
1216
+ "text": [
1217
+ "Setting up retriever...\n",
1218
+ "Processing complete!\n"
1219
+ ]
1220
+ },
1221
+ {
1222
+ "name": "stderr",
1223
+ "output_type": "stream",
1224
+ "text": [
1225
+ "INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat \"HTTP/1.1 200 OK\"\n"
1226
+ ]
1227
+ },
1228
+ {
1229
+ "name": "stdout",
1230
+ "output_type": "stream",
1231
+ "text": [
1232
+ "Answer: Based on the provided context, I am unable to determine the exact paid amount as no numerical values related to payment are present in the given data. Please provide more specific details or numbers for a precise answer.\n",
1233
+ "\n",
1234
+ "Relevant Context: \n"
1235
+ ]
1236
+ }
1237
+ ],
1238
+ "source": [
1239
+ "# Initialize the system\n",
1240
+ "pinecone_api_key = \"pcsk_3AEjJe_So4D99WCivWvTLohkzAWp12gJiDcHMNXk3V8RkkaVUywB2jVitnciQbAEYZQEVS\"\n",
1241
+ "rag_system = TableRAGSystem(pinecone_api_key)\n",
1242
+ "\n",
1243
+ "# Process documents\n",
1244
+ "file_paths = [\n",
1245
+ " \"/teamspace/studios/this_studio/TabularRAG/data/FeesPaymentReceipt_7thsem.pdf\"\n",
1246
+ "]\n",
1247
+ "rag_system.process_documents(file_paths)\n",
1248
+ "\n",
1249
+ "# Query the system\n",
1250
+ "question = \"what is the paid amount?\"\n",
1251
+ "result = rag_system.query(question)\n",
1252
+ "\n",
1253
+ "# Access different parts of the response\n",
1254
+ "print(\"Answer:\", result[\"response\"])\n",
1255
+ "print(\"\\nRelevant Context:\", result[\"context\"])\n"
1256
+ ]
1257
+ },
1258
+ {
1259
+ "cell_type": "code",
1260
+ "execution_count": 11,
1261
+ "metadata": {},
1262
+ "outputs": [
1263
+ {
1264
+ "name": "stderr",
1265
+ "output_type": "stream",
1266
+ "text": [
1267
+ "INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/embed \"HTTP/1.1 200 OK\"\n",
1268
+ "INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat \"HTTP/1.1 200 OK\"\n"
1269
+ ]
1270
+ }
1271
+ ],
1272
+ "source": [
1273
+ "question = \"what is the paid amount?\"\n",
1274
+ "result = rag_system.query(question)"
1275
+ ]
1276
+ },
1277
+ {
1278
+ "cell_type": "code",
1279
+ "execution_count": 12,
1280
+ "metadata": {},
1281
+ "outputs": [
1282
+ {
1283
+ "name": "stdout",
1284
+ "output_type": "stream",
1285
+ "text": [
1286
+ "Answer: The paid amount for this receipt is $22,500. This can be found in the table under the \"Online Payment Total\" category.\n",
1287
+ "\n",
1288
+ "Relevant Context: <!-- image -->\n",
1289
+ "\n",
1290
+ "## THE NEOTIA UNIVERSITY\n",
1291
+ "\n",
1292
+ "Diamond Harbour Road, Sarisha Hat, Sarisha, West Bengal - 743368, India\n",
1293
+ "\n",
1294
+ "Payment Receipt\n",
1295
+ "\n",
1296
+ "Student Details\n",
1297
+ "\n",
1298
+ "Receipt Date\n",
1299
+ "\n",
1300
+ "03/07/2024\n",
1301
+ "\n",
1302
+ "Name\n",
1303
+ "\n",
1304
+ ":\n",
1305
+ "\n",
1306
+ "ANINDYA MITRA\n",
1307
+ "\n",
1308
+ "UID No.\n",
1309
+ "\n",
1310
+ "Course\n",
1311
+ "\n",
1312
+ ":\n",
1313
+ "\n",
1314
+ "Contact No.\n",
1315
+ "\n",
1316
+ "Installment\n",
1317
+ "\n",
1318
+ ":\n",
1319
+ "\n",
1320
+ "Payment Type :\n",
1321
+ "\n",
1322
+ ":\n",
1323
+ "\n",
1324
+ "TNU2021053100042\n",
1325
+ "\n",
1326
+ "Bachelor of Technology in Computer Science & Engineering with\n",
1327
+ "\n",
1328
+ "8240716218\n",
1329
+ "\n",
1330
+ "Semester Fee-7\n",
1331
+ "\n",
1332
+ "Online Payment\n",
1333
+ "\n",
1334
+ "\n",
1335
+ "\n",
1336
+ "Table: | Heads | Amount |\n",
1337
+ "|----------------------------------------------------------|----------------------------------------------------------|\n",
1338
+ "| Outstanding(Tuition Fees & Others) | Outstanding(Tuition Fees & Others) |\n",
1339
+ "| Outstanding(Fooding) | Outstanding(Fooding) |\n",
1340
+ "| Tuition Fee | 22500 |\n",
1341
+ "| Other Charges | |\n",
1342
+ "| Lodging including facilities(for one semester) e P A I D | Lodging including facilities(for one semester) e P A I D |\n",
1343
+ "| Excess | Excess |\n",
1344
+ "| Late Fine 22500 Total | Late Fine 22500 Total |\n",
1345
+ "\n",
1346
+ "Description: The table provides a breakdown of various costs associated with educational expenses, including tuition fees, lodging, fooding, and other charges. The most significant cost is the tuition fee at $22,500. It's interesting to note that there are two categories labeled as \"Outstanding\" for both tuition fees & others, and fooding, suggesting that these costs have not been fully paid.\n",
1347
+ "\n",
1348
+ " The lodging including facilities for one semester is also a substantial cost, although the amount is not specified in this table. The presence of an \"Excess\" and \"Late Fine 22500 Total\" categories implies that there may be additional fees for late payments or exceeding certain limits.\n",
1349
+ "\n",
1350
+ " Overall, the data suggests that the total educational costs can be quite high, with a significant portion of these costs being outstanding, potentially indicating a need for financial planning and budgeting strategies to manage these expenses effectively.\n",
1351
+ "\n",
1352
+ "\n",
1353
+ "For THE NEOTIA UNIVERSITY\n",
1354
+ "\n",
1355
+ ":\n",
1356
+ "\n",
1357
+ ":\n",
1358
+ "\n",
1359
+ "<!-- image -->\n",
1360
+ "\n",
1361
+ "## THE NEOTIA UNIVERSITY\n",
1362
+ "\n",
1363
+ "Diamond Harbour Road, Sarisha Hat, Sarisha, West Bengal - 743368, India\n",
1364
+ "\n",
1365
+ "Payment Receipt\n",
1366
+ "\n",
1367
+ "Semester Fee-7\n",
1368
+ "\n",
1369
+ ":\n",
1370
+ "\n",
1371
+ "Student Details\n",
1372
+ "\n",
1373
+ "8240716218\n",
1374
+ "\n",
1375
+ ": UID No.\n",
1376
+ "\n",
1377
+ "TNU2021053100042\n",
1378
+ "\n",
1379
+ "Name\n",
1380
+ "\n",
1381
+ ":\n",
1382
+ "\n",
1383
+ "ANINDYA MITRA\n",
1384
+ "\n",
1385
+ "Contact No.\n",
1386
+ "\n",
1387
+ "Installment\n",
1388
+ "\n",
1389
+ ":\n",
1390
+ "\n",
1391
+ "Receipt Date\n",
1392
+ "\n",
1393
+ "03/07/2024\n",
1394
+ "\n",
1395
+ ":\n",
1396
+ "\n",
1397
+ "Bachelor of Technology in Computer Science & Engineering with\n",
1398
+ "\n",
1399
+ "Course\n",
1400
+ "\n",
1401
+ ":\n",
1402
+ "\n",
1403
+ "Online Payment\n",
1404
+ "\n",
1405
+ "Payment Type :\n",
1406
+ "\n",
1407
+ ": 418511050700\n",
1408
+ "\n",
1409
+ "Bank Ref. No. e P A I D\n",
1410
+ "\n",
1411
+ ":\n",
1412
+ "\n",
1413
+ "Transaction Ref. No.\n",
1414
+ "\n",
1415
+ "Bank Merchant ID\n",
1416
+ "\n",
1417
+ "005693\n",
1418
+ "\n",
1419
+ "Transaction ID\n",
1420
+ "\n",
1421
+ ":\n",
1422
+ "\n",
1423
+ ":\n",
1424
+ "\n",
1425
+ "Service Charges : NA\n",
1426
+ "\n",
1427
+ "22500\n",
1428
+ "\n",
1429
+ "Online Payment Total\n",
1430
+ "\n",
1431
+ "On-line Payment Details\n",
1432
+ "\n",
1433
+ "For THE NEOTIA UNIVERSITY\n"
1434
+ ]
1435
+ }
1436
+ ],
1437
+ "source": [
1438
+ "print(\"Answer:\", result[\"response\"])\n",
1439
+ "print(\"\\nRelevant Context:\", result[\"context\"])"
1440
+ ]
1441
+ },
1442
+ {
1443
+ "cell_type": "code",
1444
+ "execution_count": 13,
1445
+ "metadata": {},
1446
+ "outputs": [
1447
+ {
1448
+ "name": "stdout",
1449
+ "output_type": "stream",
1450
+ "text": [
1451
+ "Database cleared successfully!\n"
1452
+ ]
1453
+ }
1454
+ ],
1455
+ "source": [
1456
+ "\n",
1457
+ "# Clear the database when done\n",
1458
+ "rag_system.clear_index()"
1459
+ ]
1460
+ },
1461
+ {
1462
+ "cell_type": "code",
1463
+ "execution_count": null,
1464
+ "metadata": {},
1465
+ "outputs": [],
1466
+ "source": []
1467
+ }
1468
+ ],
1469
+ "metadata": {
1470
+ "language_info": {
1471
+ "name": "python"
1472
+ }
1473
+ },
1474
+ "nbformat": 4,
1475
+ "nbformat_minor": 2
1476
+ }
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain-community
2
+ tiktoken
3
+ langchainhub
4
+ langchain
5
+ langchain-huggingface
6
+ sentence_transformers
7
+ langchain-ollama
8
+ ollama
9
+ docling
10
+ easyocr
11
+ FlagEmbedding
12
+ chonkie
13
+ pinecone
14
+ streamlit
src/embedding.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from langchain_ollama import OllamaEmbeddings
3
+
4
+ class EmbeddingModel:
5
+ def __init__(self, model_name: str = "llama3.2"):
6
+ """
7
+ Initialize embedding model with LangChain OllamaEmbeddings
8
+
9
+ Args:
10
+ model_name (str): Name of the model to use
11
+ """
12
+ self.model_name = model_name
13
+ self.embeddings = OllamaEmbeddings(
14
+ model=model_name
15
+ )
16
+
17
+ def embed(self, text: str) -> List[float]:
18
+ """
19
+ Generate embeddings for a single text input
20
+
21
+ Args:
22
+ text (str): Input text to embed
23
+
24
+ Returns:
25
+ List[float]: Embedding vector
26
+ """
27
+ try:
28
+ # Use embed_query for single text embedding
29
+ return self.embeddings.embed_query(text)
30
+ except Exception as e:
31
+ print(f"Error generating embedding: {e}")
32
+ return []
33
+
34
+ def embed_batch(self, texts: List[str]) -> List[List[float]]:
35
+ """
36
+ Generate embeddings for multiple texts
37
+
38
+ Args:
39
+ texts (List[str]): List of input texts to embed
40
+
41
+ Returns:
42
+ List[List[float]]: List of embedding vectors
43
+ """
44
+ try:
45
+ # Use embed_documents for batch embedding
46
+ return self.embeddings.embed_documents(texts)
47
+ except Exception as e:
48
+ print(f"Error generating batch embeddings: {e}")
49
+ return []
50
+
51
+ if __name__ == "__main__":
52
+ # Initialize the embedding model
53
+ embedding_model = EmbeddingModel(model_name="llama3.2")
54
+
55
+ # Generate embedding for a single text
56
+ single_text = "The meaning of life is 42"
57
+ vector = embedding_model.embed(single_text)
58
+ print(vector[:3]) # Print first 3 dimensions
59
+
60
+ # Generate embeddings for multiple texts
61
+ texts = ["Document 1...", "Document 2..."]
62
+ vectors = embedding_model.embed_batch(texts)
63
+ print(len(vectors)) # Number of vectors
64
+ print(vectors[0][:3]) # First 3 dimensions of first vector
src/llm.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional
2
+ from langchain_ollama import ChatOllama
3
+ from langchain_core.messages import HumanMessage, AIMessage
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+
6
+ class LLMChat:
7
+ def __init__(self, model_name: str = "llama3.2", temperature: float = 0):
8
+ """
9
+ Initialize LLMChat with LangChain ChatOllama
10
+
11
+ Args:
12
+ model_name (str): Name of the model to use
13
+ temperature (float): Temperature parameter for response generation
14
+ """
15
+ self.model_name = model_name
16
+ self.llm = ChatOllama(
17
+ model=model_name,
18
+ temperature=temperature
19
+ )
20
+ self.history: List[Dict[str, str]] = []
21
+
22
+ def chat_once(self, message: str):
23
+ """
24
+ Single chat interaction without maintaining history
25
+
26
+ Args:
27
+ message (str): User input message
28
+
29
+ Returns:
30
+ str: Model's response
31
+ """
32
+ try:
33
+ # Create a simple prompt template for single messages
34
+ prompt = ChatPromptTemplate.from_messages([
35
+ ("human", "{input}")
36
+ ])
37
+
38
+ # Create and invoke the chain
39
+ chain = prompt | self.llm
40
+ response = chain.invoke({"input": message})
41
+
42
+ return response.content
43
+ except Exception as e:
44
+ print(f"Error in chat: {e}")
45
+ return ""
46
+
47
+ def chat_with_history(self, message: str):
48
+ """
49
+ Chat interaction maintaining conversation history
50
+
51
+ Args:
52
+ message (str): User input message
53
+
54
+ Returns:
55
+ str: Model's response
56
+ """
57
+ try:
58
+ # Add user message to history
59
+ self.history.append({'role': 'human', 'content': message})
60
+
61
+ # Convert history to LangChain message format
62
+ messages = [
63
+ HumanMessage(content=msg['content']) if msg['role'] == 'human'
64
+ else AIMessage(content=msg['content'])
65
+ for msg in self.history
66
+ ]
67
+
68
+ # Get response using chat method
69
+ response = self.llm.invoke(messages)
70
+ assistant_message = response.content
71
+
72
+ # Add assistant response to history
73
+ self.history.append({'role': 'assistant', 'content': assistant_message})
74
+
75
+ return assistant_message
76
+ except Exception as e:
77
+ print(f"Error in chat with history: {e}")
78
+ return ""
79
+
80
+ def chat_with_template(self, template_messages: List[Dict[str, str]],
81
+ input_variables: Dict[str, str]):
82
+ """
83
+ Chat using a custom template
84
+
85
+ Args:
86
+ template_messages (List[Dict[str, str]]): List of template messages
87
+ input_variables (Dict[str, str]): Variables to fill in the template
88
+
89
+ Returns:
90
+ str: Model's response
91
+ """
92
+ try:
93
+ # Create prompt template from messages
94
+ prompt = ChatPromptTemplate.from_messages([
95
+ (msg['role'], msg['content'])
96
+ for msg in template_messages
97
+ ])
98
+
99
+ # Create and invoke the chain
100
+ chain = prompt | self.llm
101
+ response = chain.invoke(input_variables)
102
+
103
+ return response.content
104
+ except Exception as e:
105
+ print(f"Error in template chat: {e}")
106
+ return ""
107
+
108
+ def clear_history(self):
109
+ """Clear the conversation history"""
110
+ self.history = []
111
+
112
+ def get_history(self) -> List[Dict[str, str]]:
113
+ """Return the current conversation history"""
114
+ return self.history
115
+
116
+ if __name__ == "__main__":
117
+ # Initialize the chat
118
+ chat = LLMChat(model_name="llama3.1", temperature=0)
119
+
120
+ # Example of using a template for translation
121
+ template_messages = [
122
+ {
123
+ "role": "system",
124
+ "content": "You are a helpful assistant that translates {input_language} to {output_language}."
125
+ },
126
+ {
127
+ "role": "human",
128
+ "content": "{input}"
129
+ }
130
+ ]
131
+
132
+ input_vars = {
133
+ "input_language": "English",
134
+ "output_language": "German",
135
+ "input": "I love programming."
136
+ }
137
+
138
+ response = chat.chat_with_template(template_messages, input_vars)
139
+ # Simple chat without history
140
+ response = chat.chat_once("Hello!")
141
+
142
+ # Chat with history
143
+ response = chat.chat_with_history("How are you?")
src/loader.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List, Union
3
+ import logging
4
+ from dataclasses import dataclass
5
+
6
+ from langchain_core.documents import Document as LCDocument
7
+ from langchain_core.document_loaders import BaseLoader
8
+ from docling.document_converter import DocumentConverter, PdfFormatOption
9
+ from docling.datamodel.base_models import InputFormat, ConversionStatus
10
+ from docling.datamodel.pipeline_options import (
11
+ PdfPipelineOptions,
12
+ EasyOcrOptions
13
+ )
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ _log = logging.getLogger(__name__)
17
+
18
+ @dataclass
19
+ class ProcessingResult:
20
+ """Store results of document processing"""
21
+ success_count: int = 0
22
+ failure_count: int = 0
23
+ partial_success_count: int = 0
24
+ failed_files: List[str] = None
25
+
26
+ def __post_init__(self):
27
+ if self.failed_files is None:
28
+ self.failed_files = []
29
+
30
+ class MultiFormatDocumentLoader(BaseLoader):
31
+ """Loader for multiple document formats that converts to LangChain documents"""
32
+
33
+ def __init__(
34
+ self,
35
+ file_paths: Union[str, List[str]],
36
+ enable_ocr: bool = True,
37
+ enable_tables: bool = True
38
+ ):
39
+ self._file_paths = [file_paths] if isinstance(file_paths, str) else file_paths
40
+ self._enable_ocr = enable_ocr
41
+ self._enable_tables = enable_tables
42
+ self._converter = self._setup_converter()
43
+
44
+ def _setup_converter(self):
45
+ """Set up the document converter with appropriate options"""
46
+ # Configure pipeline options
47
+ pipeline_options = PdfPipelineOptions(do_ocr=False, do_table_structure=False, ocr_options=EasyOcrOptions(
48
+ force_full_page_ocr=True
49
+ ))
50
+ if self._enable_ocr:
51
+ pipeline_options.do_ocr = True
52
+ if self._enable_tables:
53
+ pipeline_options.do_table_structure = True
54
+ pipeline_options.table_structure_options.do_cell_matching = True
55
+
56
+ # Create converter with supported formats
57
+ return DocumentConverter(
58
+ allowed_formats=[
59
+ InputFormat.PDF,
60
+ InputFormat.IMAGE,
61
+ InputFormat.DOCX,
62
+ InputFormat.HTML,
63
+ InputFormat.PPTX,
64
+ InputFormat.ASCIIDOC,
65
+ InputFormat.MD,
66
+ ],
67
+ format_options={
68
+ InputFormat.PDF: PdfFormatOption(
69
+ pipeline_options=pipeline_options,
70
+ )}
71
+ )
72
+
73
+ def lazy_load(self):
74
+ """Convert documents and yield LangChain documents"""
75
+ results = ProcessingResult()
76
+
77
+ for file_path in self._file_paths:
78
+ try:
79
+ path = Path(file_path)
80
+ if not path.exists():
81
+ _log.warning(f"File not found: {file_path}")
82
+ results.failure_count += 1
83
+ results.failed_files.append(file_path)
84
+ continue
85
+
86
+ conversion_result = self._converter.convert(path)
87
+
88
+ if conversion_result.status == ConversionStatus.SUCCESS:
89
+ results.success_count += 1
90
+ text = conversion_result.document.export_to_markdown()
91
+ metadata = {
92
+ 'source': str(path),
93
+ 'file_type': path.suffix,
94
+ }
95
+ yield LCDocument(
96
+ page_content=text,
97
+ metadata=metadata
98
+ )
99
+ elif conversion_result.status == ConversionStatus.PARTIAL_SUCCESS:
100
+ results.partial_success_count += 1
101
+ _log.warning(f"Partial conversion for {file_path}")
102
+ text = conversion_result.document.export_to_markdown()
103
+ metadata = {
104
+ 'source': str(path),
105
+ 'file_type': path.suffix,
106
+ 'conversion_status': 'partial'
107
+ }
108
+ yield LCDocument(
109
+ page_content=text,
110
+ metadata=metadata
111
+ )
112
+ else:
113
+ results.failure_count += 1
114
+ results.failed_files.append(file_path)
115
+ _log.error(f"Failed to convert {file_path}")
116
+
117
+ except Exception as e:
118
+ _log.error(f"Error processing {file_path}: {str(e)}")
119
+ results.failure_count += 1
120
+ results.failed_files.append(file_path)
121
+
122
+ # Log final results
123
+ total = results.success_count + results.partial_success_count + results.failure_count
124
+ _log.info(
125
+ f"Processed {total} documents:\n"
126
+ f"- Successfully converted: {results.success_count}\n"
127
+ f"- Partially converted: {results.partial_success_count}\n"
128
+ f"- Failed: {results.failure_count}"
129
+ )
130
+ if results.failed_files:
131
+ _log.info("Failed files:")
132
+ for file in results.failed_files:
133
+ _log.info(f"- {file}")
134
+
135
+
136
+ if __name__ == '__main__':
137
+ # Load documents from a list of file paths
138
+ loader = MultiFormatDocumentLoader(
139
+ file_paths=[
140
+ # './data/2404.19756v1.pdf',
141
+ # './data/OD429347375590223100.pdf',
142
+ '/teamspace/studios/this_studio/TabularRAG/data/FeesPaymentReceipt_7thsem.pdf',
143
+ # './data/UNIT 2 GENDER BASED VIOLENCE.pptx'
144
+ ],
145
+ enable_ocr=False,
146
+ enable_tables=True
147
+ )
148
+ for doc in loader.lazy_load():
149
+ print(doc.page_content)
150
+ print(doc.metadata)
151
+ # save document in .md file
152
+ with open('/teamspace/studios/this_studio/TabularRAG/data/output.md', 'w') as f:
153
+ f.write(doc.page_content)
src/processor.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any
2
+ from tqdm import tqdm
3
+ import time
4
+
5
+ from src.embedding import EmbeddingModel
6
+ from src.llm import LLMChat
7
+
8
+ class TableProcessor:
9
+ def __init__(self, llm_model: LLMChat, embedding_model: EmbeddingModel, batch_size: int = 8):
10
+ """
11
+ Initialize the TableProcessor with pre-initialized models.
12
+
13
+ Args:
14
+ llm_model (LLMChat): Initialized LLM model
15
+ embedding_model (EmbeddingModel): Initialized embedding model
16
+ batch_size (int): Batch size for processing embeddings
17
+ """
18
+ self.llm = llm_model
19
+ self.embedder = embedding_model
20
+ self.batch_size = batch_size
21
+
22
+ def get_table_description(self, markdown_table: str) -> str:
23
+ """
24
+ Generate description for a single markdown table using Ollama chat.
25
+
26
+ Args:
27
+ markdown_table (str): Input markdown table
28
+
29
+ Returns:
30
+ str: Generated description of the table
31
+ """
32
+ system_prompt = """You are an AI language model. Your task is to examine the provided table, taking into account both its rows and columns, and produce a concise summary of up to 200 words. Emphasize key patterns, trends, and notable data points that provide meaningful insights into the content of the table."""
33
+
34
+ try:
35
+ # Use chat_once to avoid maintaining history between tables
36
+ full_prompt = f"{system_prompt}\n\nTable:\n{markdown_table}"
37
+ return self.llm.chat_once(full_prompt)
38
+ except Exception as e:
39
+ print(f"Error generating table description: {e}")
40
+ return ""
41
+
42
+ def process_tables(self, markdown_tables) -> List[Dict[str, Any]]:
43
+ """
44
+ Process a list of markdown tables: generate descriptions and embeddings.
45
+
46
+ Args:
47
+ markdown_tables (List[str]): List of markdown tables to process
48
+
49
+ Returns:
50
+ List[Dict[str, Any]]: List of dictionaries containing processed information
51
+ """
52
+ results = []
53
+ descriptions = []
54
+
55
+ # Generate descriptions for all tables
56
+ with tqdm(total=len(markdown_tables), desc="Generating table descriptions") as pbar:
57
+ for i, table in enumerate(markdown_tables):
58
+ description = self.get_table_description(table.text)
59
+ print(f"\nTable {i+1}:")
60
+ print(f"Description: {description}")
61
+ print("-" * 50)
62
+ descriptions.append(description)
63
+ pbar.update(1)
64
+ time.sleep(1) # Rate limiting
65
+
66
+ # Generate embeddings in batches
67
+ embeddings = []
68
+ total_batches = (len(descriptions) + self.batch_size - 1) // self.batch_size
69
+
70
+ with tqdm(total=total_batches, desc="Generating embeddings") as pbar:
71
+ for i in range(0, len(descriptions), self.batch_size):
72
+ batch = descriptions[i:i + self.batch_size]
73
+ if len(batch) == 1:
74
+ batch_embeddings = [self.embedder.embed(batch[0])]
75
+ else:
76
+ batch_embeddings = self.embedder.embed_batch(batch)
77
+ embeddings.extend(batch_embeddings)
78
+ pbar.update(1)
79
+
80
+ # Combine results with progress bar
81
+ with tqdm(total=len(markdown_tables), desc="Combining results") as pbar:
82
+ for table, description, embedding in zip(markdown_tables, descriptions, embeddings):
83
+ results.append({
84
+ "embedding": embedding,
85
+ "text": table,
86
+ "table_description": description,
87
+ "type": "table_chunk"
88
+ })
89
+ pbar.update(1)
90
+
91
+ return results
92
+
93
+ def __call__(self, markdown_tables) -> List[Dict[str, Any]]:
94
+ """
95
+ Make the class callable for easier use.
96
+
97
+ Args:
98
+ markdown_tables (List[str]): List of markdown tables to process
99
+
100
+ Returns:
101
+ List[Dict[str, Any]]: Processed results
102
+ """
103
+ return self.process_tables(markdown_tables)
src/table_aware_chunker.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Union
2
+ import re
3
+ from dataclasses import dataclass
4
+ from chonkie.chunker import RecursiveChunker
5
+ from chonkie.types import RecursiveChunk
6
+ from chonkie import RecursiveRules
7
+
8
+ @dataclass
9
+ class TableChunk:
10
+ """Represents a table chunk from the markdown document."""
11
+ text: str
12
+ start_index: int
13
+ end_index: int
14
+ token_count: int
15
+
16
+ class TableRecursiveChunker(RecursiveChunker):
17
+ """A recursive chunker that preserves markdown tables while chunking text.
18
+
19
+ This chunker extends the base RecursiveChunker to handle markdown tables as special cases,
20
+ keeping them intact rather than splitting them according to the recursive rules.
21
+ """
22
+
23
+ def _extract_tables(self, text: str) -> Tuple[List[TableChunk], List[Tuple[int, int, str]]]:
24
+ """
25
+ Extract markdown tables from text and return table chunks and remaining text segments.
26
+
27
+ Args:
28
+ text: The input text containing markdown content
29
+
30
+ Returns:
31
+ Tuple containing:
32
+ - List of TableChunk objects for tables
33
+ - List of (start_index, end_index, text) tuples for non-table segments
34
+ """
35
+ # Regular expression for markdown tables (matches header, separator, and content rows)
36
+ table_pattern = r'(\|[^\n]+\|\n\|[-:\|\s]+\|\n(?:\|[^\n]+\|\n)+)'
37
+
38
+ table_chunks = []
39
+ non_table_segments = []
40
+ last_end = 0
41
+
42
+ for match in re.finditer(table_pattern, text):
43
+ start, end = match.span()
44
+
45
+ # Add non-table text before this table
46
+ if start > last_end:
47
+ non_table_segments.append((last_end, start, text[last_end:start]))
48
+
49
+ # Create table chunk
50
+ table_text = match.group()
51
+ token_count = self._count_tokens(table_text)
52
+ table_chunks.append(TableChunk(
53
+ text=table_text,
54
+ start_index=start,
55
+ end_index=end,
56
+ token_count=token_count
57
+ ))
58
+
59
+ last_end = end
60
+
61
+ # Add remaining text after last table
62
+ if last_end < len(text):
63
+ non_table_segments.append((last_end, len(text), text[last_end:]))
64
+
65
+ return table_chunks, non_table_segments
66
+
67
+ def chunk(self, text: str) -> Tuple[List[RecursiveChunk], List[TableChunk]]:
68
+ """
69
+ Chunk the text while preserving tables.
70
+
71
+ This method overrides the base chunk method to handle tables separately from
72
+ regular text content.
73
+
74
+ Args:
75
+ text: The input text to chunk
76
+
77
+ Returns:
78
+ Tuple containing:
79
+ - List of RecursiveChunk objects for non-table text
80
+ - List of TableChunk objects for tables
81
+ """
82
+ # First extract tables
83
+ table_chunks, non_table_segments = self._extract_tables(text)
84
+
85
+ # Chunk each non-table segment using the parent class's recursive chunking
86
+ text_chunks = []
87
+ for start, end, segment in non_table_segments:
88
+ if segment.strip(): # Only process non-empty segments
89
+ # Use the parent class's recursive chunking logic
90
+ chunks = super()._recursive_chunk(segment, level=0, full_text=text)
91
+ text_chunks.extend(chunks)
92
+
93
+ return text_chunks, table_chunks
94
+
95
+ def chunk_batch(self, texts: List[str]) -> List[Tuple[List[RecursiveChunk], List[TableChunk]]]:
96
+ """
97
+ Chunk multiple texts while preserving tables in each.
98
+
99
+ Args:
100
+ texts: List of texts to chunk
101
+
102
+ Returns:
103
+ List of tuples, each containing:
104
+ - List of RecursiveChunk objects for non-table text
105
+ - List of TableChunk objects for tables
106
+ """
107
+ return [self.chunk(text) for text in texts]
108
+
109
+ def __call__(self, texts: Union[str, List[str]]) -> Union[
110
+ Tuple[List[RecursiveChunk], List[TableChunk]],
111
+ List[Tuple[List[RecursiveChunk], List[TableChunk]]]
112
+ ]:
113
+ """Make the chunker callable for convenience."""
114
+ if isinstance(texts, str):
115
+ return self.chunk(texts)
116
+ return self.chunk_batch(texts)
117
+
src/vectordb.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, Optional
2
+ import pandas as pd
3
+ import time
4
+ from tqdm import tqdm
5
+ import logging
6
+ from pinecone import Pinecone, ServerlessSpec
7
+ from dataclasses import dataclass
8
+ from enum import Enum
9
+ from src.table_aware_chunker import TableRecursiveChunker
10
+ from src.processor import TableProcessor
11
+ from src.llm import LLMChat
12
+ from src.embedding import EmbeddingModel
13
+ from chonkie import RecursiveRules
14
+ from src.loader import MultiFormatDocumentLoader
15
+ from dotenv import load_dotenv
16
+ import os
17
+
18
+ load_dotenv()
19
+ # API Keys
20
+ PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
21
+
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
25
+ )
26
+ logger = logging.getLogger('table_aware_rag')
27
+
28
+ class ChunkType(Enum):
29
+ TEXT = "text_chunk"
30
+ TABLE = "table_chunk"
31
+
32
+ @dataclass
33
+ class ProcessedChunk:
34
+ text: str # This will be the embedable text (table description for tables)
35
+ chunk_type: ChunkType
36
+ token_count: int
37
+ markdown_table: Optional[str] = None # Store original markdown table format
38
+ start_index: Optional[int] = None
39
+ end_index: Optional[int] = None
40
+
41
+ def process_documents(
42
+ file_paths: List[str],
43
+ chunker: TableRecursiveChunker,
44
+ processor: TableProcessor,
45
+ output_path: str = './output.md'
46
+ ) -> List[ProcessedChunk]:
47
+ """
48
+ Process documents into text and table chunks
49
+ """
50
+ # Load documents
51
+ loader = MultiFormatDocumentLoader(
52
+ file_paths=file_paths,
53
+ enable_ocr=False,
54
+ enable_tables=True
55
+ )
56
+
57
+ # Save to markdown and read content
58
+ with open(output_path, 'w') as f:
59
+ for doc in loader.lazy_load():
60
+ f.write(doc.page_content)
61
+
62
+ with open(output_path, 'r') as file:
63
+ text = file.read()
64
+
65
+ # Get text and table chunks
66
+ text_chunks, table_chunks = chunker.chunk(text)
67
+
68
+ # Process chunks
69
+ processed_chunks = []
70
+
71
+ # Process text chunks
72
+ for chunk in text_chunks:
73
+ processed_chunks.append(
74
+ ProcessedChunk(
75
+ text=chunk.text,
76
+ chunk_type=ChunkType.TEXT,
77
+ token_count=chunk.token_count,
78
+ start_index=chunk.start_index,
79
+ end_index=chunk.end_index
80
+ )
81
+ )
82
+
83
+ # Process table chunks
84
+ table_results = processor(table_chunks)
85
+ for table in table_results:
86
+ # Convert table chunk to string representation if needed
87
+ table_str = str(table["text"].text)
88
+
89
+ processed_chunks.append(
90
+ ProcessedChunk(
91
+ text=table["table_description"], # Use description for embedding
92
+ chunk_type=ChunkType.TABLE,
93
+ token_count=len(table["table_description"].split()),
94
+ markdown_table=table_str # Store string version of table
95
+ )
96
+ )
97
+
98
+ return processed_chunks
99
+
100
+ class PineconeRetriever:
101
+ def __init__(
102
+ self,
103
+ pinecone_client: Pinecone,
104
+ index_name: str,
105
+ namespace: str,
106
+ embedding_model: Any,
107
+ llm_model: Any
108
+ ):
109
+ """
110
+ Initialize retriever with configurable models
111
+ """
112
+ self.pinecone = pinecone_client
113
+ self.index = self.pinecone.Index(index_name)
114
+ self.namespace = namespace
115
+ self.embedding_model = embedding_model
116
+ self.llm_model = llm_model
117
+
118
+ def _prepare_query(self, question: str) -> List[float]:
119
+ """Generate embedding for query"""
120
+ return self.embedding_model.embed(question)
121
+
122
+ def invoke(
123
+ self,
124
+ question: str,
125
+ top_k: int = 5,
126
+ chunk_type_filter: Optional[ChunkType] = None
127
+ ) -> List[Dict[str, Any]]:
128
+ """
129
+ Retrieve similar documents with optional filtering by chunk type
130
+ """
131
+ query_embedding = self._prepare_query(question)
132
+
133
+ # Prepare filter if chunk type specified
134
+ filter_dict = None
135
+ if chunk_type_filter:
136
+ filter_dict = {"chunk_type": chunk_type_filter.value}
137
+
138
+ results = self.index.query(
139
+ namespace=self.namespace,
140
+ vector=query_embedding,
141
+ top_k=top_k,
142
+ include_values=False,
143
+ include_metadata=True,
144
+ filter=filter_dict
145
+ )
146
+
147
+ retrieved_docs = []
148
+ for match in results.matches:
149
+ doc = {
150
+ "score": match.score,
151
+ "chunk_type": match.metadata["chunk_type"]
152
+ }
153
+
154
+ # Handle different chunk types
155
+ if match.metadata["chunk_type"] == ChunkType.TABLE.value:
156
+ doc["table_description"] = match.metadata["text"] # The embedded description
157
+ doc["markdown_table"] = match.metadata["markdown_table"] # Original table format
158
+ else:
159
+ doc["page_content"] = match.metadata["text"]
160
+
161
+ retrieved_docs.append(doc)
162
+
163
+ return retrieved_docs
164
+
165
+ def ingest_data(
166
+ processed_chunks: List[ProcessedChunk],
167
+ embedding_model: Any,
168
+ pinecone_client: Pinecone,
169
+ index_name: str = "vector-index",
170
+ namespace: str = "rag",
171
+ batch_size: int = 100
172
+ ):
173
+ """
174
+ Ingest processed chunks into Pinecone
175
+ """
176
+ # Create or get index
177
+ if not pinecone_client.has_index(index_name):
178
+ pinecone_client.create_index(
179
+ name=index_name,
180
+ dimension=768,
181
+ metric="cosine",
182
+ spec=ServerlessSpec(
183
+ cloud='aws',
184
+ region='us-east-1'
185
+ )
186
+ )
187
+
188
+ while not pinecone_client.describe_index(index_name).status['ready']:
189
+ time.sleep(1)
190
+
191
+ index = pinecone_client.Index(index_name)
192
+
193
+ # Process in batches
194
+ for i in tqdm(range(0, len(processed_chunks), batch_size)):
195
+ batch = processed_chunks[i:i+batch_size]
196
+
197
+ # Generate embeddings for the text content
198
+ texts = [chunk.text for chunk in batch]
199
+ embeddings = embedding_model.embed_batch(texts)
200
+
201
+ # Prepare records
202
+ records = []
203
+ for idx, chunk in enumerate(batch):
204
+ metadata = {
205
+ "text": chunk.text, # This is the description for tables
206
+ "chunk_type": chunk.chunk_type.value,
207
+ "token_count": chunk.token_count
208
+ }
209
+
210
+ # Add markdown table to metadata if it's a table chunk
211
+ if chunk.markdown_table is not None:
212
+ # Ensure the table is in string format
213
+ metadata["markdown_table"] = str(chunk.markdown_table)
214
+
215
+ records.append({
216
+ "id": f"chunk_{i + idx}",
217
+ "values": embeddings[idx],
218
+ "metadata": metadata
219
+ })
220
+
221
+ # Upsert to Pinecone
222
+ try:
223
+ index.upsert(vectors=records, namespace=namespace)
224
+ except Exception as e:
225
+ logger.error(f"Error during upsert: {str(e)}")
226
+ logger.error(f"Problematic record metadata: {records[0]['metadata']}")
227
+ raise
228
+
229
+ time.sleep(0.5) # Rate limiting
230
+
231
+
232
+ def main():
233
+ # Initialize components
234
+ pc = Pinecone(api_key=PINECONE_API_KEY)
235
+
236
+ chunker = TableRecursiveChunker(
237
+ tokenizer="gpt2",
238
+ chunk_size=512,
239
+ rules=RecursiveRules(),
240
+ min_characters_per_chunk=12
241
+ )
242
+
243
+ llm = LLMChat("qwen2.5:0.5b")
244
+ embedder = EmbeddingModel("nomic-embed-text")
245
+
246
+ processor = TableProcessor(
247
+ llm_model=llm,
248
+ embedding_model=embedder,
249
+ batch_size=8
250
+ )
251
+
252
+ try:
253
+ # Process documents
254
+ processed_chunks = process_documents(
255
+ file_paths=['/teamspace/studios/this_studio/TabularRAG/data/FeesPaymentReceipt_7thsem.pdf'],
256
+ chunker=chunker,
257
+ processor=processor
258
+ )
259
+
260
+ # Ingest data
261
+ ingest_data(
262
+ processed_chunks=processed_chunks,
263
+ embedding_model=embedder,
264
+ pinecone_client=pc
265
+ )
266
+
267
+ # Test retrieval
268
+ retriever = PineconeRetriever(
269
+ pinecone_client=pc,
270
+ index_name="vector-index",
271
+ namespace="rag",
272
+ embedding_model=embedder,
273
+ llm_model=llm
274
+ )
275
+
276
+ # # Test text-only retrieval
277
+ # text_results = retriever.invoke(
278
+ # question="What is paid fees amount?",
279
+ # top_k=3,
280
+ # chunk_type_filter=ChunkType.TEXT
281
+ # )
282
+ # print("Text results:")
283
+ # for result in text_results:
284
+ # print(result)
285
+ # Test table-only retrieval
286
+ # table_results = retriever.invoke(
287
+ # question="What is paid fees amount?",
288
+ # top_k=3,
289
+ # chunk_type_filter=ChunkType.TABLE
290
+ # )
291
+ # print("Table results:")
292
+ # for result in table_results:
293
+ # print(result)
294
+
295
+ results = retriever.invoke(
296
+ question="What is paid fees amount?",
297
+ top_k=3
298
+ )
299
+
300
+ for i, result in enumerate(results, 1):
301
+ print(f"\nResult {i}:")
302
+ if result["chunk_type"] == ChunkType.TABLE.value:
303
+ print(f"Table Description: {result['table_description']}")
304
+ print("Table Format:")
305
+ print(result['markdown_table'])
306
+ else:
307
+ print(f"Content: {result['page_content']}")
308
+ print(f"Score: {result['score']}")
309
+
310
+ except Exception as e:
311
+ logger.error(f"Error in pipeline: {str(e)}")
312
+
313
+ if __name__ == "__main__":
314
+ main()