Spaces:
Sleeping
Sleeping
File size: 6,445 Bytes
b39c0ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import logging
import re
import os
from typing import List
from datetime import datetime
from fastapi.responses import JSONResponse
from script.vector_db import IndexManager
from llama_index.core.llms import MessageRole
from core.chat.engine import Engine
from core.chat.chatstore import ChatStore
from core.parser import clean_text, update_response, renumber_sources
from service.dto import ChatMessage
from pymongo.mongo_client import MongoClient
class ChatCompletionService:
def __init__(self, session_id: str, user_request: str, titles: List = None, type_bot: str = "general"):
self.session_id = session_id
self.user_request = user_request
self.titles = titles
self.type_bot = type_bot
self.client = MongoClient(os.getenv("MONGO_URI"))
self.engine = Engine()
self.index_manager = IndexManager()
self.chatstore = ChatStore()
def generate_completion(self):
if not self._ping_mongo():
return JSONResponse(status_code=500, content="Database Error: Unable to connect to MongoDB")
try:
# Load and retrieve chat engine with appropriate index
index = self.index_manager.load_existing_indexes()
chat_engine = self._get_chat_engine(index)
# Generate chat response
response = chat_engine.chat(self.user_request)
sources = response.sources
number_reference_sorted = self._extract_sorted_references(response)
contents, metadata_collection, scores = self._process_sources(sources, number_reference_sorted)
# Update response and renumber sources
response = update_response(str(response))
contents = renumber_sources(contents)
# Add contents to metadata
metadata_collection = self._attach_contents_to_metadata(contents, metadata_collection)
# Save the message to chat store
self._store_message_in_chatstore(response, metadata_collection)
except Exception as e:
logging.error(f"An error occurred in generate text: {e}")
return JSONResponse(
status_code=500,
content=f"An internal server error occurred: {e}"
)
try:
if self.type_bot == "specific":
self._save_chat_history_to_db(response, metadata_collection)
return str(response), metadata_collection, scores
except Exception as e:
logging.error(f"An error occurred while saving chat history: {e}")
return JSONResponse(
status_code=500,
content=f"An internal server error occurred while saving chat history: {e}"
)
def _ping_mongo(self):
try:
self.client.admin.command("ping")
print("Pinged your deployment. Successfully connected to MongoDB!")
return True
except Exception as e:
logging.error(f"MongoDB connection failed: {e}")
return False
def _get_chat_engine(self, index):
if self.type_bot == "general":
return self.engine.get_chat_engine(self.session_id, index)
return self.engine.get_chat_engine(self.session_id, index, self.titles, self.type_bot)
def _extract_sorted_references(self, response):
number_reference = list(set(re.findall(r"\[(\d+)\]", str(response))))
return sorted(number_reference)
def _process_sources(self, sources, number_reference_sorted):
contents, metadata_collection, scores = [], [], []
if not number_reference_sorted:
print("There are no references")
return contents, metadata_collection, scores
for number in number_reference_sorted:
number = int(number)
if sources and sources[0].get("raw_output"):
node = dict(sources[0])["raw_output"].source_nodes
if 0 <= number - 1 < len(node):
content = clean_text(node[number - 1].node.get_text())
contents.append(content)
metadata = dict(node[number - 1].node.metadata)
metadata_collection.append(metadata)
score = node[number - 1].score
scores.append(score)
else:
print(f"Invalid reference number: {number}")
else:
print("No sources available")
return contents, metadata_collection, scores
def _attach_contents_to_metadata(self, contents, metadata_collection):
for i in range(min(len(contents), len(metadata_collection))):
metadata_collection[i]["content"] = re.sub(r"source \d+:", "", contents[i])
return metadata_collection
def _store_message_in_chatstore(self, response, metadata_collection):
message = ChatMessage(
role=MessageRole.ASSISTANT,
content=response,
metadata=metadata_collection
)
self.chatstore.delete_last_message(self.session_id)
self.chatstore.add_message(self.session_id, message)
self.chatstore.clean_message(self.session_id)
def _save_chat_history_to_db(self, response, metadata_collection):
chat_history_db = [
ChatMessage(
role=MessageRole.SYSTEM,
content=self.user_request,
timestamp=datetime.now(),
payment="free" if self.type_bot == "general" else None,
),
ChatMessage(
role=MessageRole.ASSISTANT,
content=response,
metadata=metadata_collection,
timestamp=datetime.now(),
payment="free" if self.type_bot == "general" else None,
),
]
chat_history_json = [message.model_dump() for message in chat_history_db]
db = self.client["bot_database"] # Replace with your database name
collection = db[self.session_id] # Replace with your collection name
result = collection.insert_many(chat_history_json)
print("Data inserted with record ids", result.inserted_ids)
# Example usage
def generate_completion_non_streaming(session_id, user_request, titles=None, type_bot="general"):
chat_service = ChatCompletionService(session_id, user_request, titles, type_bot)
return chat_service.generate_completion()
|