TalatMasood commited on
Commit
d161383
·
1 Parent(s): e9d730a

Update knowledge upload api and linked chromadb to mongodb

Browse files
.vscode/launch.json CHANGED
@@ -2,7 +2,7 @@
2
  "version": "0.2.0",
3
  "configurations": [
4
  {
5
- "name": "Python: FastAPI",
6
  "type": "python",
7
  "request": "launch",
8
  "module": "uvicorn",
@@ -17,7 +17,7 @@
17
  }
18
  },
19
  {
20
- "name": "Python: Test",
21
  "type": "python",
22
  "request": "launch",
23
  "module": "pytest",
 
2
  "version": "0.2.0",
3
  "configurations": [
4
  {
5
+ "name": "Chatbot",
6
  "type": "python",
7
  "request": "launch",
8
  "module": "uvicorn",
 
17
  }
18
  },
19
  {
20
+ "name": "Chatbot: Tests",
21
  "type": "python",
22
  "request": "launch",
23
  "module": "pytest",
DocKnowledge-based chatbot.docx ADDED
Binary file (16.9 kB). View file
 
src/__pycache__/main.cpython-312.pyc CHANGED
Binary files a/src/__pycache__/main.cpython-312.pyc and b/src/__pycache__/main.cpython-312.pyc differ
 
src/db/__pycache__/mongodb_store.cpython-312.pyc CHANGED
Binary files a/src/db/__pycache__/mongodb_store.cpython-312.pyc and b/src/db/__pycache__/mongodb_store.cpython-312.pyc differ
 
src/db/mongodb_store.py CHANGED
@@ -2,7 +2,7 @@
2
  from motor.motor_asyncio import AsyncIOMotorClient
3
  from datetime import datetime
4
  import json
5
- from typing import List, Dict, Optional
6
  from bson import ObjectId
7
 
8
  class MongoDBStore:
@@ -11,6 +11,40 @@ class MongoDBStore:
11
  self.client = AsyncIOMotorClient(mongo_uri)
12
  self.db = self.client.rag_chatbot
13
  self.chat_history = self.db.chat_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  async def store_message(
16
  self,
 
2
  from motor.motor_asyncio import AsyncIOMotorClient
3
  from datetime import datetime
4
  import json
5
+ from typing import List, Dict, Optional, Any
6
  from bson import ObjectId
7
 
8
  class MongoDBStore:
 
11
  self.client = AsyncIOMotorClient(mongo_uri)
12
  self.db = self.client.rag_chatbot
13
  self.chat_history = self.db.chat_history
14
+ self.documents = self.db.documents # Collection for original documents
15
+
16
+ async def store_document(
17
+ self,
18
+ document_id: str,
19
+ filename: str,
20
+ content: str,
21
+ content_type: str,
22
+ file_size: int
23
+ ) -> str:
24
+ """Store original document in MongoDB"""
25
+ document = {
26
+ "document_id": document_id,
27
+ "filename": filename,
28
+ "content": content,
29
+ "content_type": content_type,
30
+ "file_size": file_size,
31
+ "upload_timestamp": datetime.now()
32
+ }
33
+
34
+ await self.documents.insert_one(document)
35
+ return document_id
36
+
37
+ async def get_document(self, document_id: str) -> Optional[Dict]:
38
+ """Retrieve document by ID"""
39
+ return await self.documents.find_one(
40
+ {"document_id": document_id},
41
+ {"_id": 0} # Exclude MongoDB's _id
42
+ )
43
+
44
+ async def get_all_documents(self) -> List[Dict]:
45
+ """Retrieve all documents"""
46
+ cursor = self.documents.find({}, {"_id": 0})
47
+ return await cursor.to_list(length=None)
48
 
49
  async def store_message(
50
  self,
src/implementations/__pycache__/document_service.cpython-312.pyc CHANGED
Binary files a/src/implementations/__pycache__/document_service.cpython-312.pyc and b/src/implementations/__pycache__/document_service.cpython-312.pyc differ
 
src/implementations/document_service.py CHANGED
@@ -2,17 +2,24 @@
2
  from pathlib import Path
3
  import shutil
4
  import os
5
- import uuid
6
- from typing import List, Tuple
7
  from fastapi import UploadFile, BackgroundTasks
8
- from ..vectorstores.chroma_vectorstore import ChromaVectorStore
9
- from ..utils.document_processor import DocumentProcessor
10
- from ..models import DocumentResponse, DocumentInfo, BatchUploadResponse
11
- from ..utils.logger import logger
 
 
12
 
13
  class DocumentService:
14
- def __init__(self, doc_processor: DocumentProcessor):
 
 
 
 
15
  self.doc_processor = doc_processor
 
16
  self.upload_dir = Path("temp_uploads")
17
  self.upload_dir.mkdir(exist_ok=True)
18
 
@@ -70,11 +77,6 @@ class DocumentService:
70
 
71
  return processed_files, failed_files
72
 
73
- def _is_supported_format(self, filename: str) -> bool:
74
- """Check if file format is supported"""
75
- return any(filename.lower().endswith(ext)
76
- for ext in self.doc_processor.supported_formats)
77
-
78
  async def _process_single_file(
79
  self,
80
  file: UploadFile,
@@ -82,57 +84,93 @@ class DocumentService:
82
  background_tasks: BackgroundTasks
83
  ) -> DocumentResponse:
84
  """Process a single file upload"""
85
- document_id = str(uuid.uuid4())
 
86
  temp_path = self.upload_dir / f"{document_id}_{file.filename}"
87
 
88
- # Save file
89
- with open(temp_path, "wb") as buffer:
90
- shutil.copyfileobj(file.file, buffer)
91
-
92
- # Add background task for processing
93
- background_tasks.add_task(
94
- self._process_and_store_document,
95
- temp_path,
96
- vector_store,
97
- document_id
98
- )
 
 
 
 
 
 
99
 
100
- return DocumentResponse(
101
- message="Document queued for processing",
102
- document_id=document_id,
103
- status="processing",
104
- document_info=DocumentInfo(
105
- original_filename=file.filename,
106
- size=os.path.getsize(temp_path),
107
- content_type=file.content_type
108
  )
109
- )
110
 
111
- async def _process_and_store_document(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  self,
113
- file_path: Path,
114
  vector_store: ChromaVectorStore,
115
- document_id: str
 
116
  ):
117
- """Process document and store in vector database"""
118
  try:
119
- processed_doc = await self.doc_processor.process_document(file_path)
 
120
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  vector_store.add_documents(
122
- documents=processed_doc['chunks'],
123
- metadatas=[{
124
- 'document_id': document_id,
125
- 'chunk_id': i,
126
- 'source': str(file_path.name),
127
- 'metadata': processed_doc['metadata']
128
- } for i in range(len(processed_doc['chunks']))],
129
- ids=[f"{document_id}_chunk_{i}" for i in range(len(processed_doc['chunks']))]
130
  )
131
 
132
- return processed_doc
133
- finally:
134
- if file_path.exists():
135
- file_path.unlink()
 
 
 
 
 
 
136
 
137
  def _create_failed_file_entry(self, filename: str, error: str) -> dict:
138
  """Create a failed file entry"""
 
2
  from pathlib import Path
3
  import shutil
4
  import os
5
+ from uuid import uuid4
6
+ from typing import List, Tuple, Dict
7
  from fastapi import UploadFile, BackgroundTasks
8
+
9
+ from src.vectorstores.chroma_vectorstore import ChromaVectorStore
10
+ from src.utils.document_processor import DocumentProcessor
11
+ from src.models import DocumentResponse, DocumentInfo, BatchUploadResponse
12
+ from src.utils.logger import logger
13
+ from src.db.mongodb_store import MongoDBStore
14
 
15
  class DocumentService:
16
+ def __init__(
17
+ self,
18
+ doc_processor: DocumentProcessor,
19
+ mongodb: MongoDBStore
20
+ ):
21
  self.doc_processor = doc_processor
22
+ self.mongodb = mongodb
23
  self.upload_dir = Path("temp_uploads")
24
  self.upload_dir.mkdir(exist_ok=True)
25
 
 
77
 
78
  return processed_files, failed_files
79
 
 
 
 
 
 
80
  async def _process_single_file(
81
  self,
82
  file: UploadFile,
 
84
  background_tasks: BackgroundTasks
85
  ) -> DocumentResponse:
86
  """Process a single file upload"""
87
+ # Generate UUID for document
88
+ document_id = str(uuid4())
89
  temp_path = self.upload_dir / f"{document_id}_{file.filename}"
90
 
91
+ try:
92
+ # Save file temporarily
93
+ with open(temp_path, "wb") as buffer:
94
+ shutil.copyfileobj(file.file, buffer)
95
+
96
+ # Process the document to get content and metadata
97
+ processed_doc = await self.doc_processor.process_document(temp_path)
98
+ content = processed_doc['content']
99
+
100
+ # First, store in MongoDB
101
+ await self.mongodb.store_document(
102
+ document_id=document_id,
103
+ filename=file.filename,
104
+ content=content,
105
+ content_type=file.content_type,
106
+ file_size=os.path.getsize(temp_path)
107
+ )
108
 
109
+ # Then process for vector store in background
110
+ background_tasks.add_task(
111
+ self._process_for_vector_store,
112
+ processed_doc['chunks'], # Use the chunks from processed document
113
+ vector_store,
114
+ document_id,
115
+ file.filename
 
116
  )
 
117
 
118
+ return DocumentResponse(
119
+ message="Document uploaded successfully",
120
+ document_id=document_id,
121
+ status="processing",
122
+ document_info=DocumentInfo(
123
+ original_filename=file.filename,
124
+ size=os.path.getsize(temp_path),
125
+ content_type=file.content_type
126
+ )
127
+ )
128
+ finally:
129
+ # Clean up temporary file
130
+ if temp_path.exists():
131
+ temp_path.unlink()
132
+
133
+ async def _process_for_vector_store(
134
  self,
135
+ chunks: List[str], # Now accepting pre-processed chunks
136
  vector_store: ChromaVectorStore,
137
+ document_id: str,
138
+ filename: str
139
  ):
140
+ """Process document content for vector store"""
141
  try:
142
+ # Generate chunk IDs using document_id
143
+ chunk_ids = [f"{document_id}-chunk-{i}" for i in range(len(chunks))]
144
 
145
+ # Get embeddings
146
+ embeddings = vector_store.embedding_function(chunks)
147
+
148
+ # Prepare metadata for each chunk
149
+ metadatas = [{
150
+ 'document_id': document_id, # MongoDB document ID
151
+ 'source_file': filename,
152
+ 'chunk_index': i,
153
+ 'total_chunks': len(chunks)
154
+ } for i in range(len(chunks))]
155
+
156
+ # Store in vector store
157
  vector_store.add_documents(
158
+ documents=chunks,
159
+ embeddings=embeddings,
160
+ metadatas=metadatas,
161
+ ids=chunk_ids
 
 
 
 
162
  )
163
 
164
+ logger.info(f"Successfully processed document {filename} (ID: {document_id}) into {len(chunks)} chunks")
165
+
166
+ except Exception as e:
167
+ logger.error(f"Error processing document {filename} (ID: {document_id}) for vector store: {str(e)}")
168
+ raise
169
+
170
+ def _is_supported_format(self, filename: str) -> bool:
171
+ """Check if file format is supported"""
172
+ return any(filename.lower().endswith(ext)
173
+ for ext in self.doc_processor.supported_formats)
174
 
175
  def _create_failed_file_entry(self, filename: str, error: str) -> dict:
176
  """Create a failed file entry"""
src/main.py CHANGED
@@ -7,6 +7,7 @@ from datetime import datetime
7
 
8
  # Import custom modules
9
  from src.agents.rag_agent import RAGAgent
 
10
  from src.utils.document_processor import DocumentProcessor
11
  from src.utils.conversation_summarizer import ConversationSummarizer
12
  from src.utils.logger import logger
@@ -16,6 +17,7 @@ from src.implementations.document_service import DocumentService
16
  from src.models import (
17
  ChatRequest,
18
  ChatResponse,
 
19
  BatchUploadResponse,
20
  SummarizeRequest,
21
  SummaryResponse,
@@ -25,6 +27,9 @@ from config.config import settings
25
 
26
  app = FastAPI(title="RAG Chatbot API")
27
 
 
 
 
28
  # Initialize core components
29
  doc_processor = DocumentProcessor(
30
  chunk_size=1000,
@@ -32,10 +37,7 @@ doc_processor = DocumentProcessor(
32
  max_file_size=10 * 1024 * 1024
33
  )
34
  summarizer = ConversationSummarizer()
35
- document_service = DocumentService(doc_processor)
36
-
37
- # Initialize MongoDB
38
- mongodb = MongoDBStore(settings.MONGODB_URI)
39
 
40
  @app.post("/documents/upload", response_model=BatchUploadResponse)
41
  async def upload_documents(
@@ -57,6 +59,52 @@ async def upload_documents(
57
  finally:
58
  document_service.cleanup()
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  @app.post("/chat", response_model=ChatResponse)
61
  async def chat_endpoint(
62
  request: ChatRequest,
 
7
 
8
  # Import custom modules
9
  from src.agents.rag_agent import RAGAgent
10
+ from src.models.document import AllDocumentsResponse, StoredDocument
11
  from src.utils.document_processor import DocumentProcessor
12
  from src.utils.conversation_summarizer import ConversationSummarizer
13
  from src.utils.logger import logger
 
17
  from src.models import (
18
  ChatRequest,
19
  ChatResponse,
20
+ DocumentResponse,
21
  BatchUploadResponse,
22
  SummarizeRequest,
23
  SummaryResponse,
 
27
 
28
  app = FastAPI(title="RAG Chatbot API")
29
 
30
+ # Initialize MongoDB
31
+ mongodb = MongoDBStore(settings.MONGODB_URI)
32
+
33
  # Initialize core components
34
  doc_processor = DocumentProcessor(
35
  chunk_size=1000,
 
37
  max_file_size=10 * 1024 * 1024
38
  )
39
  summarizer = ConversationSummarizer()
40
+ document_service = DocumentService(doc_processor, mongodb)
 
 
 
41
 
42
  @app.post("/documents/upload", response_model=BatchUploadResponse)
43
  async def upload_documents(
 
59
  finally:
60
  document_service.cleanup()
61
 
62
+ @app.get("/documents", response_model=AllDocumentsResponse)
63
+ async def get_all_documents(include_embeddings: bool = False):
64
+ """
65
+ Get all documents stored in the system
66
+
67
+ Args:
68
+ include_embeddings (bool): Whether to include embeddings in the response
69
+ """
70
+ try:
71
+ vector_store, _ = await get_vector_store()
72
+ documents = vector_store.get_all_documents(include_embeddings=include_embeddings)
73
+
74
+ return AllDocumentsResponse(
75
+ total_documents=len(documents),
76
+ documents=[
77
+ StoredDocument(
78
+ id=doc['id'],
79
+ text=doc['text'],
80
+ embedding=doc.get('embedding'),
81
+ metadata=doc.get('metadata')
82
+ ) for doc in documents
83
+ ]
84
+ )
85
+ except Exception as e:
86
+ logger.error(f"Error retrieving documents: {str(e)}")
87
+ raise HTTPException(status_code=500, detail=str(e))
88
+
89
+ @app.get("/documentchunks/{document_id}")
90
+ async def get_document_chunks(document_id: str):
91
+ """Get all chunks for a specific document"""
92
+ try:
93
+ vector_store, _ = await get_vector_store()
94
+ chunks = vector_store.get_document_chunks(document_id)
95
+
96
+ if not chunks:
97
+ raise HTTPException(status_code=404, detail="Document not found")
98
+
99
+ return {
100
+ "document_id": document_id,
101
+ "total_chunks": len(chunks),
102
+ "chunks": chunks
103
+ }
104
+ except Exception as e:
105
+ logger.error(f"Error retrieving document chunks: {str(e)}")
106
+ raise HTTPException(status_code=500, detail=str(e))
107
+
108
  @app.post("/chat", response_model=ChatResponse)
109
  async def chat_endpoint(
110
  request: ChatRequest,
src/models/__pycache__/document.cpython-312.pyc CHANGED
Binary files a/src/models/__pycache__/document.cpython-312.pyc and b/src/models/__pycache__/document.cpython-312.pyc differ
 
src/models/document.py CHANGED
@@ -1,6 +1,6 @@
1
  # src/models/document.py
2
  from pydantic import BaseModel
3
- from typing import Optional, List
4
 
5
  class DocumentInfo(BaseModel):
6
  """Document information model"""
@@ -19,4 +19,16 @@ class BatchUploadResponse(BaseModel):
19
  """Response model for batch document upload"""
20
  message: str
21
  processed_files: List[DocumentResponse]
22
- failed_files: List[dict]
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # src/models/document.py
2
  from pydantic import BaseModel
3
+ from typing import Optional, List, Dict, Any
4
 
5
  class DocumentInfo(BaseModel):
6
  """Document information model"""
 
19
  """Response model for batch document upload"""
20
  message: str
21
  processed_files: List[DocumentResponse]
22
+ failed_files: List[dict]
23
+
24
+ class StoredDocument(BaseModel):
25
+ """Model for a document stored in the vector store"""
26
+ id: str
27
+ text: str
28
+ embedding: Optional[List[float]] = None
29
+ metadata: Optional[Dict[str, Any]] = None
30
+
31
+ class AllDocumentsResponse(BaseModel):
32
+ """Response model for retrieving all documents"""
33
+ total_documents: int
34
+ documents: List[StoredDocument]
src/vectorstores/__pycache__/base_vectorstore.cpython-312.pyc CHANGED
Binary files a/src/vectorstores/__pycache__/base_vectorstore.cpython-312.pyc and b/src/vectorstores/__pycache__/base_vectorstore.cpython-312.pyc differ
 
src/vectorstores/__pycache__/chroma_vectorstore.cpython-312.pyc CHANGED
Binary files a/src/vectorstores/__pycache__/chroma_vectorstore.cpython-312.pyc and b/src/vectorstores/__pycache__/chroma_vectorstore.cpython-312.pyc differ
 
src/vectorstores/base_vectorstore.py CHANGED
@@ -1,20 +1,21 @@
1
  # src/vectorstores/base_vectorstore.py
2
  from abc import ABC, abstractmethod
3
- from typing import List, Callable, Any
4
 
5
  class BaseVectorStore(ABC):
6
  @abstractmethod
7
  def add_documents(
8
  self,
9
  documents: List[str],
10
- embeddings: List[List[float]]
11
  ) -> None:
12
  """
13
  Add documents to the vector store
14
 
15
  Args:
16
  documents (List[str]): List of document texts
17
- embeddings (List[List[float]]): Corresponding embeddings
 
18
  """
19
  pass
20
 
@@ -22,7 +23,8 @@ class BaseVectorStore(ABC):
22
  def similarity_search(
23
  self,
24
  query_embedding: List[float],
25
- top_k: int = 3
 
26
  ) -> List[str]:
27
  """
28
  Perform similarity search
@@ -30,8 +32,25 @@ class BaseVectorStore(ABC):
30
  Args:
31
  query_embedding (List[float]): Embedding of the query
32
  top_k (int): Number of top similar documents to retrieve
 
33
 
34
  Returns:
35
  List[str]: List of most similar documents
36
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  pass
 
1
  # src/vectorstores/base_vectorstore.py
2
  from abc import ABC, abstractmethod
3
+ from typing import List, Callable, Any, Dict, Optional
4
 
5
  class BaseVectorStore(ABC):
6
  @abstractmethod
7
  def add_documents(
8
  self,
9
  documents: List[str],
10
+ embeddings: Optional[List[List[float]]] = None
11
  ) -> None:
12
  """
13
  Add documents to the vector store
14
 
15
  Args:
16
  documents (List[str]): List of document texts
17
+ embeddings (Optional[List[List[float]]]): Corresponding embeddings.
18
+ If not provided, they will be generated using the embedding function.
19
  """
20
  pass
21
 
 
23
  def similarity_search(
24
  self,
25
  query_embedding: List[float],
26
+ top_k: int = 3,
27
+ **kwargs
28
  ) -> List[str]:
29
  """
30
  Perform similarity search
 
32
  Args:
33
  query_embedding (List[float]): Embedding of the query
34
  top_k (int): Number of top similar documents to retrieve
35
+ **kwargs: Additional search parameters
36
 
37
  Returns:
38
  List[str]: List of most similar documents
39
  """
40
+ pass
41
+
42
+ @abstractmethod
43
+ def get_all_documents(
44
+ self,
45
+ include_embeddings: bool = False
46
+ ) -> List[Dict[str, Any]]:
47
+ """
48
+ Retrieve all documents from the vector store
49
+
50
+ Args:
51
+ include_embeddings (bool): Whether to include embeddings in the response
52
+
53
+ Returns:
54
+ List[Dict[str, Any]]: List of documents with their IDs and optionally embeddings
55
+ """
56
  pass
src/vectorstores/chroma_vectorstore.py CHANGED
@@ -1,6 +1,8 @@
1
  # src/vectorstores/chroma_vectorstore.py
2
  import chromadb
3
- from typing import List, Callable, Any
 
 
4
 
5
  from .base_vectorstore import BaseVectorStore
6
 
@@ -8,7 +10,9 @@ class ChromaVectorStore(BaseVectorStore):
8
  def __init__(
9
  self,
10
  embedding_function: Callable[[List[str]], List[List[float]]],
11
- persist_directory: str = './chroma_db'
 
 
12
  ):
13
  """
14
  Initialize Chroma Vector Store
@@ -16,39 +20,77 @@ class ChromaVectorStore(BaseVectorStore):
16
  Args:
17
  embedding_function (Callable): Function to generate embeddings
18
  persist_directory (str): Directory to persist the vector store
 
 
19
  """
20
- self.client = chromadb.PersistentClient(path=persist_directory)
21
- self.collection = self.client.get_or_create_collection(name="documents")
22
- self.embedding_function = embedding_function
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def add_documents(
25
  self,
26
  documents: List[str],
27
- embeddings: List[List[float]] = None
 
 
28
  ) -> None:
29
  """
30
  Add documents to the vector store
31
 
32
  Args:
33
  documents (List[str]): List of document texts
34
- embeddings (List[List[float]], optional): Pre-computed embeddings
 
 
35
  """
36
- if not embeddings:
37
- embeddings = self.embedding_function(documents)
38
-
39
- # Generate unique IDs
40
- ids = [f"doc_{i}" for i in range(len(documents))]
41
-
42
- self.collection.add(
43
- documents=documents,
44
- embeddings=embeddings,
45
- ids=ids
46
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def similarity_search(
49
  self,
50
  query_embedding: List[float],
51
- top_k: int = 3
 
52
  ) -> List[str]:
53
  """
54
  Perform similarity search
@@ -56,13 +98,114 @@ class ChromaVectorStore(BaseVectorStore):
56
  Args:
57
  query_embedding (List[float]): Embedding of the query
58
  top_k (int): Number of top similar documents to retrieve
 
59
 
60
  Returns:
61
  List[str]: List of most similar documents
62
  """
63
- results = self.collection.query(
64
- query_embeddings=[query_embedding],
65
- n_results=top_k
66
- )
67
-
68
- return results.get('documents', [[]])[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # src/vectorstores/chroma_vectorstore.py
2
  import chromadb
3
+ from typing import List, Callable, Any, Dict, Optional
4
+ from chromadb.config import Settings
5
+ import logging
6
 
7
  from .base_vectorstore import BaseVectorStore
8
 
 
10
  def __init__(
11
  self,
12
  embedding_function: Callable[[List[str]], List[List[float]]],
13
+ persist_directory: str = './chroma_db',
14
+ collection_name: str = "documents",
15
+ client_settings: Optional[Dict[str, Any]] = None
16
  ):
17
  """
18
  Initialize Chroma Vector Store
 
20
  Args:
21
  embedding_function (Callable): Function to generate embeddings
22
  persist_directory (str): Directory to persist the vector store
23
+ collection_name (str): Name of the collection to use
24
+ client_settings (Optional[Dict[str, Any]]): Additional settings for ChromaDB client
25
  """
26
+ try:
27
+ settings = Settings(
28
+ persist_directory=persist_directory,
29
+ **(client_settings or {})
30
+ )
31
+ self.client = chromadb.PersistentClient(settings=settings)
32
+ self.collection = self.client.get_or_create_collection(
33
+ name=collection_name,
34
+ metadata={"hnsw:space": "cosine"} # Using cosine similarity by default
35
+ )
36
+ self.embedding_function = embedding_function
37
+ except Exception as e:
38
+ logging.error(f"Error initializing ChromaDB: {str(e)}")
39
+ raise
40
 
41
  def add_documents(
42
  self,
43
  documents: List[str],
44
+ embeddings: Optional[List[List[float]]] = None,
45
+ metadatas: Optional[List[Dict[str, Any]]] = None,
46
+ ids: Optional[List[str]] = None
47
  ) -> None:
48
  """
49
  Add documents to the vector store
50
 
51
  Args:
52
  documents (List[str]): List of document texts
53
+ embeddings (Optional[List[List[float]]]): Pre-computed embeddings
54
+ metadatas (Optional[List[Dict[str, Any]]]): Metadata for each document
55
+ ids (Optional[List[str]]): Custom IDs for the documents
56
  """
57
+ try:
58
+ if not documents:
59
+ logging.warning("No documents provided to add_documents")
60
+ return
61
+
62
+ if not embeddings:
63
+ embeddings = self.embedding_function(documents)
64
+
65
+ if len(documents) != len(embeddings):
66
+ raise ValueError("Number of documents and embeddings must match")
67
+
68
+ # Use provided IDs or generate them
69
+ doc_ids = ids if ids is not None else [f"doc_{i}" for i in range(len(documents))]
70
+
71
+ # Prepare add parameters
72
+ add_params = {
73
+ "documents": documents,
74
+ "embeddings": embeddings,
75
+ "ids": doc_ids
76
+ }
77
+
78
+ # Only include metadatas if provided
79
+ if metadatas is not None:
80
+ if len(metadatas) != len(documents):
81
+ raise ValueError("Number of documents and metadatas must match")
82
+ add_params["metadatas"] = metadatas
83
+
84
+ self.collection.add(**add_params)
85
+ except Exception as e:
86
+ logging.error(f"Error adding documents to ChromaDB: {str(e)}")
87
+ raise
88
 
89
  def similarity_search(
90
  self,
91
  query_embedding: List[float],
92
+ top_k: int = 3,
93
+ **kwargs
94
  ) -> List[str]:
95
  """
96
  Perform similarity search
 
98
  Args:
99
  query_embedding (List[float]): Embedding of the query
100
  top_k (int): Number of top similar documents to retrieve
101
+ **kwargs: Additional search parameters
102
 
103
  Returns:
104
  List[str]: List of most similar documents
105
  """
106
+ try:
107
+ results = self.collection.query(
108
+ query_embeddings=[query_embedding],
109
+ n_results=top_k,
110
+ **kwargs
111
+ )
112
+
113
+ # Handle the case where no results are found
114
+ if not results or 'documents' not in results:
115
+ return []
116
+
117
+ return results.get('documents', [[]])[0]
118
+ except Exception as e:
119
+ logging.error(f"Error performing similarity search in ChromaDB: {str(e)}")
120
+ raise
121
+
122
+ def get_all_documents(
123
+ self,
124
+ include_embeddings: bool = False
125
+ ) -> List[Dict[str, Any]]:
126
+ """
127
+ Retrieve all documents from the vector store
128
+ """
129
+ try:
130
+ include = ["documents", "metadatas"]
131
+ if include_embeddings:
132
+ include.append("embeddings")
133
+
134
+ results = self.collection.get(
135
+ include=include
136
+ )
137
+
138
+ if not results or 'documents' not in results:
139
+ return []
140
+
141
+ documents = []
142
+ for i in range(len(results['documents'])):
143
+ doc = {
144
+ 'id': str(i), # Generate sequential IDs
145
+ 'text': results['documents'][i],
146
+ }
147
+
148
+ if include_embeddings and 'embeddings' in results:
149
+ doc['embedding'] = results['embeddings'][i]
150
+
151
+ if 'metadatas' in results and results['metadatas'][i]:
152
+ doc['metadata'] = results['metadatas'][i]
153
+
154
+ # Use document_id from metadata if available
155
+ if 'document_id' in results['metadatas'][i]:
156
+ doc['id'] = results['metadatas'][i]['document_id']
157
+
158
+ documents.append(doc)
159
+
160
+ return documents
161
+ except Exception as e:
162
+ logging.error(f"Error retrieving documents from ChromaDB: {str(e)}")
163
+ raise
164
+
165
+ def get_document_chunks(self, document_id: str) -> List[Dict[str, Any]]:
166
+ """Retrieve all chunks for a specific document"""
167
+ try:
168
+ results = self.collection.get(
169
+ where={"document_id": document_id},
170
+ include=["documents", "metadatas"]
171
+ )
172
+
173
+ if not results or 'documents' not in results:
174
+ return []
175
+
176
+ chunks = []
177
+ for i in range(len(results['documents'])):
178
+ chunk = {
179
+ 'text': results['documents'][i],
180
+ 'metadata': results['metadatas'][i] if results.get('metadatas') else None
181
+ }
182
+ chunks.append(chunk)
183
+
184
+ # Sort by chunk_index if available
185
+ chunks.sort(key=lambda x: x.get('metadata', {}).get('chunk_index', 0))
186
+
187
+ return chunks
188
+ except Exception as e:
189
+ logging.error(f"Error retrieving document chunks: {str(e)}")
190
+ raise
191
+
192
+ def delete_document(self, document_id: str) -> None:
193
+ """Delete all chunks associated with a document_id"""
194
+ try:
195
+ # Get all chunks with the given document_id
196
+ results = self.collection.get(
197
+ where={"document_id": document_id},
198
+ include=["metadatas"]
199
+ )
200
+
201
+ if not results or 'ids' not in results:
202
+ logging.warning(f"No document found with ID: {document_id}")
203
+ return
204
+
205
+ # Delete all chunks associated with the document
206
+ chunk_ids = [f"{document_id}-chunk-{i}" for i in range(len(results['metadatas']))]
207
+ self.collection.delete(ids=chunk_ids)
208
+
209
+ except Exception as e:
210
+ logging.error(f"Error deleting document {document_id} from ChromaDB: {str(e)}")
211
+ raise