Spaces:
Running
Running
Commit
·
9700f95
1
Parent(s):
406f1ed
Working chat with context and history
Browse files- config/__pycache__/config.cpython-312.pyc +0 -0
- config/config.py +3 -0
- src/__pycache__/main.cpython-312.pyc +0 -0
- src/agents/__pycache__/rag_agent.cpython-312.pyc +0 -0
- src/agents/rag_agent.py +134 -107
- src/db/__pycache__/mongodb_store.cpython-312.pyc +0 -0
- src/db/mongodb_store.py +209 -21
- src/main.py +40 -13
- src/models/__pycache__/chat.cpython-312.pyc +0 -0
- src/models/__pycache__/rag.cpython-312.pyc +0 -0
- src/models/chat.py +16 -2
- src/models/rag.py +28 -2
- src/utils/__pycache__/conversation_manager.cpython-312.pyc +0 -0
- src/utils/conversation_manager.py +111 -0
config/__pycache__/config.cpython-312.pyc
CHANGED
Binary files a/config/__pycache__/config.cpython-312.pyc and b/config/__pycache__/config.cpython-312.pyc differ
|
|
config/config.py
CHANGED
@@ -26,6 +26,9 @@ class Settings:
|
|
26 |
# MongoDB Configuration
|
27 |
MONGODB_URI = os.getenv('MONGODB_URI', 'mongodb://localhost:27017')
|
28 |
|
|
|
|
|
|
|
29 |
# Application Configuration
|
30 |
DEBUG = os.getenv('DEBUG', 'False') == 'True'
|
31 |
|
|
|
26 |
# MongoDB Configuration
|
27 |
MONGODB_URI = os.getenv('MONGODB_URI', 'mongodb://localhost:27017')
|
28 |
|
29 |
+
# Feedback Configuration
|
30 |
+
MAX_RATING = int(os.getenv('MAX_RATING', '5'))
|
31 |
+
|
32 |
# Application Configuration
|
33 |
DEBUG = os.getenv('DEBUG', 'False') == 'True'
|
34 |
|
src/__pycache__/main.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/main.cpython-312.pyc and b/src/__pycache__/main.cpython-312.pyc differ
|
|
src/agents/__pycache__/rag_agent.cpython-312.pyc
CHANGED
Binary files a/src/agents/__pycache__/rag_agent.cpython-312.pyc and b/src/agents/__pycache__/rag_agent.cpython-312.pyc differ
|
|
src/agents/rag_agent.py
CHANGED
@@ -1,136 +1,163 @@
|
|
1 |
# src/agents/rag_agent.py
|
2 |
from typing import List, Optional, Tuple, Dict
|
|
|
3 |
|
4 |
from ..llms.base_llm import BaseLLM
|
5 |
from src.embeddings.base_embedding import BaseEmbedding
|
6 |
from src.vectorstores.base_vectorstore import BaseVectorStore
|
7 |
-
from src.utils.
|
|
|
8 |
from src.models.rag import RAGResponse
|
|
|
9 |
|
10 |
class RAGAgent:
|
11 |
def __init__(
|
12 |
-
self,
|
13 |
-
llm: BaseLLM,
|
14 |
-
embedding: BaseEmbedding,
|
15 |
-
vector_store: BaseVectorStore
|
|
|
|
|
|
|
16 |
):
|
17 |
-
self.llm = llm
|
18 |
-
self.embedding = embedding
|
19 |
-
self.vector_store = vector_store
|
20 |
-
|
21 |
-
def _convert_metadata_to_strings(self, metadata: Dict) -> Dict:
|
22 |
-
"""Convert numeric metadata values to strings"""
|
23 |
-
converted = {}
|
24 |
-
for key, value in metadata.items():
|
25 |
-
if isinstance(value, (int, float)):
|
26 |
-
converted[key] = str(value)
|
27 |
-
else:
|
28 |
-
converted[key] = value
|
29 |
-
return converted
|
30 |
-
|
31 |
-
def retrieve_context(
|
32 |
-
self,
|
33 |
-
query: str,
|
34 |
-
top_k: int = 3
|
35 |
-
) -> Tuple[List[str], List[Dict], Optional[List[float]]]:
|
36 |
"""
|
37 |
-
|
38 |
|
39 |
Args:
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
45 |
"""
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
)
|
54 |
-
|
55 |
-
# Extract documents, sources, and scores from results
|
56 |
-
documents = [doc['text'] for doc in results]
|
57 |
-
|
58 |
-
# Convert numeric metadata values to strings
|
59 |
-
sources = [self._convert_metadata_to_strings(doc['metadata']) for doc in results]
|
60 |
-
|
61 |
-
scores = [doc['score'] for doc in results if doc.get('score') is not None]
|
62 |
-
|
63 |
-
# Only return scores if we have them for all documents
|
64 |
-
if len(scores) != len(documents):
|
65 |
-
scores = None
|
66 |
-
|
67 |
-
return documents, sources, scores
|
68 |
-
|
69 |
async def generate_response(
|
70 |
-
self,
|
71 |
query: str,
|
|
|
72 |
temperature: float = 0.7,
|
73 |
max_tokens: Optional[int] = None,
|
74 |
context_docs: Optional[List[str]] = None
|
75 |
) -> RAGResponse:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
"""
|
77 |
-
|
78 |
|
79 |
Args:
|
80 |
-
query (str):
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
Returns:
|
86 |
-
|
|
|
87 |
"""
|
88 |
-
#
|
89 |
-
if
|
90 |
-
|
|
|
|
|
|
|
|
|
91 |
else:
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
temperature=temperature,
|
102 |
-
max_tokens=max_tokens
|
103 |
-
)
|
104 |
-
|
105 |
-
return RAGResponse(
|
106 |
-
response=response,
|
107 |
-
context_docs=context_docs,
|
108 |
-
sources=sources,
|
109 |
-
scores=scores
|
110 |
)
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
"""
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
Based on the context, please provide a comprehensive and accurate response.
|
136 |
-
"""
|
|
|
1 |
# src/agents/rag_agent.py
|
2 |
from typing import List, Optional, Tuple, Dict
|
3 |
+
import uuid
|
4 |
|
5 |
from ..llms.base_llm import BaseLLM
|
6 |
from src.embeddings.base_embedding import BaseEmbedding
|
7 |
from src.vectorstores.base_vectorstore import BaseVectorStore
|
8 |
+
from src.utils.conversation_manager import ConversationManager
|
9 |
+
from src.db.mongodb_store import MongoDBStore
|
10 |
from src.models.rag import RAGResponse
|
11 |
+
from src.utils.logger import logger
|
12 |
|
13 |
class RAGAgent:
|
14 |
def __init__(
|
15 |
+
self,
|
16 |
+
llm: BaseLLM,
|
17 |
+
embedding: BaseEmbedding,
|
18 |
+
vector_store: BaseVectorStore,
|
19 |
+
mongodb: MongoDBStore,
|
20 |
+
max_history_tokens: int = 4000,
|
21 |
+
max_history_messages: int = 10
|
22 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
"""
|
24 |
+
Initialize RAG Agent
|
25 |
|
26 |
Args:
|
27 |
+
llm (BaseLLM): Language model instance
|
28 |
+
embedding (BaseEmbedding): Embedding model instance
|
29 |
+
vector_store (BaseVectorStore): Vector store instance
|
30 |
+
mongodb (MongoDBStore): MongoDB store instance
|
31 |
+
max_history_tokens (int): Maximum tokens in conversation history
|
32 |
+
max_history_messages (int): Maximum messages to keep in history
|
33 |
"""
|
34 |
+
self.llm = llm
|
35 |
+
self.embedding = embedding
|
36 |
+
self.vector_store = vector_store
|
37 |
+
self.mongodb = mongodb
|
38 |
+
self.conversation_manager = ConversationManager(
|
39 |
+
max_tokens=max_history_tokens,
|
40 |
+
max_messages=max_history_messages
|
41 |
)
|
42 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
async def generate_response(
|
44 |
+
self,
|
45 |
query: str,
|
46 |
+
conversation_id: Optional[str] = None,
|
47 |
temperature: float = 0.7,
|
48 |
max_tokens: Optional[int] = None,
|
49 |
context_docs: Optional[List[str]] = None
|
50 |
) -> RAGResponse:
|
51 |
+
"""Generate a response using RAG with conversation history"""
|
52 |
+
try:
|
53 |
+
# Create new conversation if no ID provided
|
54 |
+
if not conversation_id:
|
55 |
+
conversation_id = str(uuid.uuid4())
|
56 |
+
await self.mongodb.create_conversation(conversation_id)
|
57 |
+
|
58 |
+
# Get conversation history
|
59 |
+
history = await self.mongodb.get_recent_messages(
|
60 |
+
conversation_id,
|
61 |
+
limit=self.conversation_manager.max_messages
|
62 |
+
)
|
63 |
+
|
64 |
+
# Get relevant history within token limits
|
65 |
+
relevant_history = self.conversation_manager.get_relevant_history(
|
66 |
+
messages=history,
|
67 |
+
current_query=query
|
68 |
+
) if history else []
|
69 |
+
|
70 |
+
# Retrieve context if not provided
|
71 |
+
if not context_docs:
|
72 |
+
context_docs, sources, scores = await self.retrieve_context(
|
73 |
+
query,
|
74 |
+
conversation_history=relevant_history
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
sources = None
|
78 |
+
scores = None
|
79 |
+
|
80 |
+
# Generate prompt with context and history
|
81 |
+
augmented_prompt = self.conversation_manager.generate_prompt_with_history(
|
82 |
+
current_query=query,
|
83 |
+
history=relevant_history,
|
84 |
+
context_docs=context_docs
|
85 |
+
)
|
86 |
+
|
87 |
+
# Generate response using LLM
|
88 |
+
response = self.llm.generate(
|
89 |
+
augmented_prompt,
|
90 |
+
temperature=temperature,
|
91 |
+
max_tokens=max_tokens
|
92 |
+
)
|
93 |
+
|
94 |
+
return RAGResponse(
|
95 |
+
response=response,
|
96 |
+
context_docs=context_docs,
|
97 |
+
sources=sources,
|
98 |
+
scores=scores
|
99 |
+
)
|
100 |
+
|
101 |
+
except Exception as e:
|
102 |
+
logger.error(f"Error generating response: {str(e)}")
|
103 |
+
raise
|
104 |
+
|
105 |
+
async def retrieve_context(
|
106 |
+
self,
|
107 |
+
query: str,
|
108 |
+
conversation_history: Optional[List[Dict]] = None,
|
109 |
+
top_k: int = 3
|
110 |
+
) -> Tuple[List[str], List[Dict], Optional[List[float]]]:
|
111 |
"""
|
112 |
+
Retrieve context with conversation history enhancement
|
113 |
|
114 |
Args:
|
115 |
+
query (str): Current query
|
116 |
+
conversation_history (Optional[List[Dict]]): Recent conversation history
|
117 |
+
top_k (int): Number of documents to retrieve
|
118 |
+
|
|
|
119 |
Returns:
|
120 |
+
Tuple[List[str], List[Dict], Optional[List[float]]]:
|
121 |
+
Retrieved documents, sources, and scores
|
122 |
"""
|
123 |
+
# Enhance query with conversation history
|
124 |
+
if conversation_history:
|
125 |
+
recent_queries = [
|
126 |
+
msg['query'] for msg in conversation_history[-2:]
|
127 |
+
if msg.get('query')
|
128 |
+
]
|
129 |
+
enhanced_query = " ".join([*recent_queries, query])
|
130 |
else:
|
131 |
+
enhanced_query = query
|
132 |
+
|
133 |
+
# Embed the enhanced query
|
134 |
+
query_embedding = self.embedding.embed_query(enhanced_query)
|
135 |
+
|
136 |
+
# Retrieve similar documents
|
137 |
+
results = self.vector_store.similarity_search(
|
138 |
+
query_embedding,
|
139 |
+
top_k=top_k
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
)
|
141 |
+
|
142 |
+
# Process results
|
143 |
+
documents = [doc['text'] for doc in results]
|
144 |
+
sources = [self._convert_metadata_to_strings(doc['metadata'])
|
145 |
+
for doc in results]
|
146 |
+
scores = [doc['score'] for doc in results
|
147 |
+
if doc.get('score') is not None]
|
148 |
+
|
149 |
+
# Return scores only if available for all documents
|
150 |
+
if len(scores) != len(documents):
|
151 |
+
scores = None
|
152 |
+
|
153 |
+
return documents, sources, scores
|
154 |
+
|
155 |
+
def _convert_metadata_to_strings(self, metadata: Dict) -> Dict:
|
156 |
+
"""Convert numeric metadata values to strings"""
|
157 |
+
converted = {}
|
158 |
+
for key, value in metadata.items():
|
159 |
+
if isinstance(value, (int, float)):
|
160 |
+
converted[key] = str(value)
|
161 |
+
else:
|
162 |
+
converted[key] = value
|
163 |
+
return converted
|
|
|
|
|
|
src/db/__pycache__/mongodb_store.cpython-312.pyc
CHANGED
Binary files a/src/db/__pycache__/mongodb_store.cpython-312.pyc and b/src/db/__pycache__/mongodb_store.cpython-312.pyc differ
|
|
src/db/mongodb_store.py
CHANGED
@@ -10,8 +10,10 @@ class MongoDBStore:
|
|
10 |
self.client = AsyncIOMotorClient(mongo_uri)
|
11 |
self.db = self.client.db_chatbot
|
12 |
self.chat_history = self.db.chat_history
|
|
|
13 |
self.documents = self.db.knowledge_base
|
14 |
|
|
|
15 |
async def store_document(
|
16 |
self,
|
17 |
document_id: str,
|
@@ -56,6 +58,60 @@ class MongoDBStore:
|
|
56 |
)
|
57 |
return await cursor.to_list(length=None)
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
async def store_message(
|
60 |
self,
|
61 |
conversation_id: str,
|
@@ -66,23 +122,52 @@ class MongoDBStore:
|
|
66 |
llm_provider: str
|
67 |
) -> str:
|
68 |
"""Store chat message in MongoDB"""
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
"conversation_id": conversation_id,
|
71 |
"timestamp": datetime.now(),
|
72 |
-
"
|
73 |
-
"
|
|
|
|
|
74 |
"context": context,
|
75 |
"sources": sources,
|
76 |
"llm_provider": llm_provider,
|
77 |
"feedback": None,
|
78 |
"rating": None
|
79 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
result = await self.chat_history.insert_one(document)
|
82 |
return str(result.inserted_id)
|
83 |
-
|
84 |
async def get_conversation_history(self, conversation_id: str) -> List[Dict]:
|
85 |
-
"""Retrieve conversation history"""
|
86 |
cursor = self.chat_history.find(
|
87 |
{"conversation_id": conversation_id}
|
88 |
).sort("timestamp", 1)
|
@@ -94,25 +179,77 @@ class MongoDBStore:
|
|
94 |
|
95 |
return history
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
async def update_feedback(
|
98 |
self,
|
99 |
conversation_id: str,
|
100 |
feedback: Optional[str],
|
101 |
rating: Optional[int]
|
102 |
) -> bool:
|
103 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
result = await self.chat_history.update_many(
|
105 |
{"conversation_id": conversation_id},
|
106 |
-
{
|
107 |
-
"$set": {
|
108 |
-
"feedback": feedback,
|
109 |
-
"rating": rating
|
110 |
-
}
|
111 |
-
}
|
112 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
return result.modified_count > 0
|
114 |
|
115 |
-
async def get_messages_for_summary(
|
|
|
|
|
|
|
116 |
"""Get messages in format suitable for summarization"""
|
117 |
cursor = self.chat_history.find(
|
118 |
{"conversation_id": conversation_id}
|
@@ -120,16 +257,67 @@ class MongoDBStore:
|
|
120 |
|
121 |
messages = []
|
122 |
async for doc in cursor:
|
|
|
|
|
123 |
messages.append({
|
124 |
-
'role':
|
125 |
-
'content':
|
126 |
-
'timestamp':
|
127 |
-
'sources':
|
128 |
})
|
129 |
|
130 |
return messages
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
return result.deleted_count > 0
|
|
|
10 |
self.client = AsyncIOMotorClient(mongo_uri)
|
11 |
self.db = self.client.db_chatbot
|
12 |
self.chat_history = self.db.chat_history
|
13 |
+
self.conversations = self.db.conversations
|
14 |
self.documents = self.db.knowledge_base
|
15 |
|
16 |
+
# Document-related methods
|
17 |
async def store_document(
|
18 |
self,
|
19 |
document_id: str,
|
|
|
58 |
)
|
59 |
return await cursor.to_list(length=None)
|
60 |
|
61 |
+
async def delete_document(self, document_id: str) -> bool:
|
62 |
+
"""Delete document from MongoDB"""
|
63 |
+
result = await self.documents.delete_one({"document_id": document_id})
|
64 |
+
return result.deleted_count > 0
|
65 |
+
|
66 |
+
# Conversation and chat history methods
|
67 |
+
async def create_conversation(
|
68 |
+
self,
|
69 |
+
conversation_id: str,
|
70 |
+
metadata: Optional[Dict] = None
|
71 |
+
) -> str:
|
72 |
+
"""Create a new conversation"""
|
73 |
+
conversation = {
|
74 |
+
"conversation_id": conversation_id,
|
75 |
+
"created_at": datetime.now(),
|
76 |
+
"last_updated": datetime.now(),
|
77 |
+
"message_count": 0,
|
78 |
+
"metadata": metadata or {}
|
79 |
+
}
|
80 |
+
|
81 |
+
await self.conversations.insert_one(conversation)
|
82 |
+
return conversation_id
|
83 |
+
|
84 |
+
async def get_conversation_metadata(
|
85 |
+
self,
|
86 |
+
conversation_id: str
|
87 |
+
) -> Optional[Dict]:
|
88 |
+
"""Get conversation metadata"""
|
89 |
+
result = await self.conversations.find_one(
|
90 |
+
{"conversation_id": conversation_id}
|
91 |
+
)
|
92 |
+
if result:
|
93 |
+
result["_id"] = str(result["_id"])
|
94 |
+
return result
|
95 |
+
|
96 |
+
async def update_conversation_metadata(
|
97 |
+
self,
|
98 |
+
conversation_id: str,
|
99 |
+
metadata: Dict
|
100 |
+
) -> bool:
|
101 |
+
"""Update conversation metadata"""
|
102 |
+
result = await self.conversations.update_one(
|
103 |
+
{"conversation_id": conversation_id},
|
104 |
+
{
|
105 |
+
"$set": {
|
106 |
+
"metadata": metadata,
|
107 |
+
"last_updated": datetime.now()
|
108 |
+
}
|
109 |
+
}
|
110 |
+
)
|
111 |
+
return result.modified_count > 0
|
112 |
+
|
113 |
+
# Update the store_message method:
|
114 |
+
|
115 |
async def store_message(
|
116 |
self,
|
117 |
conversation_id: str,
|
|
|
122 |
llm_provider: str
|
123 |
) -> str:
|
124 |
"""Store chat message in MongoDB"""
|
125 |
+
# Store user message
|
126 |
+
user_message = {
|
127 |
+
"conversation_id": conversation_id,
|
128 |
+
"timestamp": datetime.now(),
|
129 |
+
"role": "user",
|
130 |
+
"content": query,
|
131 |
+
"query": query, # Keep for backward compatibility
|
132 |
+
"response": None,
|
133 |
+
"context": context,
|
134 |
+
"sources": sources,
|
135 |
+
"llm_provider": llm_provider,
|
136 |
+
"feedback": None,
|
137 |
+
"rating": None
|
138 |
+
}
|
139 |
+
await self.chat_history.insert_one(user_message)
|
140 |
+
|
141 |
+
# Store assistant message
|
142 |
+
assistant_message = {
|
143 |
"conversation_id": conversation_id,
|
144 |
"timestamp": datetime.now(),
|
145 |
+
"role": "assistant",
|
146 |
+
"content": response,
|
147 |
+
"query": None,
|
148 |
+
"response": response, # Keep for backward compatibility
|
149 |
"context": context,
|
150 |
"sources": sources,
|
151 |
"llm_provider": llm_provider,
|
152 |
"feedback": None,
|
153 |
"rating": None
|
154 |
}
|
155 |
+
result = await self.chat_history.insert_one(assistant_message)
|
156 |
+
|
157 |
+
# Update conversation metadata
|
158 |
+
await self.conversations.update_one(
|
159 |
+
{"conversation_id": conversation_id},
|
160 |
+
{
|
161 |
+
"$set": {"last_updated": datetime.now()},
|
162 |
+
"$inc": {"message_count": 2} # Increment by 2 since we store both messages
|
163 |
+
},
|
164 |
+
upsert=True
|
165 |
+
)
|
166 |
|
|
|
167 |
return str(result.inserted_id)
|
168 |
+
|
169 |
async def get_conversation_history(self, conversation_id: str) -> List[Dict]:
|
170 |
+
"""Retrieve complete conversation history"""
|
171 |
cursor = self.chat_history.find(
|
172 |
{"conversation_id": conversation_id}
|
173 |
).sort("timestamp", 1)
|
|
|
179 |
|
180 |
return history
|
181 |
|
182 |
+
async def get_recent_messages(
|
183 |
+
self,
|
184 |
+
conversation_id: str,
|
185 |
+
limit: int = 5
|
186 |
+
) -> List[Dict]:
|
187 |
+
"""Get most recent messages from conversation"""
|
188 |
+
cursor = self.chat_history.find(
|
189 |
+
{"conversation_id": conversation_id}
|
190 |
+
).sort("timestamp", -1).limit(limit * 2) # Multiply limit by 2 to account for user-assistant pairs
|
191 |
+
|
192 |
+
messages = []
|
193 |
+
async for doc in cursor:
|
194 |
+
messages.append(self._format_message(doc))
|
195 |
+
|
196 |
+
return list(reversed(messages))
|
197 |
+
|
198 |
async def update_feedback(
|
199 |
self,
|
200 |
conversation_id: str,
|
201 |
feedback: Optional[str],
|
202 |
rating: Optional[int]
|
203 |
) -> bool:
|
204 |
+
"""
|
205 |
+
Update feedback for a conversation
|
206 |
+
|
207 |
+
Args:
|
208 |
+
conversation_id (str): Conversation ID
|
209 |
+
feedback (Optional[str]): Feedback text
|
210 |
+
rating (Optional[int]): Numeric rating
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
bool: True if update successful
|
214 |
+
"""
|
215 |
+
update_fields = {}
|
216 |
+
|
217 |
+
if feedback is not None:
|
218 |
+
update_fields["feedback"] = feedback
|
219 |
+
|
220 |
+
if rating is not None:
|
221 |
+
from config.config import settings
|
222 |
+
formatted_rating = f"{rating}/{settings.MAX_RATING}"
|
223 |
+
update_fields.update({
|
224 |
+
"rating": rating, # Store numeric value
|
225 |
+
"formatted_rating": formatted_rating # Store formatted string
|
226 |
+
})
|
227 |
+
|
228 |
+
if not update_fields:
|
229 |
+
return False
|
230 |
+
|
231 |
result = await self.chat_history.update_many(
|
232 |
{"conversation_id": conversation_id},
|
233 |
+
{"$set": update_fields}
|
|
|
|
|
|
|
|
|
|
|
234 |
)
|
235 |
+
|
236 |
+
# Also update conversation metadata
|
237 |
+
if result.modified_count > 0:
|
238 |
+
await self.update_conversation_metadata(
|
239 |
+
conversation_id,
|
240 |
+
{
|
241 |
+
"last_feedback": datetime.now(),
|
242 |
+
"last_rating": rating if rating is not None else None,
|
243 |
+
"formatted_rating": formatted_rating if rating is not None else None
|
244 |
+
}
|
245 |
+
)
|
246 |
+
|
247 |
return result.modified_count > 0
|
248 |
|
249 |
+
async def get_messages_for_summary(
|
250 |
+
self,
|
251 |
+
conversation_id: str
|
252 |
+
) -> List[Dict]:
|
253 |
"""Get messages in format suitable for summarization"""
|
254 |
cursor = self.chat_history.find(
|
255 |
{"conversation_id": conversation_id}
|
|
|
257 |
|
258 |
messages = []
|
259 |
async for doc in cursor:
|
260 |
+
formatted = self._format_message(doc)
|
261 |
+
# For summary, we only need specific fields
|
262 |
messages.append({
|
263 |
+
'role': formatted['role'],
|
264 |
+
'content': formatted['content'],
|
265 |
+
'timestamp': formatted['timestamp'],
|
266 |
+
'sources': formatted['sources']
|
267 |
})
|
268 |
|
269 |
return messages
|
270 |
|
271 |
+
|
272 |
+
def _format_message(self, doc: Dict) -> Dict:
|
273 |
+
"""Helper method to format message documents consistently"""
|
274 |
+
return {
|
275 |
+
"_id": str(doc["_id"]) if "_id" in doc else None,
|
276 |
+
"conversation_id": doc.get("conversation_id"),
|
277 |
+
"timestamp": doc.get("timestamp"),
|
278 |
+
"role": doc.get("role", "user" if doc.get("query") else "assistant"),
|
279 |
+
"content": doc.get("content", doc.get("query") or doc.get("response", "")),
|
280 |
+
"context": doc.get("context", []),
|
281 |
+
"sources": doc.get("sources", []),
|
282 |
+
"llm_provider": doc.get("llm_provider"),
|
283 |
+
"feedback": doc.get("feedback"),
|
284 |
+
"rating": doc.get("rating")
|
285 |
+
}
|
286 |
+
|
287 |
+
# Vector store related methods
|
288 |
+
async def store_vector_metadata(
|
289 |
+
self,
|
290 |
+
document_id: str,
|
291 |
+
chunk_id: str,
|
292 |
+
metadata: Dict[str, Any]
|
293 |
+
) -> str:
|
294 |
+
"""Store vector chunk metadata"""
|
295 |
+
vector_metadata = {
|
296 |
+
"document_id": document_id,
|
297 |
+
"chunk_id": chunk_id,
|
298 |
+
"metadata": metadata,
|
299 |
+
"created_at": datetime.now()
|
300 |
+
}
|
301 |
+
|
302 |
+
result = await self.db.vector_metadata.insert_one(vector_metadata)
|
303 |
+
return str(result.inserted_id)
|
304 |
+
|
305 |
+
async def get_vector_metadata(
|
306 |
+
self,
|
307 |
+
document_id: str
|
308 |
+
) -> List[Dict]:
|
309 |
+
"""Get vector metadata for a document"""
|
310 |
+
cursor = self.db.vector_metadata.find(
|
311 |
+
{"document_id": document_id}
|
312 |
+
)
|
313 |
+
return await cursor.to_list(length=None)
|
314 |
+
|
315 |
+
async def delete_vector_metadata(
|
316 |
+
self,
|
317 |
+
document_id: str
|
318 |
+
) -> bool:
|
319 |
+
"""Delete vector metadata for a document"""
|
320 |
+
result = await self.db.vector_metadata.delete_many(
|
321 |
+
{"document_id": document_id}
|
322 |
+
)
|
323 |
return result.deleted_count > 0
|
src/main.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
|
3 |
from fastapi.responses import StreamingResponse, FileResponse
|
4 |
from fastapi.staticfiles import StaticFiles
|
|
|
5 |
from typing import List
|
6 |
import uuid
|
7 |
from datetime import datetime
|
@@ -30,6 +31,14 @@ from config.config import settings
|
|
30 |
|
31 |
app = FastAPI(title="Chatbot API")
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
# Initialize MongoDB
|
34 |
mongodb = MongoDBStore(settings.MONGODB_URI)
|
35 |
|
@@ -192,26 +201,25 @@ async def chat_endpoint(
|
|
192 |
vector_store, embedding_model = await get_vector_store()
|
193 |
llm = get_llm_instance(request.llm_provider)
|
194 |
|
|
|
195 |
rag_agent = RAGAgent(
|
196 |
llm=llm,
|
197 |
embedding=embedding_model,
|
198 |
-
vector_store=vector_store
|
|
|
199 |
)
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
)
|
206 |
-
|
207 |
response = await rag_agent.generate_response(
|
208 |
-
query=
|
|
|
209 |
temperature=request.temperature
|
210 |
)
|
211 |
|
212 |
-
|
213 |
-
|
214 |
-
# Store chat history in MongoDB
|
215 |
await mongodb.store_message(
|
216 |
conversation_id=conversation_id,
|
217 |
query=request.query,
|
@@ -274,6 +282,12 @@ async def submit_feedback(
|
|
274 |
):
|
275 |
"""Submit feedback for a conversation"""
|
276 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
success = await mongodb.update_feedback(
|
278 |
conversation_id=conversation_id,
|
279 |
feedback=feedback_request.feedback,
|
@@ -281,10 +295,23 @@ async def submit_feedback(
|
|
281 |
)
|
282 |
|
283 |
if not success:
|
284 |
-
raise HTTPException(
|
|
|
|
|
|
|
285 |
|
286 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
|
|
|
|
|
288 |
except Exception as e:
|
289 |
logger.error(f"Error submitting feedback: {str(e)}")
|
290 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
2 |
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
|
3 |
from fastapi.responses import StreamingResponse, FileResponse
|
4 |
from fastapi.staticfiles import StaticFiles
|
5 |
+
from fastapi.middleware.cors import CORSMiddleware # Add this import
|
6 |
from typing import List
|
7 |
import uuid
|
8 |
from datetime import datetime
|
|
|
31 |
|
32 |
app = FastAPI(title="Chatbot API")
|
33 |
|
34 |
+
app.add_middleware(
|
35 |
+
CORSMiddleware,
|
36 |
+
allow_origins=["http://localhost:8080"], # Add your frontend URL
|
37 |
+
allow_credentials=True,
|
38 |
+
allow_methods=["*"], # Allows all methods
|
39 |
+
allow_headers=["*"], # Allows all headers
|
40 |
+
)
|
41 |
+
|
42 |
# Initialize MongoDB
|
43 |
mongodb = MongoDBStore(settings.MONGODB_URI)
|
44 |
|
|
|
201 |
vector_store, embedding_model = await get_vector_store()
|
202 |
llm = get_llm_instance(request.llm_provider)
|
203 |
|
204 |
+
# Initialize RAG agent with required MongoDB
|
205 |
rag_agent = RAGAgent(
|
206 |
llm=llm,
|
207 |
embedding=embedding_model,
|
208 |
+
vector_store=vector_store,
|
209 |
+
mongodb=mongodb
|
210 |
)
|
211 |
|
212 |
+
# Use provided conversation ID or create new one
|
213 |
+
conversation_id = request.conversation_id or str(uuid.uuid4())
|
214 |
+
query = request.query + ". The response should be short and to the point. make sure, to not add any irrelevant information. Stick to the point is very very important."
|
215 |
+
# Generate response
|
|
|
|
|
216 |
response = await rag_agent.generate_response(
|
217 |
+
query=query,
|
218 |
+
conversation_id=conversation_id,
|
219 |
temperature=request.temperature
|
220 |
)
|
221 |
|
222 |
+
# Store message in chat history
|
|
|
|
|
223 |
await mongodb.store_message(
|
224 |
conversation_id=conversation_id,
|
225 |
query=request.query,
|
|
|
282 |
):
|
283 |
"""Submit feedback for a conversation"""
|
284 |
try:
|
285 |
+
# Validate conversation exists
|
286 |
+
conversation = await mongodb.get_conversation_metadata(conversation_id)
|
287 |
+
if not conversation:
|
288 |
+
raise HTTPException(status_code=404, detail="Conversation not found")
|
289 |
+
|
290 |
+
# Update feedback
|
291 |
success = await mongodb.update_feedback(
|
292 |
conversation_id=conversation_id,
|
293 |
feedback=feedback_request.feedback,
|
|
|
295 |
)
|
296 |
|
297 |
if not success:
|
298 |
+
raise HTTPException(
|
299 |
+
status_code=500,
|
300 |
+
detail="Failed to update feedback"
|
301 |
+
)
|
302 |
|
303 |
+
return {
|
304 |
+
"status": "success",
|
305 |
+
"message": "Feedback submitted successfully",
|
306 |
+
"data": {
|
307 |
+
"conversation_id": conversation_id,
|
308 |
+
"feedback": feedback_request.feedback,
|
309 |
+
"rating": feedback_request.format_rating()
|
310 |
+
}
|
311 |
+
}
|
312 |
|
313 |
+
except HTTPException:
|
314 |
+
raise
|
315 |
except Exception as e:
|
316 |
logger.error(f"Error submitting feedback: {str(e)}")
|
317 |
raise HTTPException(status_code=500, detail=str(e))
|
src/models/__pycache__/chat.cpython-312.pyc
CHANGED
Binary files a/src/models/__pycache__/chat.cpython-312.pyc and b/src/models/__pycache__/chat.cpython-312.pyc differ
|
|
src/models/__pycache__/rag.cpython-312.pyc
CHANGED
Binary files a/src/models/__pycache__/rag.cpython-312.pyc and b/src/models/__pycache__/rag.cpython-312.pyc differ
|
|
src/models/chat.py
CHANGED
@@ -4,6 +4,11 @@ from typing import Optional, List, Dict
|
|
4 |
from datetime import datetime
|
5 |
from .base import ChatMetadata
|
6 |
|
|
|
|
|
|
|
|
|
|
|
7 |
class ChatRequest(BaseModel):
|
8 |
"""Request model for chat endpoint"""
|
9 |
query: str
|
@@ -21,10 +26,19 @@ class ChatResponse(ChatMetadata):
|
|
21 |
relevant_doc_scores: Optional[List[float]] = None
|
22 |
|
23 |
class FeedbackRequest(BaseModel):
|
24 |
-
|
25 |
-
rating: int
|
26 |
feedback: Optional[str] = None
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
class SummarizeRequest(BaseModel):
|
29 |
"""Request model for summarize endpoint"""
|
30 |
conversation_id: str
|
|
|
4 |
from datetime import datetime
|
5 |
from .base import ChatMetadata
|
6 |
|
7 |
+
from pydantic import BaseModel, validator, Field
|
8 |
+
from typing import Optional
|
9 |
+
from config.config import settings
|
10 |
+
|
11 |
+
|
12 |
class ChatRequest(BaseModel):
|
13 |
"""Request model for chat endpoint"""
|
14 |
query: str
|
|
|
26 |
relevant_doc_scores: Optional[List[float]] = None
|
27 |
|
28 |
class FeedbackRequest(BaseModel):
|
29 |
+
rating: int = Field(..., ge=0, le=settings.MAX_RATING)
|
|
|
30 |
feedback: Optional[str] = None
|
31 |
|
32 |
+
@validator('rating')
|
33 |
+
def validate_rating(cls, v):
|
34 |
+
if v < 0 or v > settings.MAX_RATING:
|
35 |
+
raise ValueError(f'Rating must be between 0 and {settings.MAX_RATING}')
|
36 |
+
return v
|
37 |
+
|
38 |
+
def format_rating(self) -> str:
|
39 |
+
"""Format rating as a fraction of maximum"""
|
40 |
+
return f"{self.rating}/{settings.MAX_RATING}"
|
41 |
+
|
42 |
class SummarizeRequest(BaseModel):
|
43 |
"""Request model for summarize endpoint"""
|
44 |
conversation_id: str
|
src/models/rag.py
CHANGED
@@ -1,11 +1,37 @@
|
|
1 |
# src/models/rag.py
|
2 |
from dataclasses import dataclass
|
3 |
from typing import List, Optional, Dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
@dataclass
|
6 |
class RAGResponse:
|
7 |
-
"""
|
8 |
response: str
|
9 |
context_docs: Optional[List[str]] = None
|
10 |
sources: Optional[List[Dict]] = None
|
11 |
-
scores: Optional[List[float]] = None
|
|
|
|
1 |
# src/models/rag.py
|
2 |
from dataclasses import dataclass
|
3 |
from typing import List, Optional, Dict
|
4 |
+
from datetime import datetime
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class Message:
|
8 |
+
"""Single message in a conversation"""
|
9 |
+
role: str # 'user' or 'assistant'
|
10 |
+
content: str
|
11 |
+
timestamp: datetime
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class ConversationContext:
|
15 |
+
"""Conversation context with history"""
|
16 |
+
messages: List[Message]
|
17 |
+
max_messages: int = 10
|
18 |
+
|
19 |
+
def add_message(self, role: str, content: str):
|
20 |
+
"""Add a message while maintaining max size"""
|
21 |
+
self.messages.append(Message(
|
22 |
+
role=role,
|
23 |
+
content=content,
|
24 |
+
timestamp=datetime.now()
|
25 |
+
))
|
26 |
+
# Keep only the most recent messages
|
27 |
+
if len(self.messages) > self.max_messages:
|
28 |
+
self.messages = self.messages[-self.max_messages:]
|
29 |
|
30 |
@dataclass
|
31 |
class RAGResponse:
|
32 |
+
"""Enhanced RAG response with conversation context"""
|
33 |
response: str
|
34 |
context_docs: Optional[List[str]] = None
|
35 |
sources: Optional[List[Dict]] = None
|
36 |
+
scores: Optional[List[float]] = None
|
37 |
+
conversation_context: Optional[ConversationContext] = None
|
src/utils/__pycache__/conversation_manager.cpython-312.pyc
ADDED
Binary file (4.53 kB). View file
|
|
src/utils/conversation_manager.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/utils/conversation_manager.py
|
2 |
+
from typing import List, Dict, Optional
|
3 |
+
import tiktoken
|
4 |
+
from datetime import datetime
|
5 |
+
|
6 |
+
class ConversationManager:
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
max_tokens: int = 4000,
|
10 |
+
max_messages: int = 10,
|
11 |
+
model: str = "gpt-3.5-turbo"
|
12 |
+
):
|
13 |
+
"""
|
14 |
+
Initialize conversation manager
|
15 |
+
|
16 |
+
Args:
|
17 |
+
max_tokens (int): Maximum tokens to keep in context
|
18 |
+
max_messages (int): Maximum number of messages to keep
|
19 |
+
model (str): Model name for token counting
|
20 |
+
"""
|
21 |
+
self.max_tokens = max_tokens
|
22 |
+
self.max_messages = max_messages
|
23 |
+
self.encoding = tiktoken.encoding_for_model(model)
|
24 |
+
|
25 |
+
def format_messages(self, messages: List[Dict]) -> str:
|
26 |
+
"""Format messages into a conversation string"""
|
27 |
+
formatted = []
|
28 |
+
for msg in messages:
|
29 |
+
role = msg.get('role', 'unknown')
|
30 |
+
content = msg.get('content', '')
|
31 |
+
formatted.append(f"{role.capitalize()}: {content}")
|
32 |
+
return "\n".join(formatted)
|
33 |
+
|
34 |
+
def count_tokens(self, text: str) -> int:
|
35 |
+
"""Count tokens in text"""
|
36 |
+
return len(self.encoding.encode(text))
|
37 |
+
|
38 |
+
def get_relevant_history(
|
39 |
+
self,
|
40 |
+
messages: List[Dict],
|
41 |
+
current_query: str,
|
42 |
+
max_tokens: Optional[int] = None
|
43 |
+
) -> List[Dict]:
|
44 |
+
"""
|
45 |
+
Get relevant conversation history within token limit
|
46 |
+
|
47 |
+
Args:
|
48 |
+
messages (List[Dict]): Full message history
|
49 |
+
current_query (str): Current user query
|
50 |
+
max_tokens (Optional[int]): Override default max tokens
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
List[Dict]: Relevant message history
|
54 |
+
"""
|
55 |
+
max_tokens = max_tokens or self.max_tokens
|
56 |
+
current_tokens = self.count_tokens(current_query)
|
57 |
+
|
58 |
+
# Keep track of tokens and messages
|
59 |
+
history = []
|
60 |
+
total_tokens = current_tokens
|
61 |
+
|
62 |
+
# Process messages from most recent to oldest
|
63 |
+
for msg in reversed(messages[-self.max_messages:]):
|
64 |
+
msg_text = f"{msg['role']}: {msg['content']}\n"
|
65 |
+
msg_tokens = self.count_tokens(msg_text)
|
66 |
+
|
67 |
+
# Check if adding this message would exceed token limit
|
68 |
+
if total_tokens + msg_tokens > max_tokens:
|
69 |
+
break
|
70 |
+
|
71 |
+
total_tokens += msg_tokens
|
72 |
+
history.append(msg)
|
73 |
+
|
74 |
+
# Reverse back to chronological order
|
75 |
+
return list(reversed(history))
|
76 |
+
|
77 |
+
def generate_prompt_with_history(
|
78 |
+
self,
|
79 |
+
current_query: str,
|
80 |
+
history: List[Dict],
|
81 |
+
context_docs: List[str]
|
82 |
+
) -> str:
|
83 |
+
"""
|
84 |
+
Generate a prompt that includes conversation history and context
|
85 |
+
|
86 |
+
Args:
|
87 |
+
current_query (str): Current user query
|
88 |
+
history (List[Dict]): Relevant conversation history
|
89 |
+
context_docs (List[str]): Retrieved context documents
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
str: Formatted prompt
|
93 |
+
"""
|
94 |
+
# Format conversation history
|
95 |
+
conversation_context = self.format_messages(history)
|
96 |
+
|
97 |
+
# Format context documents
|
98 |
+
context_str = "\n\n".join(context_docs)
|
99 |
+
|
100 |
+
prompt = f"""
|
101 |
+
Previous Conversation:
|
102 |
+
{conversation_context}
|
103 |
+
|
104 |
+
Relevant Context:
|
105 |
+
{context_str}
|
106 |
+
|
107 |
+
Current Query: {current_query}
|
108 |
+
|
109 |
+
Based on the previous conversation and the provided context, please provide a comprehensive and accurate response that maintains continuity with the conversation history."""
|
110 |
+
|
111 |
+
return prompt
|