Spaces:
Sleeping
Sleeping
File size: 6,445 Bytes
b39c0ba |
|
import logging
import re
import os
from typing import List
from datetime import datetime
from fastapi.responses import JSONResponse
from script.vector_db import IndexManager
from llama_index.core.llms import MessageRole
from core.chat.engine import Engine
from core.chat.chatstore import ChatStore
from core.parser import clean_text, update_response, renumber_sources
from service.dto import ChatMessage
from pymongo.mongo_client import MongoClient
class ChatCompletionService:
def __init__(self, session_id: str, user_request: str, titles: List = None, type_bot: str = "general"):
self.session_id = session_id
self.user_request = user_request
self.titles = titles
self.type_bot = type_bot
self.client = MongoClient(os.getenv("MONGO_URI"))
self.engine = Engine()
self.index_manager = IndexManager()
self.chatstore = ChatStore()
def generate_completion(self):
if not self._ping_mongo():
return JSONResponse(status_code=500, content="Database Error: Unable to connect to MongoDB")
try:
# Load and retrieve chat engine with appropriate index
index = self.index_manager.load_existing_indexes()
chat_engine = self._get_chat_engine(index)
# Generate chat response
response = chat_engine.chat(self.user_request)
sources = response.sources
number_reference_sorted = self._extract_sorted_references(response)
contents, metadata_collection, scores = self._process_sources(sources, number_reference_sorted)
# Update response and renumber sources
response = update_response(str(response))
contents = renumber_sources(contents)
# Add contents to metadata
metadata_collection = self._attach_contents_to_metadata(contents, metadata_collection)
# Save the message to chat store
self._store_message_in_chatstore(response, metadata_collection)
except Exception as e:
logging.error(f"An error occurred in generate text: {e}")
return JSONResponse(
status_code=500,
content=f"An internal server error occurred: {e}"
)
try:
if self.type_bot == "specific":
self._save_chat_history_to_db(response, metadata_collection)
return str(response), metadata_collection, scores
except Exception as e:
logging.error(f"An error occurred while saving chat history: {e}")
return JSONResponse(
status_code=500,
content=f"An internal server error occurred while saving chat history: {e}"
)
def _ping_mongo(self):
try:
self.client.admin.command("ping")
print("Pinged your deployment. Successfully connected to MongoDB!")
return True
except Exception as e:
logging.error(f"MongoDB connection failed: {e}")
return False
def _get_chat_engine(self, index):
if self.type_bot == "general":
return self.engine.get_chat_engine(self.session_id, index)
return self.engine.get_chat_engine(self.session_id, index, self.titles, self.type_bot)
def _extract_sorted_references(self, response):
number_reference = list(set(re.findall(r"\[(\d+)\]", str(response))))
return sorted(number_reference)
def _process_sources(self, sources, number_reference_sorted):
contents, metadata_collection, scores = [], [], []
if not number_reference_sorted:
print("There are no references")
return contents, metadata_collection, scores
for number in number_reference_sorted:
number = int(number)
if sources and sources[0].get("raw_output"):
node = dict(sources[0])["raw_output"].source_nodes
if 0 <= number - 1 < len(node):
content = clean_text(node[number - 1].node.get_text())
contents.append(content)
metadata = dict(node[number - 1].node.metadata)
metadata_collection.append(metadata)
score = node[number - 1].score
scores.append(score)
else:
print(f"Invalid reference number: {number}")
else:
print("No sources available")
return contents, metadata_collection, scores
def _attach_contents_to_metadata(self, contents, metadata_collection):
for i in range(min(len(contents), len(metadata_collection))):
metadata_collection[i]["content"] = re.sub(r"source \d+:", "", contents[i])
return metadata_collection
def _store_message_in_chatstore(self, response, metadata_collection):
message = ChatMessage(
role=MessageRole.ASSISTANT,
content=response,
metadata=metadata_collection
)
self.chatstore.delete_last_message(self.session_id)
self.chatstore.add_message(self.session_id, message)
self.chatstore.clean_message(self.session_id)
def _save_chat_history_to_db(self, response, metadata_collection):
chat_history_db = [
ChatMessage(
role=MessageRole.SYSTEM,
content=self.user_request,
timestamp=datetime.now(),
payment="free" if self.type_bot == "general" else None,
),
ChatMessage(
role=MessageRole.ASSISTANT,
content=response,
metadata=metadata_collection,
timestamp=datetime.now(),
payment="free" if self.type_bot == "general" else None,
),
]
chat_history_json = [message.model_dump() for message in chat_history_db]
db = self.client["bot_database"] # Replace with your database name
collection = db[self.session_id] # Replace with your collection name
result = collection.insert_many(chat_history_json)
print("Data inserted with record ids", result.inserted_ids)
# Example usage
def generate_completion_non_streaming(session_id, user_request, titles=None, type_bot="general"):
chat_service = ChatCompletionService(session_id, user_request, titles, type_bot)
return chat_service.generate_completion()
|