Spaces:
Sleeping
Sleeping
import logging | |
import re | |
import os | |
import pytz | |
from typing import List | |
from datetime import datetime | |
from datetime import timedelta | |
from fastapi.responses import JSONResponse | |
from script.vector_db import IndexManager | |
from llama_index.core.llms import MessageRole | |
from core.chat.engine import Engine | |
from core.chat.chatstore import ChatStore | |
from core.parser import ( | |
filter_metadata_by_pages, | |
extract_sorted_page_numbers | |
) | |
from service.dto import ChatMessage | |
from pymongo.mongo_client import MongoClient | |
class ChatCompletionService: | |
def __init__( | |
self, | |
session_id: str, | |
user_request: str, | |
titles: List = None, | |
type_bot: str = "general", | |
): | |
self.session_id = session_id | |
self.user_request = user_request | |
self.titles = titles | |
self.type_bot = type_bot | |
self.client = MongoClient(os.getenv("MONGO_URI")) | |
self.engine = Engine() | |
self.index_manager = IndexManager() | |
self.chatstore = ChatStore() | |
def generate_completion(self): | |
if not self._ping_mongo(): | |
return JSONResponse( | |
status_code=500, content="Database Error: Unable to connect to MongoDB" | |
) | |
try: | |
# Load and retrieve chat engine with appropriate index | |
index = self.index_manager.load_existing_indexes() | |
chat_engine = self._get_chat_engine(index) | |
# Generate chat response | |
response = chat_engine.chat(self.user_request) | |
sources = response.source_nodes | |
contents, metadata_collection, scores = self._process_sources_images(sources) | |
# Update response and renumber sources | |
response = str(response) | |
# Add contents to metadata | |
metadata_collection = self._attach_contents_to_metadata( | |
contents, metadata_collection | |
) | |
page_sources = extract_sorted_page_numbers(response) | |
metadata_collection = filter_metadata_by_pages(metadata_collection, page_sources) | |
# Save the message to chat store | |
self._store_message_in_chatstore(response, metadata_collection) | |
except Exception as e: | |
logging.error(f"An error occurred in generate text: {e}") | |
return JSONResponse( | |
status_code=500, content=f"An internal server error occurred: {e}" | |
) | |
try: | |
if self.type_bot == "specific": | |
self._save_chat_history_to_db(response, metadata_collection) | |
return str(response), metadata_collection, scores | |
except Exception as e: | |
logging.error(f"An error occurred while saving chat history: {e}") | |
return JSONResponse( | |
status_code=500, | |
content=f"An internal server error occurred while saving chat history: {e}", | |
) | |
def _ping_mongo(self): | |
try: | |
self.client.admin.command("ping") | |
print("Pinged your deployment. Successfully connected to MongoDB!") | |
return True | |
except Exception as e: | |
logging.error(f"MongoDB connection failed: {e}") | |
return False | |
def _get_chat_engine(self, index): | |
if self.type_bot == "general": | |
return self.engine.get_chat_engine(self.session_id, index) | |
return self.engine.get_chat_engine( | |
self.session_id, index, self.titles, self.type_bot | |
) | |
def _extract_sorted_references(self, response): | |
number_reference = list(set(re.findall(r"\[(\d+)\]", str(response)))) | |
return sorted(number_reference) | |
def _process_sources_images(self, sources): | |
contents, metadata_collection, scores = [], [], [] | |
for number in range (len(sources)): | |
if sources and len(sources) > 0: | |
content = sources[number - 1].node.get_text() | |
contents.append(content) | |
metadata = dict(sources[number - 1].node.metadata) | |
metadata_collection.append(metadata) | |
score = sources[number - 1].score | |
scores.append(score) | |
else: | |
print("No sources available") | |
return contents, metadata_collection, scores | |
def _attach_contents_to_metadata(self, contents, metadata_collection): | |
for i in range(min(len(contents), len(metadata_collection))): | |
metadata_collection[i]["content"] = contents[i] | |
return metadata_collection | |
def _store_message_in_chatstore(self, response, metadata_collection): | |
message = ChatMessage( | |
role=MessageRole.ASSISTANT, | |
content=response, | |
metadata=metadata_collection, | |
) | |
self.chatstore.delete_last_message(self.session_id) | |
self.chatstore.add_message(self.session_id, message) | |
self.chatstore.clean_message(self.session_id) | |
def _save_chat_history_to_db(self, response, metadata_collection): | |
jakarta_tz = pytz.timezone("Asia/Jakarta") | |
time_now = datetime.now(jakarta_tz) | |
user_timestamp = time_now - timedelta(seconds=0.2) | |
chat_history_db = [ | |
ChatMessage( | |
role=MessageRole.USER, | |
content=self.user_request, | |
timestamp=user_timestamp, | |
payment="free" if self.type_bot == "general" else None, | |
), | |
ChatMessage( | |
role=MessageRole.ASSISTANT, | |
content=response, | |
metadata=metadata_collection, | |
timestamp=time_now, | |
payment="free" if self.type_bot == "general" else None, | |
), | |
] | |
chat_history_json = [message.model_dump() for message in chat_history_db] | |
db = self.client["bot_database"] # Replace with your database name | |
collection = db[self.session_id] # Replace with your collection name | |
collection.insert_many(chat_history_json) | |