Spaces:
Sleeping
Sleeping
Commit
·
0743bb0
1
Parent(s):
9002555
Update Repository
Browse files- api/function.py +58 -36
- api/router/bot.py +49 -4
- api/router/reader.py +17 -0
- api/router/topic.py +23 -4
- app.py +2 -1
- core/chat/chatstore.py +93 -0
- core/chat/engine.py +18 -52
- core/chat/messaging.py +1 -1
- core/prompt.py +1 -1
- db/get_data.py +14 -3
- script/get_metadata.py +17 -16
- script/vector_db.py +157 -0
- service/aws_loader.py +41 -8
- service/dto.py +14 -2
- utils/utils.py +9 -1
api/function.py
CHANGED
@@ -1,57 +1,62 @@
|
|
1 |
-
from script.
|
2 |
from script.document_uploader import Uploader
|
3 |
from db.save_data import InsertDatabase
|
4 |
from db.get_data import GetDatabase
|
5 |
from db.delete_data import DeleteDatabase
|
6 |
from db.update_data import UpdateDatabase
|
7 |
-
|
|
|
8 |
from fastapi import UploadFile
|
9 |
from fastapi import HTTPException
|
|
|
|
|
10 |
from core.chat.engine import Engine
|
|
|
11 |
from core.parser import clean_text, update_response, renumber_sources, seperate_to_list
|
12 |
-
from llama_index.core.
|
13 |
-
from service.dto import BotResponseStreaming
|
14 |
from service.aws_loader import Loader
|
15 |
|
16 |
import logging
|
17 |
import re
|
|
|
18 |
|
19 |
|
20 |
# Configure logging
|
21 |
logging.basicConfig(level=logging.INFO)
|
22 |
|
23 |
-
|
24 |
# async def data_ingestion(
|
25 |
# db_conn, reference, file: UploadFile, content_table: UploadFile
|
26 |
# ) -> Any:
|
27 |
|
28 |
-
async def data_ingestion(
|
29 |
-
db_conn, reference, file: UploadFile
|
30 |
-
) -> Any:
|
31 |
-
|
32 |
-
insert_database = InsertDatabase(db_conn)
|
33 |
-
|
34 |
-
file_name = f"{reference['title']}.pdf"
|
35 |
-
aws_loader = Loader()
|
36 |
-
|
37 |
-
file_obj = file
|
38 |
-
aws_loader.upload_to_s3(file_obj, file_name)
|
39 |
-
|
40 |
-
print("Uploaded Success")
|
41 |
|
|
|
42 |
try:
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
#
|
53 |
-
|
54 |
-
|
|
|
|
|
55 |
|
56 |
return response
|
57 |
|
@@ -63,6 +68,7 @@ async def data_ingestion(
|
|
63 |
detail="An internal server error occurred in data ingestion.",
|
64 |
)
|
65 |
|
|
|
66 |
async def get_data(db_conn, title="", fetch_all_data=True):
|
67 |
get_database = GetDatabase(db_conn)
|
68 |
print(get_database)
|
@@ -118,21 +124,31 @@ async def delete_data(id: int, db_conn):
|
|
118 |
)
|
119 |
|
120 |
|
121 |
-
def generate_completion_non_streaming(
|
|
|
|
|
122 |
try:
|
123 |
engine = Engine()
|
124 |
index_manager = IndexManager()
|
|
|
125 |
|
126 |
# Load existing indexes
|
127 |
index = index_manager.load_existing_indexes()
|
128 |
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
# Generate completion response
|
133 |
response = chat_engine.chat(user_request)
|
134 |
|
135 |
sources = response.sources
|
|
|
136 |
|
137 |
number_reference = list(set(re.findall(r"\[(\d+)\]", str(response))))
|
138 |
number_reference_sorted = sorted(number_reference)
|
@@ -153,10 +169,8 @@ def generate_completion_non_streaming(user_request, chat_engine):
|
|
153 |
|
154 |
# Pastikan number valid sebagai indeks
|
155 |
if 0 <= number - 1 < len(node):
|
156 |
-
|
157 |
-
raw_content = seperate_to_list(
|
158 |
-
node[number - 1].node.get_text()
|
159 |
-
)
|
160 |
raw_contents.append(raw_content)
|
161 |
|
162 |
content = clean_text(node[number - 1].node.get_text())
|
@@ -176,7 +190,7 @@ def generate_completion_non_streaming(user_request, chat_engine):
|
|
176 |
|
177 |
response = update_response(str(response))
|
178 |
contents = renumber_sources(contents)
|
179 |
-
|
180 |
# Check the lengths of content and metadata
|
181 |
num_content = len(contents)
|
182 |
num_metadata = len(metadata_collection)
|
@@ -185,6 +199,14 @@ def generate_completion_non_streaming(user_request, chat_engine):
|
|
185 |
for i in range(min(num_content, num_metadata)):
|
186 |
metadata_collection[i]["content"] = re.sub(r"source \d+\:", "", contents[i])
|
187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
return str(response), raw_contents, contents, metadata_collection, scores
|
189 |
except Exception as e:
|
190 |
# Log the error and raise HTTPException for FastAPI
|
|
|
1 |
+
from script.vector_db import IndexManager
|
2 |
from script.document_uploader import Uploader
|
3 |
from db.save_data import InsertDatabase
|
4 |
from db.get_data import GetDatabase
|
5 |
from db.delete_data import DeleteDatabase
|
6 |
from db.update_data import UpdateDatabase
|
7 |
+
|
8 |
+
from typing import Any, Optional, List
|
9 |
from fastapi import UploadFile
|
10 |
from fastapi import HTTPException
|
11 |
+
|
12 |
+
from service.dto import ChatMessage
|
13 |
from core.chat.engine import Engine
|
14 |
+
from core.chat.chatstore import ChatStore
|
15 |
from core.parser import clean_text, update_response, renumber_sources, seperate_to_list
|
16 |
+
from llama_index.core.llms import MessageRole
|
17 |
+
from service.dto import BotResponseStreaming
|
18 |
from service.aws_loader import Loader
|
19 |
|
20 |
import logging
|
21 |
import re
|
22 |
+
import json
|
23 |
|
24 |
|
25 |
# Configure logging
|
26 |
logging.basicConfig(level=logging.INFO)
|
27 |
|
|
|
28 |
# async def data_ingestion(
|
29 |
# db_conn, reference, file: UploadFile, content_table: UploadFile
|
30 |
# ) -> Any:
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
async def data_ingestion(db_conn, reference, file: UploadFile) -> Any:
|
34 |
try:
|
35 |
+
|
36 |
+
# insert_database = InsertDatabase(db_conn)
|
37 |
+
|
38 |
+
file_name = f"{reference['title']}"
|
39 |
+
aws_loader = Loader()
|
40 |
+
|
41 |
+
file_obj = file
|
42 |
+
aws_loader.upload_to_s3(file_obj, file_name)
|
43 |
+
|
44 |
+
print("Uploaded Success")
|
45 |
|
46 |
+
response = json.dumps({"status": "success", "message": "Vector Index loaded successfully."})
|
47 |
+
|
48 |
+
# Insert data into the database
|
49 |
+
# await insert_database.insert_data(reference)
|
50 |
+
|
51 |
+
# # uploader = Uploader(reference, file, content_table)
|
52 |
+
# uploader = Uploader(reference, file)
|
53 |
+
# print("uploader : ", uploader)
|
54 |
|
55 |
+
# nodes_with_metadata = await uploader.process_documents()
|
56 |
+
|
57 |
+
# # Build indexes using IndexManager
|
58 |
+
# index = IndexManager()
|
59 |
+
# response = index.build_indexes(nodes_with_metadata)
|
60 |
|
61 |
return response
|
62 |
|
|
|
68 |
detail="An internal server error occurred in data ingestion.",
|
69 |
)
|
70 |
|
71 |
+
|
72 |
async def get_data(db_conn, title="", fetch_all_data=True):
|
73 |
get_database = GetDatabase(db_conn)
|
74 |
print(get_database)
|
|
|
124 |
)
|
125 |
|
126 |
|
127 |
+
def generate_completion_non_streaming(
|
128 |
+
session_id, user_request, chat_engine, title=None, category=None, type="general"
|
129 |
+
):
|
130 |
try:
|
131 |
engine = Engine()
|
132 |
index_manager = IndexManager()
|
133 |
+
chatstore = ChatStore()
|
134 |
|
135 |
# Load existing indexes
|
136 |
index = index_manager.load_existing_indexes()
|
137 |
|
138 |
+
if type == "general":
|
139 |
+
# Retrieve the chat engine with the loaded index
|
140 |
+
chat_engine = engine.get_chat_engine(session_id, index)
|
141 |
+
else:
|
142 |
+
# Retrieve the chat engine with the loaded index
|
143 |
+
chat_engine = engine.get_chat_engine(
|
144 |
+
session_id, index, title=title, category=category
|
145 |
+
)
|
146 |
|
147 |
# Generate completion response
|
148 |
response = chat_engine.chat(user_request)
|
149 |
|
150 |
sources = response.sources
|
151 |
+
print(sources)
|
152 |
|
153 |
number_reference = list(set(re.findall(r"\[(\d+)\]", str(response))))
|
154 |
number_reference_sorted = sorted(number_reference)
|
|
|
169 |
|
170 |
# Pastikan number valid sebagai indeks
|
171 |
if 0 <= number - 1 < len(node):
|
172 |
+
|
173 |
+
raw_content = seperate_to_list(node[number - 1].node.get_text())
|
|
|
|
|
174 |
raw_contents.append(raw_content)
|
175 |
|
176 |
content = clean_text(node[number - 1].node.get_text())
|
|
|
190 |
|
191 |
response = update_response(str(response))
|
192 |
contents = renumber_sources(contents)
|
193 |
+
|
194 |
# Check the lengths of content and metadata
|
195 |
num_content = len(contents)
|
196 |
num_metadata = len(metadata_collection)
|
|
|
199 |
for i in range(min(num_content, num_metadata)):
|
200 |
metadata_collection[i]["content"] = re.sub(r"source \d+\:", "", contents[i])
|
201 |
|
202 |
+
message = ChatMessage(
|
203 |
+
role=MessageRole.ASSISTANT, content=response, metadata=metadata_collection
|
204 |
+
)
|
205 |
+
|
206 |
+
chatstore.delete_last_message(session_id)
|
207 |
+
chatstore.add_message(session_id, message)
|
208 |
+
chatstore.clean_message(session_id)
|
209 |
+
|
210 |
return str(response), raw_contents, contents, metadata_collection, scores
|
211 |
except Exception as e:
|
212 |
# Log the error and raise HTTPException for FastAPI
|
api/router/bot.py
CHANGED
@@ -1,16 +1,49 @@
|
|
1 |
-
from fastapi import APIRouter
|
2 |
from service.dto import UserPromptRequest, BotResponse
|
|
|
3 |
|
4 |
from api.function import (
|
5 |
generate_streaming_completion,
|
6 |
generate_completion_non_streaming,
|
7 |
)
|
8 |
from sse_starlette.sse import EventSourceResponse
|
|
|
9 |
|
10 |
router = APIRouter(tags=["Bot"])
|
11 |
|
|
|
|
|
12 |
|
13 |
-
@router.post("/bot")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
async def bot_generator_general(user_prompt_request: UserPromptRequest):
|
15 |
|
16 |
if user_prompt_request.streaming:
|
@@ -22,7 +55,7 @@ async def bot_generator_general(user_prompt_request: UserPromptRequest):
|
|
22 |
else:
|
23 |
response, raw_references, references, metadata, scores = (
|
24 |
generate_completion_non_streaming(
|
25 |
-
user_prompt_request.prompt, user_prompt_request.streaming
|
26 |
)
|
27 |
)
|
28 |
|
@@ -35,12 +68,24 @@ async def bot_generator_general(user_prompt_request: UserPromptRequest):
|
|
35 |
)
|
36 |
|
37 |
|
38 |
-
@router.post("/bot/{category_id}/{title}")
|
39 |
async def bot_generator_spesific(
|
40 |
category_id: int, title: str, user_prompt_request: UserPromptRequest
|
41 |
):
|
42 |
pass
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
@router.get("/bot/{category_id}/{title}")
|
46 |
async def get_favourite_data(category_id: int, title: str, human_template):
|
|
|
1 |
+
from fastapi import APIRouter, HTTPException, Depends
|
2 |
from service.dto import UserPromptRequest, BotResponse
|
3 |
+
from core.chat.chatstore import ChatStore
|
4 |
|
5 |
from api.function import (
|
6 |
generate_streaming_completion,
|
7 |
generate_completion_non_streaming,
|
8 |
)
|
9 |
from sse_starlette.sse import EventSourceResponse
|
10 |
+
from utils.utils import generate_uuid
|
11 |
|
12 |
router = APIRouter(tags=["Bot"])
|
13 |
|
14 |
+
def get_chat_store():
|
15 |
+
return ChatStore()
|
16 |
|
17 |
+
@router.post("/bot/new")
|
18 |
+
async def create_new_session():
|
19 |
+
session_id = generate_uuid()
|
20 |
+
return {"session_id" : session_id}
|
21 |
+
|
22 |
+
@router.get("/bot/{session_id}")
|
23 |
+
async def get_session_id(session_id: str, chat_store: ChatStore = Depends(get_chat_store)):
|
24 |
+
chat_history = chat_store.get_messages(session_id)
|
25 |
+
|
26 |
+
if not chat_history:
|
27 |
+
raise HTTPException(status_code=404, detail="Session not found or empty.")
|
28 |
+
|
29 |
+
return chat_history
|
30 |
+
|
31 |
+
@router.get("/bot")
|
32 |
+
async def get_all_session_ids():
|
33 |
+
try:
|
34 |
+
chat_store = ChatStore()
|
35 |
+
all_keys = chat_store.get_keys()
|
36 |
+
print(all_keys)
|
37 |
+
return all_keys
|
38 |
+
except Exception as e:
|
39 |
+
# Log the error and raise HTTPException for FastAPI
|
40 |
+
print(f"An error occurred in update data.: {e}")
|
41 |
+
raise HTTPException(
|
42 |
+
status_code=400, detail="the error when get all session ids"
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
@router.post("/bot/{session_id}")
|
47 |
async def bot_generator_general(user_prompt_request: UserPromptRequest):
|
48 |
|
49 |
if user_prompt_request.streaming:
|
|
|
55 |
else:
|
56 |
response, raw_references, references, metadata, scores = (
|
57 |
generate_completion_non_streaming(
|
58 |
+
user_prompt_request.session_id, user_prompt_request.prompt, user_prompt_request.streaming
|
59 |
)
|
60 |
)
|
61 |
|
|
|
68 |
)
|
69 |
|
70 |
|
71 |
+
@router.post("/bot/{category_id}/{title}") #Ganti router
|
72 |
async def bot_generator_spesific(
|
73 |
category_id: int, title: str, user_prompt_request: UserPromptRequest
|
74 |
):
|
75 |
pass
|
76 |
|
77 |
+
@router.delete("/bot/{session_id}")
|
78 |
+
async def delete_bot(session_id: str, chat_store: ChatStore = Depends(get_chat_store)):
|
79 |
+
try:
|
80 |
+
chat_store.delete_messages(session_id)
|
81 |
+
return {"info": f"Delete {session_id} successful"}
|
82 |
+
except Exception as e:
|
83 |
+
# Log the error and raise HTTPException for FastAPI
|
84 |
+
print(f"An error occurred in update data.: {e}")
|
85 |
+
raise HTTPException(
|
86 |
+
status_code=400, detail="the error when deleting message"
|
87 |
+
)
|
88 |
+
|
89 |
|
90 |
@router.get("/bot/{category_id}/{title}")
|
91 |
async def get_favourite_data(category_id: int, title: str, human_template):
|
api/router/reader.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, File, UploadFile, HTTPException
|
2 |
+
from core.journal_reading.upload import upload_file
|
3 |
+
|
4 |
+
router = APIRouter(tags=["Journal Reading"])
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
@router.post("/upload")
|
9 |
+
async def upload_journal(file: UploadFile = File(...)):
|
10 |
+
try :
|
11 |
+
documents = await upload_file(file)
|
12 |
+
|
13 |
+
return {"Success"}
|
14 |
+
except Exception as e:
|
15 |
+
raise HTTPException(
|
16 |
+
status_code=400, detail=f"Error processing file: {str(e)}"
|
17 |
+
)
|
api/router/topic.py
CHANGED
@@ -1,12 +1,18 @@
|
|
1 |
from fastapi import Form, APIRouter, File, UploadFile, HTTPException, Request
|
2 |
from db.repository import get_db_conn
|
|
|
|
|
3 |
from config import MYSQL_CONFIG
|
4 |
from api.function import data_ingestion, get_data, delete_data, update_data
|
|
|
5 |
from service.dto import MetadataRequest
|
6 |
|
7 |
router = APIRouter(tags=["Topics"])
|
8 |
|
9 |
db_conn = get_db_conn(MYSQL_CONFIG)
|
|
|
|
|
|
|
10 |
|
11 |
@router.post("/topic")
|
12 |
async def upload_file(
|
@@ -40,11 +46,24 @@ async def get_metadata():
|
|
40 |
|
41 |
@router.put("/topic/{id}")
|
42 |
async def update_metadata(id: int, reference: MetadataRequest):
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
|
47 |
@router.delete("/topic/{id}")
|
48 |
async def delete_metadata(id: int):
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from fastapi import Form, APIRouter, File, UploadFile, HTTPException, Request
|
2 |
from db.repository import get_db_conn
|
3 |
+
from db.get_data import GetDatabase
|
4 |
+
from db.save_data import InsertDatabase
|
5 |
from config import MYSQL_CONFIG
|
6 |
from api.function import data_ingestion, get_data, delete_data, update_data
|
7 |
+
from script.vector_db import IndexManager
|
8 |
from service.dto import MetadataRequest
|
9 |
|
10 |
router = APIRouter(tags=["Topics"])
|
11 |
|
12 |
db_conn = get_db_conn(MYSQL_CONFIG)
|
13 |
+
get_database = GetDatabase(db_conn)
|
14 |
+
index_manager = IndexManager()
|
15 |
+
|
16 |
|
17 |
@router.post("/topic")
|
18 |
async def upload_file(
|
|
|
46 |
|
47 |
@router.put("/topic/{id}")
|
48 |
async def update_metadata(id: int, reference: MetadataRequest):
|
49 |
+
try :
|
50 |
+
old_reference = await get_database.get_data_by_id(id)
|
51 |
+
index_manager.update_vector_database(old_reference, reference)
|
52 |
+
|
53 |
+
return await update_data(id, reference, db_conn)
|
54 |
+
except Exception as e:
|
55 |
+
raise HTTPException(status_code=500, detail="An error occurred while updating metadata")
|
56 |
|
57 |
|
58 |
@router.delete("/topic/{id}")
|
59 |
async def delete_metadata(id: int):
|
60 |
+
try:
|
61 |
+
old_reference = await get_database.get_data_by_id(id)
|
62 |
+
index_manager.delete_vector_database(old_reference)
|
63 |
+
|
64 |
+
return await delete_data(id, db_conn)
|
65 |
+
|
66 |
+
except Exception as e:
|
67 |
+
print(e)
|
68 |
+
raise HTTPException(status_code=500, detail="An error occurred while delete metadata")
|
69 |
+
|
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from fastapi.applications import FastAPI
|
2 |
-
from api.router import health, topic, user, bot, trial, role
|
3 |
from fastapi.middleware.cors import CORSMiddleware
|
4 |
from api.events import register_events
|
5 |
from utils.utils import pipe
|
@@ -27,6 +27,7 @@ def register_routers(app: FastAPI) -> FastAPI:
|
|
27 |
app.include_router(bot.router)
|
28 |
app.include_router(trial.router)
|
29 |
app.include_router(role.router)
|
|
|
30 |
app.include_router(health.router)
|
31 |
|
32 |
return app
|
|
|
1 |
from fastapi.applications import FastAPI
|
2 |
+
from api.router import health, topic, user, bot, trial, role, reader
|
3 |
from fastapi.middleware.cors import CORSMiddleware
|
4 |
from api.events import register_events
|
5 |
from utils.utils import pipe
|
|
|
27 |
app.include_router(bot.router)
|
28 |
app.include_router(trial.router)
|
29 |
app.include_router(role.router)
|
30 |
+
app.include_router(reader.router)
|
31 |
app.include_router(health.router)
|
32 |
|
33 |
return app
|
core/chat/chatstore.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import redis
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
from fastapi import HTTPException
|
5 |
+
from uuid import uuid4
|
6 |
+
from typing import Optional, List
|
7 |
+
from llama_index.storage.chat_store.redis import RedisChatStore
|
8 |
+
from llama_index.core.memory import ChatMemoryBuffer
|
9 |
+
from service.dto import ChatMessage
|
10 |
+
|
11 |
+
|
12 |
+
class ChatStore:
|
13 |
+
def __init__(self):
|
14 |
+
self.redis_client = redis.Redis(
|
15 |
+
host="redis-10365.c244.us-east-1-2.ec2.redns.redis-cloud.com",
|
16 |
+
port=10365,
|
17 |
+
password=os.environ.get("REDIS_PASSWORD"),
|
18 |
+
)
|
19 |
+
|
20 |
+
def generate_uuid(use_hex=False):
|
21 |
+
if use_hex:
|
22 |
+
return str(uuid4().hex)
|
23 |
+
else:
|
24 |
+
return str(uuid4())
|
25 |
+
|
26 |
+
def initialize_memory_bot(self, session_id=None):
|
27 |
+
if session_id is None:
|
28 |
+
session_id = self.generate_uuid()
|
29 |
+
# chat_store = SimpleChatStore()
|
30 |
+
chat_store = RedisChatStore(
|
31 |
+
redis_client=self.redis_client
|
32 |
+
) # Need to be configured
|
33 |
+
|
34 |
+
memory = ChatMemoryBuffer.from_defaults(
|
35 |
+
token_limit=3000, chat_store=chat_store, chat_store_key=session_id
|
36 |
+
)
|
37 |
+
|
38 |
+
return memory
|
39 |
+
|
40 |
+
def get_messages(self, session_id: str) -> List[dict]:
|
41 |
+
"""Get messages for a session_id."""
|
42 |
+
items = self.redis_client.lrange(session_id, 0, -1)
|
43 |
+
if len(items) == 0:
|
44 |
+
return []
|
45 |
+
|
46 |
+
# Decode and parse each item into a dictionary
|
47 |
+
return [json.loads(m.decode("utf-8")) for m in items]
|
48 |
+
|
49 |
+
def delete_last_message(self, session_id: str) -> Optional[ChatMessage]:
|
50 |
+
"""Delete last message for a session_id."""
|
51 |
+
return self.redis_client.rpop(session_id)
|
52 |
+
|
53 |
+
def delete_messages(self, key: str) -> Optional[List[ChatMessage]]:
|
54 |
+
"""Delete messages for a key."""
|
55 |
+
self.redis_client.delete(key)
|
56 |
+
return None
|
57 |
+
|
58 |
+
def clean_message(self, session_id: str) -> Optional[ChatMessage]:
|
59 |
+
"""Delete specific message for a session_id."""
|
60 |
+
current_list = self.redis_client.lrange(session_id, 0, -1)
|
61 |
+
|
62 |
+
indices_to_delete = []
|
63 |
+
for index, item in enumerate(current_list):
|
64 |
+
data = json.loads(item) # Parse JSON string to dict
|
65 |
+
|
66 |
+
# Logic to determine if item should be removed
|
67 |
+
if (data.get("role") == "assistant" and data.get("content") is None) or (data.get("role") == "tool"):
|
68 |
+
indices_to_delete.append(index)
|
69 |
+
|
70 |
+
# Remove elements by their indices in reverse order
|
71 |
+
for index in reversed(indices_to_delete):
|
72 |
+
self.redis_client.lrem(session_id, 1, current_list[index]) # Remove the element from the list in Redis
|
73 |
+
|
74 |
+
def get_keys(self) -> List[str]:
|
75 |
+
"""Get all keys."""
|
76 |
+
try :
|
77 |
+
print(self.redis_client.keys("*"))
|
78 |
+
return [key.decode("utf-8") for key in self.redis_client.keys("*")]
|
79 |
+
|
80 |
+
except Exception as e:
|
81 |
+
# Log the error and raise HTTPException for FastAPI
|
82 |
+
print(f"An error occurred in update data.: {e}")
|
83 |
+
raise HTTPException(
|
84 |
+
status_code=400, detail="the error when get keys"
|
85 |
+
)
|
86 |
+
|
87 |
+
def add_message(self, session_id: str, message: ChatMessage) -> None:
|
88 |
+
"""Add a message for a session_id."""
|
89 |
+
item = json.dumps(self._message_to_dict(message))
|
90 |
+
self.redis_client.rpush(session_id, item)
|
91 |
+
|
92 |
+
def _message_to_dict(self, message: ChatMessage) -> dict:
|
93 |
+
return message.model_dump()
|
core/chat/engine.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import
|
2 |
from llama_index.core.vector_stores import (
|
3 |
MetadataFilter,
|
4 |
MetadataFilters,
|
@@ -10,15 +10,17 @@ from llama_index.core.tools import QueryEngineTool, ToolMetadata
|
|
10 |
from llama_index.agent.openai import OpenAIAgent
|
11 |
from llama_index.llms.openai import OpenAI
|
12 |
from llama_index.storage.chat_store.redis import RedisChatStore
|
13 |
-
from llama_index.core.storage.chat_store import SimpleChatStore
|
14 |
from llama_index.core.memory import ChatMemoryBuffer
|
15 |
from llama_index.core.query_engine import CitationQueryEngine
|
16 |
from llama_index.core import Settings
|
|
|
17 |
|
|
|
18 |
from config import GPTBOT_CONFIG
|
19 |
from core.prompt import SYSTEM_BOT_TEMPLATE
|
20 |
import redis
|
21 |
import os
|
|
|
22 |
|
23 |
|
24 |
class Engine:
|
@@ -29,22 +31,10 @@ class Engine:
|
|
29 |
max_tokens=GPTBOT_CONFIG.max_tokens,
|
30 |
api_key=GPTBOT_CONFIG.api_key,
|
31 |
)
|
32 |
-
|
33 |
-
Settings.llm = self.llm
|
34 |
|
35 |
-
|
36 |
-
redis_client = redis.Redis(
|
37 |
-
host="redis-10365.c244.us-east-1-2.ec2.redns.redis-cloud.com",
|
38 |
-
port=10365,
|
39 |
-
password=os.environ.get("REDIS_PASSWORD"),
|
40 |
-
)
|
41 |
-
# chat_store = SimpleChatStore()
|
42 |
-
chat_store = RedisChatStore(redis_client=redis_client, ttl=3600) # Need to be configured
|
43 |
-
memory = ChatMemoryBuffer.from_defaults(
|
44 |
-
token_limit=3000, chat_store=chat_store, chat_store_key=user_id
|
45 |
-
)
|
46 |
|
47 |
-
|
48 |
|
49 |
def _build_description_bot(self, title, category):
|
50 |
try:
|
@@ -56,22 +46,6 @@ class Engine:
|
|
56 |
except Exception as e:
|
57 |
return f"Error generating description: {str(e)}"
|
58 |
|
59 |
-
def index_to_query_engine(self, title, category, index):
|
60 |
-
filters = MetadataFilters(
|
61 |
-
filters=[
|
62 |
-
MetadataFilter(key="title", value=title),
|
63 |
-
MetadataFilter(key="category", value=category),
|
64 |
-
],
|
65 |
-
condition=FilterCondition.AND,
|
66 |
-
)
|
67 |
-
|
68 |
-
# Create the QueryEngineTool with the index and filters
|
69 |
-
kwargs = {"similarity_top_k": 5, "filters": filters}
|
70 |
-
|
71 |
-
query_engine = index.as_query_engine(**kwargs)
|
72 |
-
|
73 |
-
return query_engine
|
74 |
-
|
75 |
def get_citation_engine(self, title, category, index):
|
76 |
filters = MetadataFilters(
|
77 |
filters=[
|
@@ -80,39 +54,33 @@ class Engine:
|
|
80 |
],
|
81 |
condition=FilterCondition.AND,
|
82 |
)
|
83 |
-
|
84 |
-
# Create the QueryEngineTool with the index and filters
|
85 |
kwargs = {"similarity_top_k": 5, "filters": filters}
|
86 |
-
|
87 |
retriever = index.as_retriever(**kwargs)
|
88 |
-
|
89 |
citation_engine = CitationQueryEngine(retriever=retriever)
|
90 |
|
91 |
return citation_engine
|
92 |
-
|
93 |
-
|
94 |
-
def get_chat_engine(self, index, title=None, category=None, type="general"):
|
95 |
-
# Define the metadata for the QueryEngineTool
|
96 |
|
|
|
|
|
|
|
97 |
# Create the QueryEngineTool based on the type
|
98 |
if type == "general":
|
99 |
# query_engine = index.as_query_engine(similarity_top_k=3)
|
100 |
citation_engine = CitationQueryEngine.from_args(index, similarity_top_k=5)
|
101 |
description = "A book containing information about medicine"
|
102 |
else:
|
103 |
-
query_engine = self.index_to_query_engine(title, category, index)
|
104 |
citation_engine = self.get_citation_engine(title, category, index)
|
105 |
description = self._build_description_bot()
|
106 |
|
107 |
-
metadata = ToolMetadata(
|
108 |
-
name="bot-belajar",
|
109 |
-
description=description
|
110 |
-
)
|
111 |
print(metadata)
|
112 |
|
113 |
vector_query_engine = QueryEngineTool(
|
114 |
-
query_engine=citation_engine,
|
115 |
-
metadata=metadata
|
116 |
)
|
117 |
print(vector_query_engine)
|
118 |
|
@@ -120,11 +88,9 @@ class Engine:
|
|
120 |
chat_engine = OpenAIAgent.from_tools(
|
121 |
tools=[vector_query_engine],
|
122 |
llm=self.llm,
|
123 |
-
memory=self.initialize_memory_bot(),
|
|
|
124 |
system_prompt=SYSTEM_BOT_TEMPLATE,
|
125 |
)
|
126 |
|
127 |
-
return chat_engine
|
128 |
-
|
129 |
-
def get_chat_history(self):
|
130 |
-
pass
|
|
|
1 |
+
from typing import Optional, List
|
2 |
from llama_index.core.vector_stores import (
|
3 |
MetadataFilter,
|
4 |
MetadataFilters,
|
|
|
10 |
from llama_index.agent.openai import OpenAIAgent
|
11 |
from llama_index.llms.openai import OpenAI
|
12 |
from llama_index.storage.chat_store.redis import RedisChatStore
|
|
|
13 |
from llama_index.core.memory import ChatMemoryBuffer
|
14 |
from llama_index.core.query_engine import CitationQueryEngine
|
15 |
from llama_index.core import Settings
|
16 |
+
from core.chat.chatstore import ChatStore
|
17 |
|
18 |
+
from service.dto import ChatMessage
|
19 |
from config import GPTBOT_CONFIG
|
20 |
from core.prompt import SYSTEM_BOT_TEMPLATE
|
21 |
import redis
|
22 |
import os
|
23 |
+
import json
|
24 |
|
25 |
|
26 |
class Engine:
|
|
|
31 |
max_tokens=GPTBOT_CONFIG.max_tokens,
|
32 |
api_key=GPTBOT_CONFIG.api_key,
|
33 |
)
|
|
|
|
|
34 |
|
35 |
+
self.chat_store = ChatStore()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
Settings.llm = self.llm
|
38 |
|
39 |
def _build_description_bot(self, title, category):
|
40 |
try:
|
|
|
46 |
except Exception as e:
|
47 |
return f"Error generating description: {str(e)}"
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def get_citation_engine(self, title, category, index):
|
50 |
filters = MetadataFilters(
|
51 |
filters=[
|
|
|
54 |
],
|
55 |
condition=FilterCondition.AND,
|
56 |
)
|
57 |
+
|
58 |
+
# Create the QueryEngineTool with the index and filters
|
59 |
kwargs = {"similarity_top_k": 5, "filters": filters}
|
60 |
+
|
61 |
retriever = index.as_retriever(**kwargs)
|
62 |
+
|
63 |
citation_engine = CitationQueryEngine(retriever=retriever)
|
64 |
|
65 |
return citation_engine
|
|
|
|
|
|
|
|
|
66 |
|
67 |
+
def get_chat_engine(
|
68 |
+
self, session_id, index, title=None, category=None, type="general"
|
69 |
+
):
|
70 |
# Create the QueryEngineTool based on the type
|
71 |
if type == "general":
|
72 |
# query_engine = index.as_query_engine(similarity_top_k=3)
|
73 |
citation_engine = CitationQueryEngine.from_args(index, similarity_top_k=5)
|
74 |
description = "A book containing information about medicine"
|
75 |
else:
|
|
|
76 |
citation_engine = self.get_citation_engine(title, category, index)
|
77 |
description = self._build_description_bot()
|
78 |
|
79 |
+
metadata = ToolMetadata(name="bot-belajar", description=description)
|
|
|
|
|
|
|
80 |
print(metadata)
|
81 |
|
82 |
vector_query_engine = QueryEngineTool(
|
83 |
+
query_engine=citation_engine, metadata=metadata
|
|
|
84 |
)
|
85 |
print(vector_query_engine)
|
86 |
|
|
|
88 |
chat_engine = OpenAIAgent.from_tools(
|
89 |
tools=[vector_query_engine],
|
90 |
llm=self.llm,
|
91 |
+
memory=self.chat_store.initialize_memory_bot(session_id),
|
92 |
+
# memory = self.initialize_memory_bot(session_id),
|
93 |
system_prompt=SYSTEM_BOT_TEMPLATE,
|
94 |
)
|
95 |
|
96 |
+
return chat_engine
|
|
|
|
|
|
core/chat/messaging.py
CHANGED
@@ -20,7 +20,7 @@ from core.chat import schema
|
|
20 |
from db.db import MessageSubProcessSourceEnum
|
21 |
from core.chat.schema import SubProcessMetadataKeysEnum, SubProcessMetadataMap
|
22 |
from core.chat.engine import Engine
|
23 |
-
from script.
|
24 |
from service.dto import UserPromptRequest
|
25 |
|
26 |
logger = logging.getLogger(__name__)
|
|
|
20 |
from db.db import MessageSubProcessSourceEnum
|
21 |
from core.chat.schema import SubProcessMetadataKeysEnum, SubProcessMetadataMap
|
22 |
from core.chat.engine import Engine
|
23 |
+
from script.vector_db import IndexManager
|
24 |
from service.dto import UserPromptRequest
|
25 |
|
26 |
logger = logging.getLogger(__name__)
|
core/prompt.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
SYSTEM_BOT_TEMPLATE = """
|
2 |
-
Kamu adalah Medbot
|
3 |
|
4 |
**Instruksi**:
|
5 |
|
|
|
1 |
SYSTEM_BOT_TEMPLATE = """
|
2 |
+
Kamu adalah Medbot yang gunakan tool kamu untuk menjawab pertanyaan tentang kedokteran. Tugasmu adalah memberikan jawaban yang informatif dan akurat berdasarkan tools yang tersediaserta selalu cantumkan kutipan dari teks yang anda kutip. Jika tidak ada jawaban melalui alat yang digunakan, carilah informasi lebih lanjut dengan menggunakan alat. Jika setelah itu tidak ada informasi yang ditemukan, katakan bahwa kamu tidak mengetahuinya.
|
3 |
|
4 |
**Instruksi**:
|
5 |
|
db/get_data.py
CHANGED
@@ -6,9 +6,6 @@ logging.basicConfig(level=logging.INFO)
|
|
6 |
|
7 |
|
8 |
class GetDatabase(Repository):
|
9 |
-
def __init__(self, db_conn):
|
10 |
-
super().__init__(db_conn)
|
11 |
-
|
12 |
async def execute_query(self, query, params=None, fetch_one=False):
|
13 |
"""
|
14 |
|
@@ -54,3 +51,17 @@ class GetDatabase(Repository):
|
|
54 |
"""
|
55 |
results = await self.execute_query(query)
|
56 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
class GetDatabase(Repository):
|
|
|
|
|
|
|
9 |
async def execute_query(self, query, params=None, fetch_one=False):
|
10 |
"""
|
11 |
|
|
|
51 |
"""
|
52 |
results = await self.execute_query(query)
|
53 |
return results
|
54 |
+
|
55 |
+
async def get_data_by_id(self, id):
|
56 |
+
query = f"""
|
57 |
+
SELECT * FROM Metadata WHERE id = :id
|
58 |
+
"""
|
59 |
+
|
60 |
+
param = {"id" : id}
|
61 |
+
try:
|
62 |
+
results = await self.execute_query(query, param)
|
63 |
+
print('Query successful, results: %s', results)
|
64 |
+
return results[0] if results else None
|
65 |
+
except Exception as e:
|
66 |
+
print('Error fetching data by ID %s: %s', id, e)
|
67 |
+
return None
|
script/get_metadata.py
CHANGED
@@ -3,32 +3,33 @@
|
|
3 |
|
4 |
class Metadata:
|
5 |
def __init__(self, reference):
|
6 |
-
self.
|
7 |
-
self.author = reference["author"]
|
8 |
-
self.category = reference["category"]
|
9 |
-
self.year = reference["year"]
|
10 |
-
self.publisher = reference["publisher"]
|
11 |
|
12 |
def add_metadata(self, documents, metadata):
|
13 |
-
"""Add metadata to each
|
14 |
-
for document in documents:
|
|
|
15 |
if not hasattr(document, "metadata") or document.metadata is None:
|
16 |
document.metadata = {}
|
|
|
|
|
|
|
17 |
document.metadata.update(metadata)
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
21 |
return documents
|
22 |
|
23 |
def _generate_metadata(self):
|
24 |
"""Generate metadata and return it."""
|
25 |
metadata = {
|
26 |
-
"title": self.title,
|
27 |
-
"author": self.author,
|
28 |
-
"category": self.category,
|
29 |
-
"year": self.year,
|
30 |
-
"publisher": self.publisher,
|
31 |
-
"reference": f"{self.author}. ({self.year}). *{self.title}*. {self.publisher}."
|
32 |
}
|
33 |
print("metadata is generated")
|
34 |
return metadata
|
|
|
3 |
|
4 |
class Metadata:
|
5 |
def __init__(self, reference):
|
6 |
+
self.reference = reference
|
|
|
|
|
|
|
|
|
7 |
|
8 |
def add_metadata(self, documents, metadata):
|
9 |
+
"""Add metadata to each document and include page number."""
|
10 |
+
for page_number, document in enumerate(documents, start=1):
|
11 |
+
# Ensure the document has a metadata attribute
|
12 |
if not hasattr(document, "metadata") or document.metadata is None:
|
13 |
document.metadata = {}
|
14 |
+
|
15 |
+
# Update metadata with page number
|
16 |
+
document.metadata["page"] = page_number
|
17 |
document.metadata.update(metadata)
|
18 |
+
|
19 |
+
print(f"Metadata added to page {page_number}")
|
20 |
+
# self.logger.log_action(f"Metadata added to document {document.id_}", action_type="METADATA")
|
21 |
+
|
22 |
return documents
|
23 |
|
24 |
def _generate_metadata(self):
|
25 |
"""Generate metadata and return it."""
|
26 |
metadata = {
|
27 |
+
"title": self.reference["title"],
|
28 |
+
"author": self.reference["author"],
|
29 |
+
"category": self.reference["category"],
|
30 |
+
"year": self.reference["year"],
|
31 |
+
"publisher": self.reference["publisher"],
|
32 |
+
"reference": f"{self.reference['author']}. ({self.reference['year']}). *{self.reference['title']}*. {self.reference['publisher']}." # APA style reference
|
33 |
}
|
34 |
print("metadata is generated")
|
35 |
return metadata
|
script/vector_db.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llama_index.core import VectorStoreIndex
|
2 |
+
from llama_index.core import StorageContext
|
3 |
+
from pinecone import Pinecone, ServerlessSpec
|
4 |
+
from llama_index.llms.openai import OpenAI
|
5 |
+
from llama_index.vector_stores.pinecone import PineconeVectorStore
|
6 |
+
from fastapi import HTTPException, status
|
7 |
+
from config import PINECONE_CONFIG
|
8 |
+
from math import ceil
|
9 |
+
import numpy as np
|
10 |
+
import os
|
11 |
+
import json
|
12 |
+
|
13 |
+
|
14 |
+
class IndexManager:
|
15 |
+
def __init__(self, index_name: str = "summarizer-semantic-index"):
|
16 |
+
self.vector_index = None
|
17 |
+
self.index_name = index_name
|
18 |
+
self.client = self._get_pinecone_client()
|
19 |
+
self.pinecone_index = self._create_pinecone_index()
|
20 |
+
|
21 |
+
def _get_pinecone_client(self):
|
22 |
+
"""Initialize and return the Pinecone client."""
|
23 |
+
# api_key = os.getenv("PINECONE_API_KEY")
|
24 |
+
api_key = PINECONE_CONFIG.PINECONE_API_KEY
|
25 |
+
if not api_key:
|
26 |
+
raise ValueError(
|
27 |
+
"Pinecone API key is missing. Please set it in environment variables."
|
28 |
+
)
|
29 |
+
return Pinecone(api_key=api_key)
|
30 |
+
|
31 |
+
def _create_pinecone_index(self):
|
32 |
+
"""Create Pinecone index if it doesn't already exist."""
|
33 |
+
if self.index_name not in self.client.list_indexes().names():
|
34 |
+
self.client.create_index(
|
35 |
+
name=self.index_name,
|
36 |
+
dimension=1536,
|
37 |
+
metric="cosine",
|
38 |
+
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
|
39 |
+
)
|
40 |
+
return self.client.Index(self.index_name)
|
41 |
+
|
42 |
+
def _initialize_vector_store(self) -> StorageContext:
|
43 |
+
"""Initialize and return the vector store with the Pinecone index."""
|
44 |
+
vector_store = PineconeVectorStore(pinecone_index=self.pinecone_index)
|
45 |
+
return StorageContext.from_defaults(vector_store=vector_store)
|
46 |
+
|
47 |
+
|
48 |
+
def build_indexes(self, nodes):
|
49 |
+
"""Build vector and tree indexes from nodes."""
|
50 |
+
try:
|
51 |
+
storage_context = self._initialize_vector_store()
|
52 |
+
self.vector_index = VectorStoreIndex(nodes, storage_context=storage_context)
|
53 |
+
self.vector_index.set_index_id("vector")
|
54 |
+
|
55 |
+
print(f"Vector Index ID: {self.vector_index.index_id}")
|
56 |
+
print("Vector Index created successfully.")
|
57 |
+
|
58 |
+
return json.dumps({"status": "success", "message": "Vector Index loaded successfully."})
|
59 |
+
|
60 |
+
except HTTPException as http_exc:
|
61 |
+
raise http_exc # Re-raise HTTPExceptions to ensure FastAPI handles them
|
62 |
+
except Exception as e:
|
63 |
+
raise HTTPException(
|
64 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
65 |
+
detail=f"Error loading existing indexes: {str(e)}"
|
66 |
+
)
|
67 |
+
|
68 |
+
def get_ids_from_query(self, input_vector, title):
|
69 |
+
print("Searching Pinecone...")
|
70 |
+
print(title)
|
71 |
+
|
72 |
+
new_ids = set() # Initialize new_ids outside the loop
|
73 |
+
|
74 |
+
while True:
|
75 |
+
results = self.pinecone_index.query(
|
76 |
+
vector=input_vector,
|
77 |
+
top_k=10000,
|
78 |
+
filter={
|
79 |
+
"title": {"$eq": f"{title}"},
|
80 |
+
},
|
81 |
+
)
|
82 |
+
|
83 |
+
ids = set()
|
84 |
+
for result in results['matches']:
|
85 |
+
ids.add(result['id'])
|
86 |
+
# Check if there's any overlap between ids and new_ids
|
87 |
+
if ids.issubset(new_ids):
|
88 |
+
break
|
89 |
+
else:
|
90 |
+
new_ids.update(ids) # Add all new ids to new_ids
|
91 |
+
|
92 |
+
return new_ids
|
93 |
+
|
94 |
+
|
95 |
+
def get_all_ids_from_index(self, title):
|
96 |
+
num_dimensions = 1536
|
97 |
+
|
98 |
+
num_vectors = self.pinecone_index.describe_index_stats(
|
99 |
+
)["total_vector_count"]
|
100 |
+
|
101 |
+
print("Length of ids list is shorter than the number of total vectors...")
|
102 |
+
input_vector = np.random.rand(num_dimensions).tolist()
|
103 |
+
print("creating random vector...")
|
104 |
+
ids = self.get_ids_from_query(input_vector, title)
|
105 |
+
print("getting ids from a vector query...")
|
106 |
+
|
107 |
+
print("updating ids set...")
|
108 |
+
print(f"Collected {len(ids)} ids out of {num_vectors}.")
|
109 |
+
|
110 |
+
return ids
|
111 |
+
|
112 |
+
def delete_vector_database(self, old_reference):
|
113 |
+
try :
|
114 |
+
batch_size = 1000
|
115 |
+
all_ids = self.get_all_ids_from_index(old_reference['title'])
|
116 |
+
all_ids = list(all_ids)
|
117 |
+
|
118 |
+
# Split ids into chunks of batch_size
|
119 |
+
num_batches = ceil(len(all_ids) / batch_size)
|
120 |
+
|
121 |
+
for i in range(num_batches):
|
122 |
+
# Fetch a batch of IDs
|
123 |
+
batch_ids = all_ids[i * batch_size: (i + 1) * batch_size]
|
124 |
+
self.pinecone_index.delete(ids=batch_ids)
|
125 |
+
print(f"delete from id {i * batch_size} to {(i + 1) * batch_size} successful")
|
126 |
+
except Exception as e:
|
127 |
+
print(e)
|
128 |
+
raise HTTPException(status_code=500, detail="An error occurred while delete metadata")
|
129 |
+
|
130 |
+
def update_vector_database(self, old_reference, new_reference):
|
131 |
+
|
132 |
+
reference = new_reference.model_dump()
|
133 |
+
|
134 |
+
all_ids = self.get_all_ids_from_index(old_reference['title'])
|
135 |
+
all_ids = list(all_ids)
|
136 |
+
|
137 |
+
for id in all_ids:
|
138 |
+
self.pinecone_index.update(
|
139 |
+
id=id,
|
140 |
+
set_metadata=reference
|
141 |
+
)
|
142 |
+
|
143 |
+
|
144 |
+
def load_existing_indexes(self):
|
145 |
+
"""Load existing indexes from Pinecone."""
|
146 |
+
try:
|
147 |
+
client = self._get_pinecone_client()
|
148 |
+
pinecone_index = client.Index(self.index_name)
|
149 |
+
|
150 |
+
vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
|
151 |
+
retriever = VectorStoreIndex.from_vector_store(vector_store)
|
152 |
+
|
153 |
+
print("Existing Vector Index loaded successfully.")
|
154 |
+
return retriever
|
155 |
+
except Exception as e:
|
156 |
+
print(f"Error loading existing indexes: {e}")
|
157 |
+
raise
|
service/aws_loader.py
CHANGED
@@ -18,22 +18,55 @@ class Loader:
|
|
18 |
region_name="us-west-2",
|
19 |
)
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def upload_to_s3(self, file, object_name, folder_name="summarizer"):
|
22 |
try:
|
23 |
# If folder_name is provided, prepend it to the object_name
|
24 |
if folder_name:
|
25 |
object_name = f"{folder_name}/{object_name}"
|
26 |
|
27 |
-
#
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
|
34 |
-
self.s3_client.upload_fileobj(file_stream, self.bucket_name, object_name)
|
35 |
|
36 |
-
print(f"File '{object_name}' successfully uploaded to bucket '{self.bucket_name}'.")
|
37 |
except Exception as e:
|
38 |
raise HTTPException(status_code=400, detail=f"Error uploading to AWS: {e}")
|
39 |
|
|
|
18 |
region_name="us-west-2",
|
19 |
)
|
20 |
|
21 |
+
# def upload_to_s3(self, file, object_name, folder_name="summarizer"):
|
22 |
+
# try:
|
23 |
+
# # If folder_name is provided, prepend it to the object_name
|
24 |
+
# if folder_name:
|
25 |
+
# object_name = f"{folder_name}/{object_name}"
|
26 |
+
|
27 |
+
# # Create an in-memory file-like object
|
28 |
+
# with BytesIO() as file_stream:
|
29 |
+
# # Write the contents of the uploaded file to the stream
|
30 |
+
# file_stream.write(file.file.read())
|
31 |
+
# file_stream.seek(0) # Move to the beginning of the stream
|
32 |
+
|
33 |
+
# # Upload file to S3
|
34 |
+
# self.s3_client.upload_fileobj(file_stream, self.bucket_name, object_name)
|
35 |
+
|
36 |
+
# print(f"File '{object_name}' successfully uploaded to bucket '{self.bucket_name}'.")
|
37 |
+
# except Exception as e:
|
38 |
+
# raise HTTPException(status_code=400, detail=f"Error uploading to AWS: {e}")
|
39 |
+
|
40 |
def upload_to_s3(self, file, object_name, folder_name="summarizer"):
|
41 |
try:
|
42 |
# If folder_name is provided, prepend it to the object_name
|
43 |
if folder_name:
|
44 |
object_name = f"{folder_name}/{object_name}"
|
45 |
|
46 |
+
# Open the PDF with PyMuPDF (fitz)
|
47 |
+
pdf_document = fitz.open(stream=file.file.read(), filetype="pdf")
|
48 |
+
|
49 |
+
# Loop through each page of the PDF
|
50 |
+
for page_num in range(pdf_document.page_count):
|
51 |
+
|
52 |
+
# Convert the page to bytes (as a separate PDF)
|
53 |
+
page_stream = BytesIO()
|
54 |
+
single_page_pdf = fitz.open() # Create a new PDF
|
55 |
+
single_page_pdf.insert_pdf(pdf_document, from_page=page_num, to_page=page_num)
|
56 |
+
single_page_pdf.save(page_stream)
|
57 |
+
single_page_pdf.close()
|
58 |
+
|
59 |
+
# Reset the stream position to the start
|
60 |
+
page_stream.seek(0)
|
61 |
+
|
62 |
+
# Define the object name for each page (e.g., 'summarizer/object_name/page_1.pdf')
|
63 |
+
page_object_name = f"{object_name}/{page_num + 1}.pdf"
|
64 |
+
|
65 |
+
# Upload each page to S3
|
66 |
+
self.s3_client.upload_fileobj(page_stream, self.bucket_name, page_object_name)
|
67 |
|
68 |
+
print(f"Page {page_num + 1} of '{object_name}' successfully uploaded as '{page_object_name}' to bucket '{self.bucket_name}'.")
|
|
|
69 |
|
|
|
70 |
except Exception as e:
|
71 |
raise HTTPException(status_code=400, detail=f"Error uploading to AWS: {e}")
|
72 |
|
service/dto.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from pydantic import BaseModel, Field
|
2 |
-
from typing import List, Optional, Dict
|
|
|
3 |
|
4 |
class MetadataRequest(BaseModel):
|
5 |
title: str
|
@@ -12,6 +13,7 @@ class DeleteById(BaseModel):
|
|
12 |
id : str
|
13 |
|
14 |
class UserPromptRequest(BaseModel):
|
|
|
15 |
prompt : str
|
16 |
streaming : bool
|
17 |
|
@@ -33,4 +35,14 @@ class BotResponseStreaming(BaseModel):
|
|
33 |
|
34 |
class TestStreaming(BaseModel):
|
35 |
role : str = "assistant"
|
36 |
-
content : str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from pydantic import BaseModel, Field
|
2 |
+
from typing import List, Optional, Dict, Any
|
3 |
+
from llama_index.core.llms import MessageRole
|
4 |
|
5 |
class MetadataRequest(BaseModel):
|
6 |
title: str
|
|
|
13 |
id : str
|
14 |
|
15 |
class UserPromptRequest(BaseModel):
|
16 |
+
session_id : str
|
17 |
prompt : str
|
18 |
streaming : bool
|
19 |
|
|
|
35 |
|
36 |
class TestStreaming(BaseModel):
|
37 |
role : str = "assistant"
|
38 |
+
content : str
|
39 |
+
|
40 |
+
class ChatMessage(BaseModel):
|
41 |
+
"""Chat message."""
|
42 |
+
|
43 |
+
role: MessageRole = MessageRole.ASSISTANT
|
44 |
+
content: Optional[Any] = ""
|
45 |
+
metadata: List
|
46 |
+
|
47 |
+
def __str__(self) -> str:
|
48 |
+
return f"{self.role.value}: {self.content}"
|
utils/utils.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
def pipe(data, *funcs):
|
2 |
""" Pipe a value through a sequence of functions
|
3 |
|
@@ -19,4 +21,10 @@ def pipe(data, *funcs):
|
|
19 |
"""
|
20 |
for func in funcs:
|
21 |
data = func(data)
|
22 |
-
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from uuid import uuid4
|
2 |
+
|
3 |
def pipe(data, *funcs):
|
4 |
""" Pipe a value through a sequence of functions
|
5 |
|
|
|
21 |
"""
|
22 |
for func in funcs:
|
23 |
data = func(data)
|
24 |
+
return data
|
25 |
+
|
26 |
+
def generate_uuid(use_hex=False):
|
27 |
+
if use_hex:
|
28 |
+
return str(uuid4().hex)
|
29 |
+
else:
|
30 |
+
return str(uuid4())
|