XThomasBU
commited on
Commit
·
8f6647c
1
Parent(s):
33e5fa6
init commit for chainlit improvements
Browse files- code/main.py +238 -170
- code/modules/chat/helpers.py +32 -30
- code/modules/chat/llm_tutor.py +15 -10
- code/modules/vectorstore/base.py +3 -0
- code/modules/vectorstore/chroma.py +3 -0
- code/modules/vectorstore/colbert.py +72 -0
- code/modules/vectorstore/faiss.py +10 -0
- code/modules/vectorstore/raptor.py +7 -0
- code/modules/vectorstore/store_manager.py +6 -2
- code/modules/vectorstore/vectorstore.py +3 -0
code/main.py
CHANGED
|
@@ -1,176 +1,244 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
from
|
| 4 |
-
from langchain_community.vectorstores import FAISS
|
| 5 |
-
from langchain.chains import RetrievalQA
|
| 6 |
import chainlit as cl
|
| 7 |
-
from
|
| 8 |
-
from
|
| 9 |
import yaml
|
| 10 |
-
import
|
| 11 |
-
from dotenv import load_dotenv
|
| 12 |
|
| 13 |
from modules.chat.llm_tutor import LLMTutor
|
| 14 |
-
from modules.config.constants import *
|
| 15 |
-
from modules.chat.helpers import get_sources
|
| 16 |
from modules.chat_processor.chat_processor import ChatProcessor
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
#
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
)
|
| 73 |
-
cl.
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import textwrap
|
| 3 |
+
from typing import Any, Callable, Dict, List, Literal, Optional, no_type_check
|
|
|
|
|
|
|
| 4 |
import chainlit as cl
|
| 5 |
+
from chainlit import run_sync
|
| 6 |
+
from chainlit.config import config
|
| 7 |
import yaml
|
| 8 |
+
import os
|
|
|
|
| 9 |
|
| 10 |
from modules.chat.llm_tutor import LLMTutor
|
|
|
|
|
|
|
| 11 |
from modules.chat_processor.chat_processor import ChatProcessor
|
| 12 |
+
from modules.config.constants import LLAMA_PATH
|
| 13 |
+
from modules.chat.helpers import get_sources
|
| 14 |
|
| 15 |
+
from chainlit.input_widget import Select, Switch, Slider
|
| 16 |
+
|
| 17 |
+
USER_TIMEOUT = 60_000
|
| 18 |
+
SYSTEM = "System 🖥️"
|
| 19 |
+
LLM = "LLM 🧠"
|
| 20 |
+
AGENT = "Agent <>"
|
| 21 |
+
YOU = "You 😃"
|
| 22 |
+
ERROR = "Error 🚫"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Chatbot:
|
| 26 |
+
def __init__(self):
|
| 27 |
+
self.llm_tutor = None
|
| 28 |
+
self.chain = None
|
| 29 |
+
self.chat_processor = None
|
| 30 |
+
self.config = self._load_config()
|
| 31 |
+
|
| 32 |
+
def _load_config(self):
|
| 33 |
+
with open("modules/config/config.yml", "r") as f:
|
| 34 |
+
config = yaml.safe_load(f)
|
| 35 |
+
return config
|
| 36 |
+
|
| 37 |
+
async def ask_helper(func, **kwargs):
|
| 38 |
+
res = await func(**kwargs).send()
|
| 39 |
+
while not res:
|
| 40 |
+
res = await func(**kwargs).send()
|
| 41 |
+
return res
|
| 42 |
+
|
| 43 |
+
@no_type_check
|
| 44 |
+
async def setup_llm(self) -> None:
|
| 45 |
+
"""From the session `llm_settings`, create new LLMConfig and LLM objects,
|
| 46 |
+
save them in session state."""
|
| 47 |
+
|
| 48 |
+
llm_settings = cl.user_session.get("llm_settings", {})
|
| 49 |
+
chat_profile = llm_settings.get("chat_model")
|
| 50 |
+
retriever_method = llm_settings.get("retriever_method")
|
| 51 |
+
memory_window = llm_settings.get("memory_window")
|
| 52 |
+
|
| 53 |
+
self._configure_llm(chat_profile)
|
| 54 |
+
|
| 55 |
+
chain = cl.user_session.get("chain")
|
| 56 |
+
memory = chain.memory
|
| 57 |
+
self.config["vectorstore"][
|
| 58 |
+
"db_option"
|
| 59 |
+
] = retriever_method # update the retriever method in the config
|
| 60 |
+
memory.k = memory_window # set the memory window
|
| 61 |
+
|
| 62 |
+
self.llm_tutor = LLMTutor(self.config)
|
| 63 |
+
self.chain = self.llm_tutor.qa_bot(memory=memory)
|
| 64 |
+
|
| 65 |
+
tags = [chat_profile, self.config["vectorstore"]["db_option"]]
|
| 66 |
+
self.chat_processor = ChatProcessor(self.config, tags=tags)
|
| 67 |
+
|
| 68 |
+
cl.user_session.set("chain", self.chain)
|
| 69 |
+
cl.user_session.set("llm_tutor", self.llm_tutor)
|
| 70 |
+
cl.user_session.set("chat_processor", self.chat_processor)
|
| 71 |
+
|
| 72 |
+
@no_type_check
|
| 73 |
+
async def update_llm(self, new_settings: Dict[str, Any]) -> None:
|
| 74 |
+
"""Update LLMConfig and LLM from settings, and save in session state."""
|
| 75 |
+
cl.user_session.set("llm_settings", new_settings)
|
| 76 |
+
await self.inform_llm_settings()
|
| 77 |
+
await self.setup_llm()
|
| 78 |
+
|
| 79 |
+
async def make_llm_settings_widgets(self, config=None):
|
| 80 |
+
config = config or self.config
|
| 81 |
+
await cl.ChatSettings(
|
| 82 |
+
[
|
| 83 |
+
cl.input_widget.Select(
|
| 84 |
+
id="chat_model",
|
| 85 |
+
label="Model Name (Default GPT-3)",
|
| 86 |
+
values=["llama", "gpt-3.5-turbo-1106", "gpt-4"],
|
| 87 |
+
initial_index=0,
|
| 88 |
+
),
|
| 89 |
+
cl.input_widget.Select(
|
| 90 |
+
id="retriever_method",
|
| 91 |
+
label="Retriever (Default FAISS)",
|
| 92 |
+
values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"],
|
| 93 |
+
initial_index=0,
|
| 94 |
+
),
|
| 95 |
+
cl.input_widget.Slider(
|
| 96 |
+
id="memory_window",
|
| 97 |
+
label="Memory Window (Default 3)",
|
| 98 |
+
initial=3,
|
| 99 |
+
min=0,
|
| 100 |
+
max=10,
|
| 101 |
+
step=1,
|
| 102 |
+
),
|
| 103 |
+
cl.input_widget.Switch(
|
| 104 |
+
id="view_sources", label="View Sources", initial=False
|
| 105 |
+
),
|
| 106 |
+
]
|
| 107 |
+
).send() # type: ignore
|
| 108 |
+
|
| 109 |
+
@no_type_check
|
| 110 |
+
async def inform_llm_settings(self) -> None:
|
| 111 |
+
llm_settings: Dict[str, Any] = cl.user_session.get("llm_settings", {})
|
| 112 |
+
llm_tutor = cl.user_session.get("llm_tutor")
|
| 113 |
+
settings_dict = dict(
|
| 114 |
+
model=llm_settings.get("chat_model"),
|
| 115 |
+
retriever=llm_settings.get("retriever_method"),
|
| 116 |
+
memory_window=llm_settings.get("memory_window"),
|
| 117 |
+
num_docs_in_db=len(llm_tutor.vector_db),
|
| 118 |
+
view_sources=llm_settings.get("view_sources"),
|
| 119 |
+
)
|
| 120 |
+
await cl.Message(
|
| 121 |
+
author=SYSTEM,
|
| 122 |
+
content="LLM settings have been updated. You can continue with your Query!",
|
| 123 |
+
elements=[
|
| 124 |
+
cl.Text(
|
| 125 |
+
name="settings",
|
| 126 |
+
display="side",
|
| 127 |
+
content=json.dumps(settings_dict, indent=4),
|
| 128 |
+
language="json",
|
| 129 |
+
)
|
| 130 |
+
],
|
| 131 |
+
).send()
|
| 132 |
+
|
| 133 |
+
async def set_starters(self):
|
| 134 |
+
return [
|
| 135 |
+
cl.Starter(
|
| 136 |
+
label="recording on CNNs?",
|
| 137 |
+
message="Where can I find the recording for the lecture on Transformers?",
|
| 138 |
+
icon="/public/adv-screen-recorder-svgrepo-com.svg",
|
| 139 |
+
),
|
| 140 |
+
cl.Starter(
|
| 141 |
+
label="where's the slides?",
|
| 142 |
+
message="When are the lectures? I can't find the schedule.",
|
| 143 |
+
icon="/public/alarmy-svgrepo-com.svg",
|
| 144 |
+
),
|
| 145 |
+
cl.Starter(
|
| 146 |
+
label="Due Date?",
|
| 147 |
+
message="When is the final project due?",
|
| 148 |
+
icon="/public/calendar-samsung-17-svgrepo-com.svg",
|
| 149 |
+
),
|
| 150 |
+
cl.Starter(
|
| 151 |
+
label="Explain backprop.",
|
| 152 |
+
message="I didn't understand the math behind backprop, could you explain it?",
|
| 153 |
+
icon="/public/acastusphoton-svgrepo-com.svg",
|
| 154 |
+
),
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
async def chat_profile(self):
|
| 158 |
+
return [
|
| 159 |
+
# cl.ChatProfile(
|
| 160 |
+
# name="gpt-3.5-turbo-1106",
|
| 161 |
+
# markdown_description="Use OpenAI API for **gpt-3.5-turbo-1106**.",
|
| 162 |
+
# ),
|
| 163 |
+
# cl.ChatProfile(
|
| 164 |
+
# name="gpt-4",
|
| 165 |
+
# markdown_description="Use OpenAI API for **gpt-4**.",
|
| 166 |
+
# ),
|
| 167 |
+
cl.ChatProfile(
|
| 168 |
+
name="Llama",
|
| 169 |
+
markdown_description="Use the local LLM: **Tiny Llama**.",
|
| 170 |
+
),
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
def rename(self, orig_author: str):
|
| 174 |
+
rename_dict = {"Chatbot": "AI Tutor"}
|
| 175 |
+
return rename_dict.get(orig_author, orig_author)
|
| 176 |
+
|
| 177 |
+
async def start(self):
|
| 178 |
+
await self.make_llm_settings_widgets(self.config)
|
| 179 |
+
|
| 180 |
+
chat_profile = cl.user_session.get("chat_profile")
|
| 181 |
+
if chat_profile:
|
| 182 |
+
self._configure_llm(chat_profile)
|
| 183 |
+
|
| 184 |
+
self.llm_tutor = LLMTutor(self.config)
|
| 185 |
+
self.chain = self.llm_tutor.qa_bot()
|
| 186 |
+
tags = [chat_profile, self.config["vectorstore"]["db_option"]]
|
| 187 |
+
self.chat_processor = ChatProcessor(self.config, tags=tags)
|
| 188 |
+
|
| 189 |
+
cl.user_session.set("llm_tutor", self.llm_tutor)
|
| 190 |
+
cl.user_session.set("chain", self.chain)
|
| 191 |
+
cl.user_session.set("counter", 0)
|
| 192 |
+
cl.user_session.set("chat_processor", self.chat_processor)
|
| 193 |
+
|
| 194 |
+
async def on_chat_end(self):
|
| 195 |
+
await cl.Message(content="Sorry, I have to go now. Goodbye!").send()
|
| 196 |
+
|
| 197 |
+
async def main(self, message):
|
| 198 |
+
user = cl.user_session.get("user")
|
| 199 |
+
chain = cl.user_session.get("chain")
|
| 200 |
+
counter = cl.user_session.get("counter")
|
| 201 |
+
llm_settings = cl.user_session.get("llm_settings")
|
| 202 |
+
|
| 203 |
+
counter += 1
|
| 204 |
+
cl.user_session.set("counter", counter)
|
| 205 |
+
|
| 206 |
+
cb = cl.AsyncLangchainCallbackHandler() # TODO: fix streaming here
|
| 207 |
+
cb.answer_reached = True
|
| 208 |
+
|
| 209 |
+
processor = cl.user_session.get("chat_processor")
|
| 210 |
+
res = await processor.rag(message.content, chain, cb)
|
| 211 |
+
answer = res.get("answer", res.get("result"))
|
| 212 |
+
|
| 213 |
+
answer_with_sources, source_elements, sources_dict = get_sources(
|
| 214 |
+
res, answer, view_sources=llm_settings.get("view_sources")
|
| 215 |
+
)
|
| 216 |
+
processor._process(message.content, answer, sources_dict)
|
| 217 |
+
|
| 218 |
+
await cl.Message(content=answer_with_sources, elements=source_elements).send()
|
| 219 |
+
|
| 220 |
+
def _configure_llm(self, chat_profile):
|
| 221 |
+
chat_profile = chat_profile.lower()
|
| 222 |
+
if chat_profile in ["gpt-3.5-turbo-1106", "gpt-4"]:
|
| 223 |
+
self.config["llm_params"]["llm_loader"] = "openai"
|
| 224 |
+
self.config["llm_params"]["openai_params"]["model"] = chat_profile
|
| 225 |
+
elif chat_profile == "llama":
|
| 226 |
+
self.config["llm_params"]["llm_loader"] = "local_llm"
|
| 227 |
+
self.config["llm_params"]["local_llm_params"]["model"] = LLAMA_PATH
|
| 228 |
+
self.config["llm_params"]["local_llm_params"]["model_type"] = "llama"
|
| 229 |
+
elif chat_profile == "mistral":
|
| 230 |
+
self.config["llm_params"]["llm_loader"] = "local_llm"
|
| 231 |
+
self.config["llm_params"]["local_llm_params"]["model"] = MISTRAL_PATH
|
| 232 |
+
self.config["llm_params"]["local_llm_params"]["model_type"] = "mistral"
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
chatbot = Chatbot()
|
| 236 |
+
|
| 237 |
+
# Register functions to Chainlit events
|
| 238 |
+
cl.set_starters(chatbot.set_starters)
|
| 239 |
+
cl.set_chat_profiles(chatbot.chat_profile)
|
| 240 |
+
cl.author_rename(chatbot.rename)
|
| 241 |
+
cl.on_chat_start(chatbot.start)
|
| 242 |
+
cl.on_chat_end(chatbot.on_chat_end)
|
| 243 |
+
cl.on_message(chatbot.main)
|
| 244 |
+
cl.on_settings_update(chatbot.update_llm)
|
code/modules/chat/helpers.py
CHANGED
|
@@ -3,7 +3,7 @@ import chainlit as cl
|
|
| 3 |
from langchain_core.prompts import PromptTemplate
|
| 4 |
|
| 5 |
|
| 6 |
-
def get_sources(res, answer):
|
| 7 |
source_elements = []
|
| 8 |
source_dict = {} # Dictionary to store URL elements
|
| 9 |
|
|
@@ -40,40 +40,42 @@ def get_sources(res, answer):
|
|
| 40 |
full_answer = "**Answer:**\n"
|
| 41 |
full_answer += answer
|
| 42 |
|
| 43 |
-
|
| 44 |
-
full_answer += "\n\n**Sources:**\n"
|
| 45 |
-
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
| 46 |
-
full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
|
| 47 |
|
| 48 |
-
|
| 49 |
-
full_answer +=
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
)
|
| 53 |
|
| 54 |
-
|
| 55 |
-
if source_data["url"].lower().endswith(".pdf"):
|
| 56 |
-
name = f"Source {idx + 1} PDF\n"
|
| 57 |
full_answer += name
|
| 58 |
-
|
| 59 |
-
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
name=
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
)
|
| 76 |
-
)
|
| 77 |
|
| 78 |
return full_answer, source_elements, source_dict
|
| 79 |
|
|
|
|
| 3 |
from langchain_core.prompts import PromptTemplate
|
| 4 |
|
| 5 |
|
| 6 |
+
def get_sources(res, answer, view_sources=False):
|
| 7 |
source_elements = []
|
| 8 |
source_dict = {} # Dictionary to store URL elements
|
| 9 |
|
|
|
|
| 40 |
full_answer = "**Answer:**\n"
|
| 41 |
full_answer += answer
|
| 42 |
|
| 43 |
+
if view_sources:
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
# Then, display the sources
|
| 46 |
+
full_answer += "\n\n**Sources:**\n"
|
| 47 |
+
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
| 48 |
+
full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
|
|
|
|
| 49 |
|
| 50 |
+
name = f"Source {idx + 1} Text\n"
|
|
|
|
|
|
|
| 51 |
full_answer += name
|
| 52 |
+
source_elements.append(
|
| 53 |
+
cl.Text(name=name, content=source_data["text"], display="side")
|
| 54 |
+
)
|
| 55 |
|
| 56 |
+
# Add a PDF element if the source is a PDF file
|
| 57 |
+
if source_data["url"].lower().endswith(".pdf"):
|
| 58 |
+
name = f"Source {idx + 1} PDF\n"
|
| 59 |
+
full_answer += name
|
| 60 |
+
pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
|
| 61 |
+
source_elements.append(cl.Pdf(name=name, url=pdf_url, display="side"))
|
| 62 |
+
|
| 63 |
+
full_answer += "\n**Metadata:**\n"
|
| 64 |
+
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
| 65 |
+
full_answer += f"\nSource {idx + 1} Metadata:\n"
|
| 66 |
+
source_elements.append(
|
| 67 |
+
cl.Text(
|
| 68 |
+
name=f"Source {idx + 1} Metadata",
|
| 69 |
+
content=f"Source: {source_data['url']}\n"
|
| 70 |
+
f"Page: {source_data['page']}\n"
|
| 71 |
+
f"Type: {source_data['source_type']}\n"
|
| 72 |
+
f"Date: {source_data['date']}\n"
|
| 73 |
+
f"TL;DR: {source_data['lecture_tldr']}\n"
|
| 74 |
+
f"Lecture Recording: {source_data['lecture_recording']}\n"
|
| 75 |
+
f"Suggested Readings: {source_data['suggested_readings']}\n",
|
| 76 |
+
display="side",
|
| 77 |
+
)
|
| 78 |
)
|
|
|
|
| 79 |
|
| 80 |
return full_answer, source_elements, source_dict
|
| 81 |
|
code/modules/chat/llm_tutor.py
CHANGED
|
@@ -157,18 +157,18 @@ class LLMTutor:
|
|
| 157 |
return prompt
|
| 158 |
|
| 159 |
# Retrieval QA Chain
|
| 160 |
-
def retrieval_qa_chain(self, llm, prompt, db):
|
| 161 |
|
| 162 |
retriever = Retriever(self.config)._return_retriever(db)
|
| 163 |
|
| 164 |
if self.config["llm_params"]["use_history"]:
|
| 165 |
-
memory
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
qa_chain = CustomConversationalRetrievalChain.from_llm(
|
| 173 |
llm=llm,
|
| 174 |
chain_type="stuff",
|
|
@@ -195,11 +195,16 @@ class LLMTutor:
|
|
| 195 |
return llm
|
| 196 |
|
| 197 |
# QA Model Function
|
| 198 |
-
def qa_bot(self):
|
| 199 |
db = self.vector_db.load_database()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
qa_prompt = self.set_custom_prompt()
|
| 201 |
qa = self.retrieval_qa_chain(
|
| 202 |
-
self.llm, qa_prompt, db
|
| 203 |
) # TODO: PROMPT is overwritten in CustomConversationalRetrievalChain
|
| 204 |
|
| 205 |
return qa
|
|
|
|
| 157 |
return prompt
|
| 158 |
|
| 159 |
# Retrieval QA Chain
|
| 160 |
+
def retrieval_qa_chain(self, llm, prompt, db, memory=None):
|
| 161 |
|
| 162 |
retriever = Retriever(self.config)._return_retriever(db)
|
| 163 |
|
| 164 |
if self.config["llm_params"]["use_history"]:
|
| 165 |
+
if memory is None:
|
| 166 |
+
memory = ConversationBufferWindowMemory(
|
| 167 |
+
k=self.config["llm_params"]["memory_window"],
|
| 168 |
+
memory_key="chat_history",
|
| 169 |
+
return_messages=True,
|
| 170 |
+
output_key="answer",
|
| 171 |
+
)
|
| 172 |
qa_chain = CustomConversationalRetrievalChain.from_llm(
|
| 173 |
llm=llm,
|
| 174 |
chain_type="stuff",
|
|
|
|
| 195 |
return llm
|
| 196 |
|
| 197 |
# QA Model Function
|
| 198 |
+
def qa_bot(self, memory=None):
|
| 199 |
db = self.vector_db.load_database()
|
| 200 |
+
# sanity check to see if there are any documents in the database
|
| 201 |
+
if len(db) == 0:
|
| 202 |
+
raise ValueError(
|
| 203 |
+
"No documents in the database. Populate the database first."
|
| 204 |
+
)
|
| 205 |
qa_prompt = self.set_custom_prompt()
|
| 206 |
qa = self.retrieval_qa_chain(
|
| 207 |
+
self.llm, qa_prompt, db, memory
|
| 208 |
) # TODO: PROMPT is overwritten in CustomConversationalRetrievalChain
|
| 209 |
|
| 210 |
return qa
|
code/modules/vectorstore/base.py
CHANGED
|
@@ -29,5 +29,8 @@ class VectorStoreBase:
|
|
| 29 |
"""
|
| 30 |
raise NotImplementedError
|
| 31 |
|
|
|
|
|
|
|
|
|
|
| 32 |
def __str__(self):
|
| 33 |
return self.__class__.__name__
|
|
|
|
| 29 |
"""
|
| 30 |
raise NotImplementedError
|
| 31 |
|
| 32 |
+
def __len__(self):
|
| 33 |
+
raise NotImplementedError
|
| 34 |
+
|
| 35 |
def __str__(self):
|
| 36 |
return self.__class__.__name__
|
code/modules/vectorstore/chroma.py
CHANGED
|
@@ -39,3 +39,6 @@ class ChromaVectorStore(VectorStoreBase):
|
|
| 39 |
|
| 40 |
def as_retriever(self):
|
| 41 |
return self.vectorstore.as_retriever()
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
def as_retriever(self):
|
| 41 |
return self.vectorstore.as_retriever()
|
| 42 |
+
|
| 43 |
+
def __len__(self):
|
| 44 |
+
return len(self.vectorstore)
|
code/modules/vectorstore/colbert.py
CHANGED
|
@@ -1,6 +1,67 @@
|
|
| 1 |
from ragatouille import RAGPretrainedModel
|
| 2 |
from modules.vectorstore.base import VectorStoreBase
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
class ColbertVectorStore(VectorStoreBase):
|
|
@@ -24,6 +85,7 @@ class ColbertVectorStore(VectorStoreBase):
|
|
| 24 |
document_ids=document_names,
|
| 25 |
document_metadatas=document_metadata,
|
| 26 |
)
|
|
|
|
| 27 |
|
| 28 |
def load_database(self):
|
| 29 |
path = os.path.join(
|
|
@@ -33,7 +95,17 @@ class ColbertVectorStore(VectorStoreBase):
|
|
| 33 |
self.vectorstore = RAGPretrainedModel.from_index(
|
| 34 |
f"{path}/colbert/indexes/new_idx"
|
| 35 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
return self.vectorstore
|
| 37 |
|
| 38 |
def as_retriever(self):
|
| 39 |
return self.vectorstore.as_retriever()
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from ragatouille import RAGPretrainedModel
|
| 2 |
from modules.vectorstore.base import VectorStoreBase
|
| 3 |
+
from langchain_core.retrievers import BaseRetriever
|
| 4 |
+
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun, Callbacks
|
| 5 |
+
from langchain_core.documents import Document
|
| 6 |
+
from typing import Any, List, Optional, Sequence
|
| 7 |
import os
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RAGatouilleLangChainRetrieverWithScore(BaseRetriever):
|
| 12 |
+
model: Any
|
| 13 |
+
kwargs: dict = {}
|
| 14 |
+
|
| 15 |
+
def _get_relevant_documents(
|
| 16 |
+
self,
|
| 17 |
+
query: str,
|
| 18 |
+
*,
|
| 19 |
+
run_manager: CallbackManagerForRetrieverRun, # noqa
|
| 20 |
+
) -> List[Document]:
|
| 21 |
+
"""Get documents relevant to a query."""
|
| 22 |
+
docs = self.model.search(query, **self.kwargs)
|
| 23 |
+
return [
|
| 24 |
+
Document(
|
| 25 |
+
page_content=doc["content"],
|
| 26 |
+
metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
|
| 27 |
+
)
|
| 28 |
+
for doc in docs
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
async def _aget_relevant_documents(
|
| 32 |
+
self,
|
| 33 |
+
query: str,
|
| 34 |
+
*,
|
| 35 |
+
run_manager: CallbackManagerForRetrieverRun, # noqa
|
| 36 |
+
) -> List[Document]:
|
| 37 |
+
"""Get documents relevant to a query."""
|
| 38 |
+
docs = self.model.search(query, **self.kwargs)
|
| 39 |
+
return [
|
| 40 |
+
Document(
|
| 41 |
+
page_content=doc["content"],
|
| 42 |
+
metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
|
| 43 |
+
)
|
| 44 |
+
for doc in docs
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class RAGPretrainedModel(RAGPretrainedModel):
|
| 49 |
+
"""
|
| 50 |
+
Adding len property to RAGPretrainedModel
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, *args, **kwargs):
|
| 54 |
+
super().__init__(*args, **kwargs)
|
| 55 |
+
self._document_count = 0
|
| 56 |
+
|
| 57 |
+
def set_document_count(self, count):
|
| 58 |
+
self._document_count = count
|
| 59 |
+
|
| 60 |
+
def __len__(self):
|
| 61 |
+
return self._document_count
|
| 62 |
+
|
| 63 |
+
def as_langchain_retriever(self, **kwargs: Any) -> BaseRetriever:
|
| 64 |
+
return RAGatouilleLangChainRetrieverWithScore(model=self, kwargs=kwargs)
|
| 65 |
|
| 66 |
|
| 67 |
class ColbertVectorStore(VectorStoreBase):
|
|
|
|
| 85 |
document_ids=document_names,
|
| 86 |
document_metadatas=document_metadata,
|
| 87 |
)
|
| 88 |
+
self.colbert.set_document_count(len(document_names))
|
| 89 |
|
| 90 |
def load_database(self):
|
| 91 |
path = os.path.join(
|
|
|
|
| 95 |
self.vectorstore = RAGPretrainedModel.from_index(
|
| 96 |
f"{path}/colbert/indexes/new_idx"
|
| 97 |
)
|
| 98 |
+
|
| 99 |
+
index_metadata = json.load(
|
| 100 |
+
open(f"{path}/colbert/indexes/new_idx/0.metadata.json")
|
| 101 |
+
)
|
| 102 |
+
num_documents = index_metadata["num_passages"]
|
| 103 |
+
self.vectorstore.set_document_count(num_documents)
|
| 104 |
+
|
| 105 |
return self.vectorstore
|
| 106 |
|
| 107 |
def as_retriever(self):
|
| 108 |
return self.vectorstore.as_retriever()
|
| 109 |
+
|
| 110 |
+
def __len__(self):
|
| 111 |
+
return len(self.vectorstore)
|
code/modules/vectorstore/faiss.py
CHANGED
|
@@ -3,6 +3,13 @@ from modules.vectorstore.base import VectorStoreBase
|
|
| 3 |
import os
|
| 4 |
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
class FaissVectorStore(VectorStoreBase):
|
| 7 |
def __init__(self, config):
|
| 8 |
self.config = config
|
|
@@ -43,3 +50,6 @@ class FaissVectorStore(VectorStoreBase):
|
|
| 43 |
|
| 44 |
def as_retriever(self):
|
| 45 |
return self.vectorstore.as_retriever()
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
|
| 5 |
|
| 6 |
+
class FAISS(FAISS):
|
| 7 |
+
"""To add length property to FAISS class"""
|
| 8 |
+
|
| 9 |
+
def __len__(self):
|
| 10 |
+
return self.index.ntotal
|
| 11 |
+
|
| 12 |
+
|
| 13 |
class FaissVectorStore(VectorStoreBase):
|
| 14 |
def __init__(self, config):
|
| 15 |
self.config = config
|
|
|
|
| 50 |
|
| 51 |
def as_retriever(self):
|
| 52 |
return self.vectorstore.as_retriever()
|
| 53 |
+
|
| 54 |
+
def __len__(self):
|
| 55 |
+
return len(self.vectorstore)
|
code/modules/vectorstore/raptor.py
CHANGED
|
@@ -16,6 +16,13 @@ from modules.vectorstore.base import VectorStoreBase
|
|
| 16 |
RANDOM_SEED = 42
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class RAPTORVectoreStore(VectorStoreBase):
|
| 20 |
def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
|
| 21 |
self.documents = documents
|
|
|
|
| 16 |
RANDOM_SEED = 42
|
| 17 |
|
| 18 |
|
| 19 |
+
class FAISS(FAISS):
|
| 20 |
+
"""To add length property to FAISS class"""
|
| 21 |
+
|
| 22 |
+
def __len__(self):
|
| 23 |
+
return self.index.ntotal
|
| 24 |
+
|
| 25 |
+
|
| 26 |
class RAPTORVectoreStore(VectorStoreBase):
|
| 27 |
def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
|
| 28 |
self.documents = documents
|
code/modules/vectorstore/store_manager.py
CHANGED
|
@@ -138,7 +138,7 @@ class VectorStoreManager:
|
|
| 138 |
self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
|
| 139 |
end_time = time.time() # End time for loading database
|
| 140 |
self.logger.info(
|
| 141 |
-
f"Time taken to load database: {end_time - start_time} seconds"
|
| 142 |
)
|
| 143 |
self.logger.info("Loaded database")
|
| 144 |
return self.loaded_vector_db
|
|
@@ -148,8 +148,12 @@ class VectorStoreManager:
|
|
| 148 |
self.vector_db._load_from_HF()
|
| 149 |
end_time = time.time()
|
| 150 |
self.logger.info(
|
| 151 |
-
f"Time taken to
|
| 152 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
|
| 155 |
if __name__ == "__main__":
|
|
|
|
| 138 |
self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
|
| 139 |
end_time = time.time() # End time for loading database
|
| 140 |
self.logger.info(
|
| 141 |
+
f"Time taken to load database {self.config['vectorstore']['db_option']} from Hugging Face: {end_time - start_time} seconds"
|
| 142 |
)
|
| 143 |
self.logger.info("Loaded database")
|
| 144 |
return self.loaded_vector_db
|
|
|
|
| 148 |
self.vector_db._load_from_HF()
|
| 149 |
end_time = time.time()
|
| 150 |
self.logger.info(
|
| 151 |
+
f"Time taken to Download database {self.config['vectorstore']['db_option']} from Hugging Face: {end_time - start_time} seconds"
|
| 152 |
)
|
| 153 |
+
self.logger.info("Downloaded database")
|
| 154 |
+
|
| 155 |
+
def __len__(self):
|
| 156 |
+
return len(self.vector_db)
|
| 157 |
|
| 158 |
|
| 159 |
if __name__ == "__main__":
|
code/modules/vectorstore/vectorstore.py
CHANGED
|
@@ -86,3 +86,6 @@ class VectorStore:
|
|
| 86 |
|
| 87 |
def _get_vectorstore(self):
|
| 88 |
return self.vectorstore
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
def _get_vectorstore(self):
|
| 88 |
return self.vectorstore
|
| 89 |
+
|
| 90 |
+
def __len__(self):
|
| 91 |
+
return self.vectorstore.__len__()
|