File size: 6,445 Bytes
b39c0ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import logging
import re
import os

from typing import List
from datetime import datetime
from fastapi.responses import JSONResponse
from script.vector_db import IndexManager
from llama_index.core.llms import MessageRole

from core.chat.engine import Engine
from core.chat.chatstore import ChatStore
from core.parser import clean_text, update_response, renumber_sources

from service.dto import ChatMessage
from pymongo.mongo_client import MongoClient




class ChatCompletionService:
    def __init__(self, session_id: str, user_request: str, titles: List = None, type_bot: str = "general"):
        self.session_id = session_id
        self.user_request = user_request
        self.titles = titles
        self.type_bot = type_bot
        self.client = MongoClient(os.getenv("MONGO_URI"))
        self.engine = Engine()
        self.index_manager = IndexManager()
        self.chatstore = ChatStore()

    def generate_completion(self):
        if not self._ping_mongo():
            return JSONResponse(status_code=500, content="Database Error: Unable to connect to MongoDB")

        try:
            # Load and retrieve chat engine with appropriate index
            index = self.index_manager.load_existing_indexes()
            chat_engine = self._get_chat_engine(index)

            # Generate chat response
            response = chat_engine.chat(self.user_request)
            sources = response.sources
            number_reference_sorted = self._extract_sorted_references(response)

            contents, metadata_collection, scores = self._process_sources(sources, number_reference_sorted)

            # Update response and renumber sources
            response = update_response(str(response))
            contents = renumber_sources(contents)

            # Add contents to metadata
            metadata_collection = self._attach_contents_to_metadata(contents, metadata_collection)

            # Save the message to chat store
            self._store_message_in_chatstore(response, metadata_collection)

        except Exception as e:
            logging.error(f"An error occurred in generate text: {e}")
            return JSONResponse(
                status_code=500,
                content=f"An internal server error occurred: {e}"
            )

        try:
            if self.type_bot == "specific":
                self._save_chat_history_to_db(response, metadata_collection)

            return str(response), metadata_collection, scores

        except Exception as e:
            logging.error(f"An error occurred while saving chat history: {e}")
            return JSONResponse(
                status_code=500,
                content=f"An internal server error occurred while saving chat history: {e}"
            )

    def _ping_mongo(self):
        try:
            self.client.admin.command("ping")
            print("Pinged your deployment. Successfully connected to MongoDB!")
            return True
        except Exception as e:
            logging.error(f"MongoDB connection failed: {e}")
            return False

    def _get_chat_engine(self, index):
        if self.type_bot == "general":
            return self.engine.get_chat_engine(self.session_id, index)
        return self.engine.get_chat_engine(self.session_id, index, self.titles, self.type_bot)

    def _extract_sorted_references(self, response):
        number_reference = list(set(re.findall(r"\[(\d+)\]", str(response))))
        return sorted(number_reference)

    def _process_sources(self, sources, number_reference_sorted):
        contents, metadata_collection, scores = [], [], []
        if not number_reference_sorted:
            print("There are no references")
            return contents, metadata_collection, scores

        for number in number_reference_sorted:
            number = int(number)
            if sources and sources[0].get("raw_output"):
                node = dict(sources[0])["raw_output"].source_nodes
                if 0 <= number - 1 < len(node):
                    content = clean_text(node[number - 1].node.get_text())
                    contents.append(content)
                    metadata = dict(node[number - 1].node.metadata)
                    metadata_collection.append(metadata)
                    score = node[number - 1].score
                    scores.append(score)
                else:
                    print(f"Invalid reference number: {number}")
            else:
                print("No sources available")

        return contents, metadata_collection, scores

    def _attach_contents_to_metadata(self, contents, metadata_collection):
        for i in range(min(len(contents), len(metadata_collection))):
            metadata_collection[i]["content"] = re.sub(r"source \d+:", "", contents[i])
        return metadata_collection

    def _store_message_in_chatstore(self, response, metadata_collection):
        message = ChatMessage(
            role=MessageRole.ASSISTANT,
            content=response,
            metadata=metadata_collection
        )
        self.chatstore.delete_last_message(self.session_id)
        self.chatstore.add_message(self.session_id, message)
        self.chatstore.clean_message(self.session_id)

    def _save_chat_history_to_db(self, response, metadata_collection):
        chat_history_db = [
            ChatMessage(
                role=MessageRole.SYSTEM,
                content=self.user_request,
                timestamp=datetime.now(),
                payment="free" if self.type_bot == "general" else None,
            ),
            ChatMessage(
                role=MessageRole.ASSISTANT,
                content=response,
                metadata=metadata_collection,
                timestamp=datetime.now(),
                payment="free" if self.type_bot == "general" else None,
            ),
        ]

        chat_history_json = [message.model_dump() for message in chat_history_db]

        db = self.client["bot_database"]  # Replace with your database name
        collection = db[self.session_id]  # Replace with your collection name
        result = collection.insert_many(chat_history_json)
        print("Data inserted with record ids", result.inserted_ids)


# Example usage
def generate_completion_non_streaming(session_id, user_request, titles=None, type_bot="general"):
    chat_service = ChatCompletionService(session_id, user_request, titles, type_bot)
    return chat_service.generate_completion()