Spaces:
Running
Running
# src/db/mongodb_store.py | |
from motor.motor_asyncio import AsyncIOMotorClient | |
from datetime import datetime | |
import json | |
from typing import List, Dict, Optional, Any | |
from bson import ObjectId | |
class MongoDBStore: | |
def __init__(self, mongo_uri: str = "mongodb://localhost:27017"): | |
"""Initialize MongoDB connection""" | |
self.client = AsyncIOMotorClient(mongo_uri) | |
self.db = self.client.rag_chatbot | |
self.chat_history = self.db.chat_history | |
self.documents = self.db.documents # Collection for original documents | |
async def store_document( | |
self, | |
document_id: str, | |
filename: str, | |
content: str, | |
content_type: str, | |
file_size: int | |
) -> str: | |
"""Store original document in MongoDB""" | |
document = { | |
"document_id": document_id, | |
"filename": filename, | |
"content": content, | |
"content_type": content_type, | |
"file_size": file_size, | |
"upload_timestamp": datetime.now() | |
} | |
await self.documents.insert_one(document) | |
return document_id | |
async def get_document(self, document_id: str) -> Optional[Dict]: | |
"""Retrieve document by ID""" | |
return await self.documents.find_one( | |
{"document_id": document_id}, | |
{"_id": 0} # Exclude MongoDB's _id | |
) | |
async def get_all_documents(self) -> List[Dict]: | |
"""Retrieve all documents""" | |
cursor = self.documents.find({}, {"_id": 0}) | |
return await cursor.to_list(length=None) | |
async def store_message( | |
self, | |
conversation_id: str, | |
query: str, | |
response: str, | |
context: List[str], | |
sources: List[Dict], | |
llm_provider: str | |
) -> str: | |
"""Store chat message in MongoDB""" | |
document = { | |
"conversation_id": conversation_id, | |
"timestamp": datetime.now(), | |
"query": query, | |
"response": response, | |
"context": context, | |
"sources": sources, | |
"llm_provider": llm_provider, | |
"feedback": None, | |
"rating": None | |
} | |
result = await self.chat_history.insert_one(document) | |
return str(result.inserted_id) | |
async def get_conversation_history(self, conversation_id: str) -> List[Dict]: | |
"""Retrieve conversation history""" | |
cursor = self.chat_history.find( | |
{"conversation_id": conversation_id} | |
).sort("timestamp", 1) | |
history = [] | |
async for document in cursor: | |
document["_id"] = str(document["_id"]) | |
history.append(document) | |
return history | |
async def update_feedback( | |
self, | |
conversation_id: str, | |
feedback: Optional[str], | |
rating: Optional[int] | |
) -> bool: | |
"""Update feedback for a conversation""" | |
result = await self.chat_history.update_many( | |
{"conversation_id": conversation_id}, | |
{ | |
"$set": { | |
"feedback": feedback, | |
"rating": rating | |
} | |
} | |
) | |
return result.modified_count > 0 | |
async def get_messages_for_summary(self, conversation_id: str) -> List[Dict]: | |
"""Get messages in format suitable for summarization""" | |
cursor = self.chat_history.find( | |
{"conversation_id": conversation_id} | |
).sort("timestamp", 1) | |
messages = [] | |
async for doc in cursor: | |
messages.append({ | |
'role': 'user' if doc['query'] else 'assistant', | |
'content': doc['query'] or doc['response'], | |
'timestamp': doc['timestamp'], | |
'sources': doc['sources'] | |
}) | |
return messages |