Spaces:
Runtime error
Runtime error
from typing import List, Dict, Any, Optional | |
import pandas as pd | |
import time | |
from tqdm import tqdm | |
import logging | |
from pinecone import Pinecone, ServerlessSpec | |
from dataclasses import dataclass | |
from enum import Enum | |
from src.table_aware_chunker import TableRecursiveChunker | |
from src.processor import TableProcessor | |
from src.llm import LLMChat | |
from src.embedding import EmbeddingModel | |
from chonkie import RecursiveRules | |
from src.loader import MultiFormatDocumentLoader | |
from dotenv import load_dotenv | |
import os | |
load_dotenv() | |
# API Keys | |
PINECONE_API_KEY = os.getenv('PINECONE_API_KEY') | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger('table_aware_rag') | |
class ChunkType(Enum): | |
TEXT = "text_chunk" | |
TABLE = "table_chunk" | |
class ProcessedChunk: | |
text: str # This will be the embedable text (table description for tables) | |
chunk_type: ChunkType | |
token_count: int | |
markdown_table: Optional[str] = None # Store original markdown table format | |
start_index: Optional[int] = None | |
end_index: Optional[int] = None | |
def process_documents( | |
file_paths: List[str], | |
chunker: TableRecursiveChunker, | |
processor: TableProcessor, | |
output_path: str = './output.md' | |
) -> List[ProcessedChunk]: | |
""" | |
Process documents into text and table chunks | |
""" | |
# Load documents | |
loader = MultiFormatDocumentLoader( | |
file_paths=file_paths, | |
enable_ocr=False, | |
enable_tables=True | |
) | |
# Save to markdown and read content | |
with open(output_path, 'w') as f: | |
for doc in loader.lazy_load(): | |
f.write(doc.page_content) | |
with open(output_path, 'r') as file: | |
text = file.read() | |
# Get text and table chunks | |
text_chunks, table_chunks = chunker.chunk(text) | |
# Process chunks | |
processed_chunks = [] | |
# Process text chunks | |
for chunk in text_chunks: | |
processed_chunks.append( | |
ProcessedChunk( | |
text=chunk.text, | |
chunk_type=ChunkType.TEXT, | |
token_count=chunk.token_count, | |
start_index=chunk.start_index, | |
end_index=chunk.end_index | |
) | |
) | |
# Process table chunks | |
table_results = processor(table_chunks) | |
for table in table_results: | |
# Convert table chunk to string representation if needed | |
table_str = str(table["text"].text) | |
processed_chunks.append( | |
ProcessedChunk( | |
text=table["table_description"], # Use description for embedding | |
chunk_type=ChunkType.TABLE, | |
token_count=len(table["table_description"].split()), | |
markdown_table=table_str # Store string version of table | |
) | |
) | |
return processed_chunks | |
class PineconeRetriever: | |
def __init__( | |
self, | |
pinecone_client: Pinecone, | |
index_name: str, | |
namespace: str, | |
embedding_model: Any, | |
llm_model: Any | |
): | |
""" | |
Initialize retriever with configurable models | |
""" | |
self.pinecone = pinecone_client | |
self.index = self.pinecone.Index(index_name) | |
self.namespace = namespace | |
self.embedding_model = embedding_model | |
self.llm_model = llm_model | |
def _prepare_query(self, question: str) -> List[float]: | |
"""Generate embedding for query""" | |
return self.embedding_model.embed(question) | |
def invoke( | |
self, | |
question: str, | |
top_k: int = 5, | |
chunk_type_filter: Optional[ChunkType] = None | |
) -> List[Dict[str, Any]]: | |
""" | |
Retrieve similar documents with optional filtering by chunk type | |
""" | |
query_embedding = self._prepare_query(question) | |
# Prepare filter if chunk type specified | |
filter_dict = None | |
if chunk_type_filter: | |
filter_dict = {"chunk_type": chunk_type_filter.value} | |
results = self.index.query( | |
namespace=self.namespace, | |
vector=query_embedding, | |
top_k=top_k, | |
include_values=False, | |
include_metadata=True, | |
filter=filter_dict | |
) | |
retrieved_docs = [] | |
for match in results.matches: | |
doc = { | |
"score": match.score, | |
"chunk_type": match.metadata["chunk_type"] | |
} | |
# Handle different chunk types | |
if match.metadata["chunk_type"] == ChunkType.TABLE.value: | |
doc["table_description"] = match.metadata["text"] # The embedded description | |
doc["markdown_table"] = match.metadata["markdown_table"] # Original table format | |
else: | |
doc["page_content"] = match.metadata["text"] | |
retrieved_docs.append(doc) | |
return retrieved_docs | |
def ingest_data( | |
processed_chunks: List[ProcessedChunk], | |
embedding_model: Any, | |
pinecone_client: Pinecone, | |
index_name: str = "vector-index", | |
namespace: str = "rag", | |
batch_size: int = 100 | |
): | |
""" | |
Ingest processed chunks into Pinecone | |
""" | |
# Create or get index | |
if not pinecone_client.has_index(index_name): | |
pinecone_client.create_index( | |
name=index_name, | |
dimension=768, | |
metric="cosine", | |
spec=ServerlessSpec( | |
cloud='aws', | |
region='us-east-1' | |
) | |
) | |
while not pinecone_client.describe_index(index_name).status['ready']: | |
time.sleep(1) | |
index = pinecone_client.Index(index_name) | |
# Process in batches | |
for i in tqdm(range(0, len(processed_chunks), batch_size)): | |
batch = processed_chunks[i:i+batch_size] | |
# Generate embeddings for the text content | |
texts = [chunk.text for chunk in batch] | |
embeddings = embedding_model.embed_batch(texts) | |
# Prepare records | |
records = [] | |
for idx, chunk in enumerate(batch): | |
metadata = { | |
"text": chunk.text, # This is the description for tables | |
"chunk_type": chunk.chunk_type.value, | |
"token_count": chunk.token_count | |
} | |
# Add markdown table to metadata if it's a table chunk | |
if chunk.markdown_table is not None: | |
# Ensure the table is in string format | |
metadata["markdown_table"] = str(chunk.markdown_table) | |
records.append({ | |
"id": f"chunk_{i + idx}", | |
"values": embeddings[idx], | |
"metadata": metadata | |
}) | |
# Upsert to Pinecone | |
try: | |
index.upsert(vectors=records, namespace=namespace) | |
except Exception as e: | |
logger.error(f"Error during upsert: {str(e)}") | |
logger.error(f"Problematic record metadata: {records[0]['metadata']}") | |
raise | |
time.sleep(0.5) # Rate limiting | |
def main(): | |
# Initialize components | |
pc = Pinecone(api_key=PINECONE_API_KEY) | |
chunker = TableRecursiveChunker( | |
tokenizer="gpt2", | |
chunk_size=512, | |
rules=RecursiveRules(), | |
min_characters_per_chunk=12 | |
) | |
llm = LLMChat("qwen2.5:0.5b") | |
embedder = EmbeddingModel("nomic-embed-text") | |
processor = TableProcessor( | |
llm_model=llm, | |
embedding_model=embedder, | |
batch_size=8 | |
) | |
try: | |
# Process documents | |
processed_chunks = process_documents( | |
file_paths=['/teamspace/studios/this_studio/TabularRAG/data/FeesPaymentReceipt_7thsem.pdf'], | |
chunker=chunker, | |
processor=processor | |
) | |
# Ingest data | |
ingest_data( | |
processed_chunks=processed_chunks, | |
embedding_model=embedder, | |
pinecone_client=pc | |
) | |
# Test retrieval | |
retriever = PineconeRetriever( | |
pinecone_client=pc, | |
index_name="vector-index", | |
namespace="rag", | |
embedding_model=embedder, | |
llm_model=llm | |
) | |
# # Test text-only retrieval | |
# text_results = retriever.invoke( | |
# question="What is paid fees amount?", | |
# top_k=3, | |
# chunk_type_filter=ChunkType.TEXT | |
# ) | |
# print("Text results:") | |
# for result in text_results: | |
# print(result) | |
# Test table-only retrieval | |
# table_results = retriever.invoke( | |
# question="What is paid fees amount?", | |
# top_k=3, | |
# chunk_type_filter=ChunkType.TABLE | |
# ) | |
# print("Table results:") | |
# for result in table_results: | |
# print(result) | |
results = retriever.invoke( | |
question="What is paid fees amount?", | |
top_k=3 | |
) | |
for i, result in enumerate(results, 1): | |
print(f"\nResult {i}:") | |
if result["chunk_type"] == ChunkType.TABLE.value: | |
print(f"Table Description: {result['table_description']}") | |
print("Table Format:") | |
print(result['markdown_table']) | |
else: | |
print(f"Content: {result['page_content']}") | |
print(f"Score: {result['score']}") | |
except Exception as e: | |
logger.error(f"Error in pipeline: {str(e)}") | |
if __name__ == "__main__": | |
main() |