XThomasBU
commited on
Commit
·
f0018f2
1
Parent(s):
40de40e
Code to add metadata to the chunks
Browse files- .chainlit/config.toml +1 -1
- code/config.yml +4 -4
- code/modules/data_loader.py +115 -46
- code/modules/embedding_model_loader.py +2 -2
- code/modules/helpers.py +121 -53
- code/modules/llm_tutor.py +15 -9
- code/modules/vector_db.py +34 -9
- requirements.txt +1 -0
- storage/data/urls.txt +2 -0
.chainlit/config.toml
CHANGED
|
@@ -22,7 +22,7 @@ prompt_playground = true
|
|
| 22 |
unsafe_allow_html = false
|
| 23 |
|
| 24 |
# Process and display mathematical expressions. This can clash with "$" characters in messages.
|
| 25 |
-
latex =
|
| 26 |
|
| 27 |
# Authorize users to upload files with messages
|
| 28 |
multi_modal = true
|
|
|
|
| 22 |
unsafe_allow_html = false
|
| 23 |
|
| 24 |
# Process and display mathematical expressions. This can clash with "$" characters in messages.
|
| 25 |
+
latex = true
|
| 26 |
|
| 27 |
# Authorize users to upload files with messages
|
| 28 |
multi_modal = true
|
code/config.yml
CHANGED
|
@@ -2,14 +2,14 @@ embedding_options:
|
|
| 2 |
embedd_files: False # bool
|
| 3 |
data_path: 'storage/data' # str
|
| 4 |
url_file_path: 'storage/data/urls.txt' # str
|
| 5 |
-
expand_urls:
|
| 6 |
-
db_option : '
|
| 7 |
db_path : 'vectorstores' # str
|
| 8 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
| 9 |
search_top_k : 3 # int
|
| 10 |
-
score_threshold : 0.
|
| 11 |
llm_params:
|
| 12 |
-
use_history:
|
| 13 |
memory_window: 3 # int
|
| 14 |
llm_loader: 'local_llm' # str [local_llm, openai]
|
| 15 |
openai_params:
|
|
|
|
| 2 |
embedd_files: False # bool
|
| 3 |
data_path: 'storage/data' # str
|
| 4 |
url_file_path: 'storage/data/urls.txt' # str
|
| 5 |
+
expand_urls: False # bool
|
| 6 |
+
db_option : 'RAGatouille' # str [FAISS, Chroma, RAGatouille]
|
| 7 |
db_path : 'vectorstores' # str
|
| 8 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
| 9 |
search_top_k : 3 # int
|
| 10 |
+
score_threshold : 0.2 # float
|
| 11 |
llm_params:
|
| 12 |
+
use_history: False # bool
|
| 13 |
memory_window: 3 # int
|
| 14 |
llm_loader: 'local_llm' # str [local_llm, openai]
|
| 15 |
openai_params:
|
code/modules/data_loader.py
CHANGED
|
@@ -2,7 +2,7 @@ import os
|
|
| 2 |
import re
|
| 3 |
import requests
|
| 4 |
import pysrt
|
| 5 |
-
from
|
| 6 |
PyMuPDFLoader,
|
| 7 |
Docx2txtLoader,
|
| 8 |
YoutubeLoader,
|
|
@@ -16,6 +16,15 @@ import logging
|
|
| 16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 17 |
from langchain_experimental.text_splitter import SemanticChunker
|
| 18 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
|
@@ -58,23 +67,6 @@ class FileReader:
|
|
| 58 |
return None
|
| 59 |
|
| 60 |
def read_pdf(self, temp_file_path: str):
|
| 61 |
-
# parser = LlamaParse(
|
| 62 |
-
# api_key="",
|
| 63 |
-
# result_type="markdown",
|
| 64 |
-
# num_workers=4,
|
| 65 |
-
# verbose=True,
|
| 66 |
-
# language="en",
|
| 67 |
-
# )
|
| 68 |
-
# documents = parser.load_data(temp_file_path)
|
| 69 |
-
|
| 70 |
-
# with open("temp/output.md", "a") as f:
|
| 71 |
-
# for doc in documents:
|
| 72 |
-
# f.write(doc.text + "\n")
|
| 73 |
-
|
| 74 |
-
# markdown_path = "temp/output.md"
|
| 75 |
-
# loader = UnstructuredMarkdownLoader(markdown_path)
|
| 76 |
-
# loader = PyMuPDFLoader(temp_file_path) # This loader preserves more metadata
|
| 77 |
-
# return loader.load()
|
| 78 |
loader = self.pdf_reader.get_loader(temp_file_path)
|
| 79 |
documents = self.pdf_reader.get_documents(loader)
|
| 80 |
return documents
|
|
@@ -108,8 +100,6 @@ class FileReader:
|
|
| 108 |
class ChunkProcessor:
|
| 109 |
def __init__(self, config):
|
| 110 |
self.config = config
|
| 111 |
-
self.document_chunks_full = []
|
| 112 |
-
self.document_names = []
|
| 113 |
|
| 114 |
if config["splitter_options"]["use_splitter"]:
|
| 115 |
if config["splitter_options"]["split_by_token"]:
|
|
@@ -130,6 +120,17 @@ class ChunkProcessor:
|
|
| 130 |
self.splitter = None
|
| 131 |
logger.info("ChunkProcessor instance created")
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
def remove_delimiters(self, document_chunks: list):
|
| 134 |
for chunk in document_chunks:
|
| 135 |
for delimiter in self.config["splitter_options"]["delimiters_to_remove"]:
|
|
@@ -146,11 +147,23 @@ class ChunkProcessor:
|
|
| 146 |
logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
|
| 147 |
return document_chunks
|
| 148 |
|
| 149 |
-
def process_chunks(
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
| 151 |
document_chunks = self.splitter.split_documents(documents)
|
| 152 |
-
|
| 153 |
-
document_chunks = documents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
if self.config["splitter_options"]["remove_leftover_delimiters"]:
|
| 156 |
document_chunks = self.remove_delimiters(document_chunks)
|
|
@@ -161,38 +174,77 @@ class ChunkProcessor:
|
|
| 161 |
|
| 162 |
def get_chunks(self, file_reader, uploaded_files, weblinks):
|
| 163 |
self.document_chunks_full = []
|
| 164 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
for file_index, file_path in enumerate(uploaded_files):
|
| 167 |
file_name = os.path.basename(file_path)
|
| 168 |
file_type = file_name.split(".")[-1].lower()
|
| 169 |
|
| 170 |
-
try:
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
else:
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
self.
|
| 185 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
-
except Exception as e:
|
| 188 |
-
|
| 189 |
|
| 190 |
self.process_weblinks(file_reader, weblinks)
|
| 191 |
|
| 192 |
logger.info(
|
| 193 |
f"Total document chunks extracted: {len(self.document_chunks_full)}"
|
| 194 |
)
|
| 195 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
def process_weblinks(self, file_reader, weblinks):
|
| 198 |
if weblinks[0] != "":
|
|
@@ -206,9 +258,26 @@ class ChunkProcessor:
|
|
| 206 |
else:
|
| 207 |
documents = file_reader.read_html(link)
|
| 208 |
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
except Exception as e:
|
| 213 |
logger.error(
|
| 214 |
f"Error splitting link {link_index+1} : {link}: {str(e)}"
|
|
|
|
| 2 |
import re
|
| 3 |
import requests
|
| 4 |
import pysrt
|
| 5 |
+
from langchain_community.document_loaders import (
|
| 6 |
PyMuPDFLoader,
|
| 7 |
Docx2txtLoader,
|
| 8 |
YoutubeLoader,
|
|
|
|
| 16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 17 |
from langchain_experimental.text_splitter import SemanticChunker
|
| 18 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
| 19 |
+
from ragatouille import RAGPretrainedModel
|
| 20 |
+
from langchain.chains import LLMChain
|
| 21 |
+
from langchain.llms import OpenAI
|
| 22 |
+
from langchain import PromptTemplate
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from modules.helpers import get_lecture_metadata
|
| 26 |
+
except:
|
| 27 |
+
from helpers import get_lecture_metadata
|
| 28 |
|
| 29 |
logger = logging.getLogger(__name__)
|
| 30 |
|
|
|
|
| 67 |
return None
|
| 68 |
|
| 69 |
def read_pdf(self, temp_file_path: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
loader = self.pdf_reader.get_loader(temp_file_path)
|
| 71 |
documents = self.pdf_reader.get_documents(loader)
|
| 72 |
return documents
|
|
|
|
| 100 |
class ChunkProcessor:
|
| 101 |
def __init__(self, config):
|
| 102 |
self.config = config
|
|
|
|
|
|
|
| 103 |
|
| 104 |
if config["splitter_options"]["use_splitter"]:
|
| 105 |
if config["splitter_options"]["split_by_token"]:
|
|
|
|
| 120 |
self.splitter = None
|
| 121 |
logger.info("ChunkProcessor instance created")
|
| 122 |
|
| 123 |
+
# def extract_metadata(self, document_content):
|
| 124 |
+
|
| 125 |
+
# llm = OpenAI()
|
| 126 |
+
# prompt_template = PromptTemplate(
|
| 127 |
+
# input_variables=["document_content"],
|
| 128 |
+
# template="Extract metadata for this document:\n\n{document_content}\n\nMetadata:",
|
| 129 |
+
# )
|
| 130 |
+
# chain = LLMChain(llm=llm, prompt=prompt_template)
|
| 131 |
+
# metadata = chain.run(document_content=document_content)
|
| 132 |
+
# return metadata
|
| 133 |
+
|
| 134 |
def remove_delimiters(self, document_chunks: list):
|
| 135 |
for chunk in document_chunks:
|
| 136 |
for delimiter in self.config["splitter_options"]["delimiters_to_remove"]:
|
|
|
|
| 147 |
logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
|
| 148 |
return document_chunks
|
| 149 |
|
| 150 |
+
def process_chunks(
|
| 151 |
+
self, documents, file_type="txt", source="", page=0, metadata={}
|
| 152 |
+
):
|
| 153 |
+
documents = [Document(page_content=documents, source=source, page=page)]
|
| 154 |
+
if file_type == "txt":
|
| 155 |
document_chunks = self.splitter.split_documents(documents)
|
| 156 |
+
elif file_type == "pdf":
|
| 157 |
+
document_chunks = documents # Full page for now
|
| 158 |
+
|
| 159 |
+
# add the source and page number back to the metadata
|
| 160 |
+
for chunk in document_chunks:
|
| 161 |
+
chunk.metadata["source"] = source
|
| 162 |
+
chunk.metadata["page"] = page
|
| 163 |
+
|
| 164 |
+
# add the metadata extracted from the document
|
| 165 |
+
for key, value in metadata.items():
|
| 166 |
+
chunk.metadata[key] = value
|
| 167 |
|
| 168 |
if self.config["splitter_options"]["remove_leftover_delimiters"]:
|
| 169 |
document_chunks = self.remove_delimiters(document_chunks)
|
|
|
|
| 174 |
|
| 175 |
def get_chunks(self, file_reader, uploaded_files, weblinks):
|
| 176 |
self.document_chunks_full = []
|
| 177 |
+
self.parent_document_names = []
|
| 178 |
+
self.child_document_names = []
|
| 179 |
+
self.documents = []
|
| 180 |
+
self.document_metadata = []
|
| 181 |
+
|
| 182 |
+
lecture_metadata = get_lecture_metadata(
|
| 183 |
+
"https://dl4ds.github.io/sp2024/lectures/"
|
| 184 |
+
) # TODO: Use more efficiently
|
| 185 |
|
| 186 |
for file_index, file_path in enumerate(uploaded_files):
|
| 187 |
file_name = os.path.basename(file_path)
|
| 188 |
file_type = file_name.split(".")[-1].lower()
|
| 189 |
|
| 190 |
+
# try:
|
| 191 |
+
if file_type == "pdf":
|
| 192 |
+
documents = file_reader.read_pdf(file_path)
|
| 193 |
+
elif file_type == "txt":
|
| 194 |
+
documents = file_reader.read_txt(file_path)
|
| 195 |
+
elif file_type == "docx":
|
| 196 |
+
documents = file_reader.read_docx(file_path)
|
| 197 |
+
elif file_type == "srt":
|
| 198 |
+
documents = file_reader.read_srt(file_path)
|
| 199 |
+
else:
|
| 200 |
+
logger.warning(f"Unsupported file type: {file_type}")
|
| 201 |
+
continue
|
| 202 |
+
|
| 203 |
+
# full_text = ""
|
| 204 |
+
# for doc in documents:
|
| 205 |
+
# full_text += doc.page_content
|
| 206 |
+
# break # getting only first page for now
|
| 207 |
+
|
| 208 |
+
# extracted_metadata = self.extract_metadata(full_text)
|
| 209 |
+
|
| 210 |
+
for doc in documents:
|
| 211 |
+
page_num = doc.metadata.get("page", 0)
|
| 212 |
+
self.documents.append(doc.page_content)
|
| 213 |
+
self.document_metadata.append({"source": file_path, "page": page_num})
|
| 214 |
+
if "lecture" in file_path.lower():
|
| 215 |
+
metadata = lecture_metadata.get(file_path, {})
|
| 216 |
+
metadata["source_type"] = "lecture"
|
| 217 |
+
self.document_metadata[-1].update(metadata)
|
| 218 |
else:
|
| 219 |
+
metadata = {"source_type": "other"}
|
| 220 |
+
|
| 221 |
+
self.child_document_names.append(f"{file_name}_{page_num}")
|
| 222 |
+
|
| 223 |
+
self.parent_document_names.append(file_name)
|
| 224 |
+
if self.config["embedding_options"]["db_option"] not in ["RAGatouille"]:
|
| 225 |
+
document_chunks = self.process_chunks(
|
| 226 |
+
self.documents[-1],
|
| 227 |
+
file_type,
|
| 228 |
+
source=file_path,
|
| 229 |
+
page=page_num,
|
| 230 |
+
metadata=metadata,
|
| 231 |
+
)
|
| 232 |
+
self.document_chunks_full.extend(document_chunks)
|
| 233 |
|
| 234 |
+
# except Exception as e:
|
| 235 |
+
# logger.error(f"Error processing file {file_name}: {str(e)}")
|
| 236 |
|
| 237 |
self.process_weblinks(file_reader, weblinks)
|
| 238 |
|
| 239 |
logger.info(
|
| 240 |
f"Total document chunks extracted: {len(self.document_chunks_full)}"
|
| 241 |
)
|
| 242 |
+
return (
|
| 243 |
+
self.document_chunks_full,
|
| 244 |
+
self.child_document_names,
|
| 245 |
+
self.documents,
|
| 246 |
+
self.document_metadata,
|
| 247 |
+
)
|
| 248 |
|
| 249 |
def process_weblinks(self, file_reader, weblinks):
|
| 250 |
if weblinks[0] != "":
|
|
|
|
| 258 |
else:
|
| 259 |
documents = file_reader.read_html(link)
|
| 260 |
|
| 261 |
+
for doc in documents:
|
| 262 |
+
page_num = doc.metadata.get("page", 0)
|
| 263 |
+
self.documents.append(doc.page_content)
|
| 264 |
+
self.document_metadata.append(
|
| 265 |
+
{"source": link, "page": page_num}
|
| 266 |
+
)
|
| 267 |
+
self.child_document_names.append(f"{link}")
|
| 268 |
+
|
| 269 |
+
self.parent_document_names.append(link)
|
| 270 |
+
if self.config["embedding_options"]["db_option"] not in [
|
| 271 |
+
"RAGatouille"
|
| 272 |
+
]:
|
| 273 |
+
document_chunks = self.process_chunks(
|
| 274 |
+
self.documents[-1],
|
| 275 |
+
"txt",
|
| 276 |
+
source=link,
|
| 277 |
+
page=0,
|
| 278 |
+
metadata={"source_type": "webpage"},
|
| 279 |
+
)
|
| 280 |
+
self.document_chunks_full.extend(document_chunks)
|
| 281 |
except Exception as e:
|
| 282 |
logger.error(
|
| 283 |
f"Error splitting link {link_index+1} : {link}: {str(e)}"
|
code/modules/embedding_model_loader.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from langchain_community.embeddings import OpenAIEmbeddings
|
| 2 |
-
from
|
| 3 |
-
from
|
| 4 |
|
| 5 |
try:
|
| 6 |
from modules.constants import *
|
|
|
|
| 1 |
from langchain_community.embeddings import OpenAIEmbeddings
|
| 2 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 3 |
+
from langchain_community.embeddings import LlamaCppEmbeddings
|
| 4 |
|
| 5 |
try:
|
| 6 |
from modules.constants import *
|
code/modules/helpers.py
CHANGED
|
@@ -4,6 +4,8 @@ from tqdm import tqdm
|
|
| 4 |
from urllib.parse import urlparse
|
| 5 |
import chainlit as cl
|
| 6 |
from langchain import PromptTemplate
|
|
|
|
|
|
|
| 7 |
|
| 8 |
try:
|
| 9 |
from modules.constants import *
|
|
@@ -138,67 +140,133 @@ def get_prompt(config):
|
|
| 138 |
|
| 139 |
|
| 140 |
def get_sources(res, answer):
|
| 141 |
-
source_elements_dict = {}
|
| 142 |
source_elements = []
|
| 143 |
-
found_sources = []
|
| 144 |
-
|
| 145 |
source_dict = {} # Dictionary to store URL elements
|
| 146 |
|
| 147 |
for idx, source in enumerate(res["source_documents"]):
|
| 148 |
source_metadata = source.metadata
|
| 149 |
url = source_metadata["source"]
|
| 150 |
score = source_metadata.get("score", "N/A")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
else:
|
| 155 |
-
source_dict[
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
full_text += f"Source {url_idx + 1} (Score: {score}):\n{text}\n\n\n"
|
| 161 |
-
source_elements.append(cl.Text(name=url, content=full_text))
|
| 162 |
-
found_sources.append(f"{url} (Score: {score})")
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
#
|
| 181 |
-
#
|
| 182 |
-
#
|
| 183 |
-
#
|
| 184 |
-
#
|
| 185 |
-
#
|
| 186 |
-
|
| 187 |
-
#
|
| 188 |
-
#
|
| 189 |
-
#
|
| 190 |
-
#
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from urllib.parse import urlparse
|
| 5 |
import chainlit as cl
|
| 6 |
from langchain import PromptTemplate
|
| 7 |
+
import requests
|
| 8 |
+
from bs4 import BeautifulSoup
|
| 9 |
|
| 10 |
try:
|
| 11 |
from modules.constants import *
|
|
|
|
| 140 |
|
| 141 |
|
| 142 |
def get_sources(res, answer):
|
|
|
|
| 143 |
source_elements = []
|
|
|
|
|
|
|
| 144 |
source_dict = {} # Dictionary to store URL elements
|
| 145 |
|
| 146 |
for idx, source in enumerate(res["source_documents"]):
|
| 147 |
source_metadata = source.metadata
|
| 148 |
url = source_metadata["source"]
|
| 149 |
score = source_metadata.get("score", "N/A")
|
| 150 |
+
page = source_metadata.get("page", 1)
|
| 151 |
+
|
| 152 |
+
lecture_tldr = source_metadata.get("tldr", "N/A")
|
| 153 |
+
lecture_recording = source_metadata.get("lecture_recording", "N/A")
|
| 154 |
+
suggested_readings = source_metadata.get("suggested_readings", "N/A")
|
| 155 |
|
| 156 |
+
source_type = source_metadata.get("source_type", "N/A")
|
| 157 |
+
|
| 158 |
+
url_name = f"{url}_{page}"
|
| 159 |
+
if url_name not in source_dict:
|
| 160 |
+
source_dict[url_name] = {
|
| 161 |
+
"text": source.page_content,
|
| 162 |
+
"url": url,
|
| 163 |
+
"score": score,
|
| 164 |
+
"page": page,
|
| 165 |
+
"lecture_tldr": lecture_tldr,
|
| 166 |
+
"lecture_recording": lecture_recording,
|
| 167 |
+
"suggested_readings": suggested_readings,
|
| 168 |
+
"source_type": source_type,
|
| 169 |
+
}
|
| 170 |
else:
|
| 171 |
+
source_dict[url_name]["text"] += f"\n\n{source.page_content}"
|
| 172 |
|
| 173 |
+
# First, display the answer
|
| 174 |
+
full_answer = "**Answer:**\n"
|
| 175 |
+
full_answer += answer
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
+
# Then, display the sources
|
| 178 |
+
full_answer += "\n\n**Sources:**\n"
|
| 179 |
+
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
| 180 |
+
full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
|
| 181 |
+
|
| 182 |
+
name = f"Source {idx + 1} Text\n"
|
| 183 |
+
full_answer += name
|
| 184 |
+
source_elements.append(cl.Text(name=name, content=source_data["text"]))
|
| 185 |
+
|
| 186 |
+
# Add a PDF element if the source is a PDF file
|
| 187 |
+
if source_data["url"].lower().endswith(".pdf"):
|
| 188 |
+
name = f"Source {idx + 1} PDF\n"
|
| 189 |
+
full_answer += name
|
| 190 |
+
pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
|
| 191 |
+
source_elements.append(cl.Pdf(name=name, url=pdf_url))
|
| 192 |
+
|
| 193 |
+
# Finally, include lecture metadata for each unique source
|
| 194 |
+
# displayed_urls = set()
|
| 195 |
+
# full_answer += "\n**Metadata:**\n"
|
| 196 |
+
# for url_name, source_data in source_dict.items():
|
| 197 |
+
# if source_data["url"] not in displayed_urls:
|
| 198 |
+
# full_answer += f"\nSource: {source_data['url']}\n"
|
| 199 |
+
# full_answer += f"Type: {source_data['source_type']}\n"
|
| 200 |
+
# full_answer += f"TL;DR: {source_data['lecture_tldr']}\n"
|
| 201 |
+
# full_answer += f"Lecture Recording: {source_data['lecture_recording']}\n"
|
| 202 |
+
# full_answer += f"Suggested Readings: {source_data['suggested_readings']}\n"
|
| 203 |
+
# displayed_urls.add(source_data["url"])
|
| 204 |
+
full_answer += "\n**Metadata:**\n"
|
| 205 |
+
for url_name, source_data in source_dict.items():
|
| 206 |
+
full_answer += f"\nSource: {source_data['url']}\n"
|
| 207 |
+
full_answer += f"Page: {source_data['page']}\n"
|
| 208 |
+
full_answer += f"Type: {source_data['source_type']}\n"
|
| 209 |
+
full_answer += f"TL;DR: {source_data['lecture_tldr']}\n"
|
| 210 |
+
full_answer += f"Lecture Recording: {source_data['lecture_recording']}\n"
|
| 211 |
+
full_answer += f"Suggested Readings: {source_data['suggested_readings']}\n"
|
| 212 |
+
|
| 213 |
+
return full_answer, source_elements
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def get_lecture_metadata(schedule_url):
|
| 217 |
+
"""
|
| 218 |
+
Function to get the lecture metadata from the schedule URL.
|
| 219 |
+
"""
|
| 220 |
+
lecture_metadata = {}
|
| 221 |
+
|
| 222 |
+
# Get the main schedule page content
|
| 223 |
+
r = requests.get(schedule_url)
|
| 224 |
+
soup = BeautifulSoup(r.text, "html.parser")
|
| 225 |
+
|
| 226 |
+
# Find all lecture blocks
|
| 227 |
+
lecture_blocks = soup.find_all("div", class_="lecture-container")
|
| 228 |
+
|
| 229 |
+
for block in lecture_blocks:
|
| 230 |
+
try:
|
| 231 |
+
# Extract the lecture title
|
| 232 |
+
title = block.find("span", style="font-weight: bold;").text.strip()
|
| 233 |
+
|
| 234 |
+
# Extract the TL;DR
|
| 235 |
+
tldr = block.find("strong", text="tl;dr:").next_sibling.strip()
|
| 236 |
+
|
| 237 |
+
# Extract the link to the slides
|
| 238 |
+
slides_link_tag = block.find("a", title="Download slides")
|
| 239 |
+
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
| 240 |
+
|
| 241 |
+
# Extract the link to the lecture recording
|
| 242 |
+
recording_link_tag = block.find("a", title="Download lecture recording")
|
| 243 |
+
recording_link = (
|
| 244 |
+
recording_link_tag["href"].strip() if recording_link_tag else None
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# Extract suggested readings or summary if available
|
| 248 |
+
suggested_readings_tag = block.find("p", text="Suggested Readings:")
|
| 249 |
+
if suggested_readings_tag:
|
| 250 |
+
suggested_readings = suggested_readings_tag.find_next_sibling("ul")
|
| 251 |
+
if suggested_readings:
|
| 252 |
+
suggested_readings = suggested_readings.get_text(
|
| 253 |
+
separator="\n"
|
| 254 |
+
).strip()
|
| 255 |
+
else:
|
| 256 |
+
suggested_readings = "No specific readings provided."
|
| 257 |
+
else:
|
| 258 |
+
suggested_readings = "No specific readings provided."
|
| 259 |
+
|
| 260 |
+
# Add to the dictionary
|
| 261 |
+
slides_link = f"https://dl4ds.github.io{slides_link}"
|
| 262 |
+
lecture_metadata[slides_link] = {
|
| 263 |
+
"tldr": tldr,
|
| 264 |
+
"title": title,
|
| 265 |
+
"lecture_recording": recording_link,
|
| 266 |
+
"suggested_readings": suggested_readings,
|
| 267 |
+
}
|
| 268 |
+
except Exception as e:
|
| 269 |
+
print(f"Error processing block: {e}")
|
| 270 |
+
continue
|
| 271 |
+
|
| 272 |
+
return lecture_metadata
|
code/modules/llm_tutor.py
CHANGED
|
@@ -8,7 +8,6 @@ from langchain.llms import CTransformers
|
|
| 8 |
from langchain.memory import ConversationBufferWindowMemory
|
| 9 |
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
| 10 |
import os
|
| 11 |
-
|
| 12 |
from modules.constants import *
|
| 13 |
from modules.helpers import get_prompt
|
| 14 |
from modules.chat_model_loader import ChatModelLoader
|
|
@@ -34,14 +33,21 @@ class LLMTutor:
|
|
| 34 |
|
| 35 |
# Retrieval QA Chain
|
| 36 |
def retrieval_qa_chain(self, llm, prompt, db):
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
if self.config["llm_params"]["use_history"]:
|
| 46 |
memory = ConversationBufferWindowMemory(
|
| 47 |
k=self.config["llm_params"]["memory_window"],
|
|
|
|
| 8 |
from langchain.memory import ConversationBufferWindowMemory
|
| 9 |
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
| 10 |
import os
|
|
|
|
| 11 |
from modules.constants import *
|
| 12 |
from modules.helpers import get_prompt
|
| 13 |
from modules.chat_model_loader import ChatModelLoader
|
|
|
|
| 33 |
|
| 34 |
# Retrieval QA Chain
|
| 35 |
def retrieval_qa_chain(self, llm, prompt, db):
|
| 36 |
+
if self.config["embedding_options"]["db_option"] in ["FAISS", "Chroma"]:
|
| 37 |
+
retriever = VectorDBScore(
|
| 38 |
+
vectorstore=db,
|
| 39 |
+
search_type="similarity_score_threshold",
|
| 40 |
+
search_kwargs={
|
| 41 |
+
"score_threshold": self.config["embedding_options"][
|
| 42 |
+
"score_threshold"
|
| 43 |
+
],
|
| 44 |
+
"k": self.config["embedding_options"]["search_top_k"],
|
| 45 |
+
},
|
| 46 |
+
)
|
| 47 |
+
elif self.config["embedding_options"]["db_option"] == "RAGatouille":
|
| 48 |
+
retriever = db.as_langchain_retriever(
|
| 49 |
+
k=self.config["embedding_options"]["search_top_k"]
|
| 50 |
+
)
|
| 51 |
if self.config["llm_params"]["use_history"]:
|
| 52 |
memory = ConversationBufferWindowMemory(
|
| 53 |
k=self.config["llm_params"]["memory_window"],
|
code/modules/vector_db.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
import yaml
|
| 4 |
-
from
|
| 5 |
from langchain.schema.vectorstore import VectorStoreRetriever
|
| 6 |
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
| 7 |
from langchain.schema.document import Document
|
| 8 |
from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
|
|
|
|
| 9 |
|
| 10 |
try:
|
| 11 |
from modules.embedding_model_loader import EmbeddingModelLoader
|
|
@@ -25,7 +26,7 @@ class VectorDBScore(VectorStoreRetriever):
|
|
| 25 |
|
| 26 |
# See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
|
| 27 |
def _get_relevant_documents(
|
| 28 |
-
|
| 29 |
) -> List[Document]:
|
| 30 |
docs_and_similarities = (
|
| 31 |
self.vectorstore.similarity_search_with_relevance_scores(
|
|
@@ -55,7 +56,6 @@ class VectorDBScore(VectorStoreRetriever):
|
|
| 55 |
return docs
|
| 56 |
|
| 57 |
|
| 58 |
-
|
| 59 |
class VectorDB:
|
| 60 |
def __init__(self, config, logger=None):
|
| 61 |
self.config = config
|
|
@@ -116,7 +116,15 @@ class VectorDB:
|
|
| 116 |
self.embedding_model_loader = EmbeddingModelLoader(self.config)
|
| 117 |
self.embedding_model = self.embedding_model_loader.load_embedding_model()
|
| 118 |
|
| 119 |
-
def initialize_database(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
# Track token usage
|
| 121 |
self.logger.info("Initializing vector_db")
|
| 122 |
self.logger.info("\tUsing {} as db_option".format(self.db_option))
|
|
@@ -136,6 +144,14 @@ class VectorDB:
|
|
| 136 |
+ self.config["embedding_options"]["model"],
|
| 137 |
),
|
| 138 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
self.logger.info("Completed initializing vector_db")
|
| 140 |
|
| 141 |
def create_database(self):
|
|
@@ -146,11 +162,13 @@ class VectorDB:
|
|
| 146 |
files += lecture_pdfs
|
| 147 |
if "storage/data/urls.txt" in files:
|
| 148 |
files.remove("storage/data/urls.txt")
|
| 149 |
-
document_chunks, document_names =
|
|
|
|
|
|
|
| 150 |
self.logger.info("Completed loading data")
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
|
| 155 |
def save_database(self):
|
| 156 |
if self.db_option == "FAISS":
|
|
@@ -166,6 +184,9 @@ class VectorDB:
|
|
| 166 |
elif self.db_option == "Chroma":
|
| 167 |
# db is saved in the persist directory during initialization
|
| 168 |
pass
|
|
|
|
|
|
|
|
|
|
| 169 |
self.logger.info("Saved database")
|
| 170 |
|
| 171 |
def load_database(self):
|
|
@@ -180,7 +201,7 @@ class VectorDB:
|
|
| 180 |
+ self.config["embedding_options"]["model"],
|
| 181 |
),
|
| 182 |
self.embedding_model,
|
| 183 |
-
|
| 184 |
)
|
| 185 |
elif self.db_option == "Chroma":
|
| 186 |
self.vector_db = Chroma(
|
|
@@ -193,6 +214,10 @@ class VectorDB:
|
|
| 193 |
),
|
| 194 |
embedding_function=self.embedding_model,
|
| 195 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
self.logger.info("Loaded database")
|
| 197 |
return self.vector_db
|
| 198 |
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
import yaml
|
| 4 |
+
from langchain_community.vectorstores import FAISS, Chroma
|
| 5 |
from langchain.schema.vectorstore import VectorStoreRetriever
|
| 6 |
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
| 7 |
from langchain.schema.document import Document
|
| 8 |
from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
|
| 9 |
+
from ragatouille import RAGPretrainedModel
|
| 10 |
|
| 11 |
try:
|
| 12 |
from modules.embedding_model_loader import EmbeddingModelLoader
|
|
|
|
| 26 |
|
| 27 |
# See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
|
| 28 |
def _get_relevant_documents(
|
| 29 |
+
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
| 30 |
) -> List[Document]:
|
| 31 |
docs_and_similarities = (
|
| 32 |
self.vectorstore.similarity_search_with_relevance_scores(
|
|
|
|
| 56 |
return docs
|
| 57 |
|
| 58 |
|
|
|
|
| 59 |
class VectorDB:
|
| 60 |
def __init__(self, config, logger=None):
|
| 61 |
self.config = config
|
|
|
|
| 116 |
self.embedding_model_loader = EmbeddingModelLoader(self.config)
|
| 117 |
self.embedding_model = self.embedding_model_loader.load_embedding_model()
|
| 118 |
|
| 119 |
+
def initialize_database(
|
| 120 |
+
self,
|
| 121 |
+
document_chunks: list,
|
| 122 |
+
document_names: list,
|
| 123 |
+
documents: list,
|
| 124 |
+
document_metadata: list,
|
| 125 |
+
):
|
| 126 |
+
if self.db_option in ["FAISS", "Chroma"]:
|
| 127 |
+
self.create_embedding_model()
|
| 128 |
# Track token usage
|
| 129 |
self.logger.info("Initializing vector_db")
|
| 130 |
self.logger.info("\tUsing {} as db_option".format(self.db_option))
|
|
|
|
| 144 |
+ self.config["embedding_options"]["model"],
|
| 145 |
),
|
| 146 |
)
|
| 147 |
+
elif self.db_option == "RAGatouille":
|
| 148 |
+
self.RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
|
| 149 |
+
index_path = self.RAG.index(
|
| 150 |
+
index_name="new_idx",
|
| 151 |
+
collection=documents,
|
| 152 |
+
document_ids=document_names,
|
| 153 |
+
document_metadatas=document_metadata,
|
| 154 |
+
)
|
| 155 |
self.logger.info("Completed initializing vector_db")
|
| 156 |
|
| 157 |
def create_database(self):
|
|
|
|
| 162 |
files += lecture_pdfs
|
| 163 |
if "storage/data/urls.txt" in files:
|
| 164 |
files.remove("storage/data/urls.txt")
|
| 165 |
+
document_chunks, document_names, documents, document_metadata = (
|
| 166 |
+
data_loader.get_chunks(files, urls)
|
| 167 |
+
)
|
| 168 |
self.logger.info("Completed loading data")
|
| 169 |
+
self.initialize_database(
|
| 170 |
+
document_chunks, document_names, documents, document_metadata
|
| 171 |
+
)
|
| 172 |
|
| 173 |
def save_database(self):
|
| 174 |
if self.db_option == "FAISS":
|
|
|
|
| 184 |
elif self.db_option == "Chroma":
|
| 185 |
# db is saved in the persist directory during initialization
|
| 186 |
pass
|
| 187 |
+
elif self.db_option == "RAGatouille":
|
| 188 |
+
# index is saved during initialization
|
| 189 |
+
pass
|
| 190 |
self.logger.info("Saved database")
|
| 191 |
|
| 192 |
def load_database(self):
|
|
|
|
| 201 |
+ self.config["embedding_options"]["model"],
|
| 202 |
),
|
| 203 |
self.embedding_model,
|
| 204 |
+
allow_dangerous_deserialization=True,
|
| 205 |
)
|
| 206 |
elif self.db_option == "Chroma":
|
| 207 |
self.vector_db = Chroma(
|
|
|
|
| 214 |
),
|
| 215 |
embedding_function=self.embedding_model,
|
| 216 |
)
|
| 217 |
+
elif self.db_option == "RAGatouille":
|
| 218 |
+
self.vector_db = RAGPretrainedModel.from_index(
|
| 219 |
+
".ragatouille/colbert/indexes/new_idx"
|
| 220 |
+
)
|
| 221 |
self.logger.info("Loaded database")
|
| 222 |
return self.vector_db
|
| 223 |
|
requirements.txt
CHANGED
|
@@ -17,3 +17,4 @@ fake-useragent==1.4.0
|
|
| 17 |
git+https://github.com/huggingface/accelerate.git
|
| 18 |
llama-cpp-python
|
| 19 |
PyPDF2==3.0.1
|
|
|
|
|
|
| 17 |
git+https://github.com/huggingface/accelerate.git
|
| 18 |
llama-cpp-python
|
| 19 |
PyPDF2==3.0.1
|
| 20 |
+
ragatouille==0.0.8.post2
|
storage/data/urls.txt
CHANGED
|
@@ -1 +1,3 @@
|
|
| 1 |
https://dl4ds.github.io/sp2024/
|
|
|
|
|
|
|
|
|
| 1 |
https://dl4ds.github.io/sp2024/
|
| 2 |
+
https://dl4ds.github.io/sp2024/static_files/lectures/15_RAG_CoT.pdf
|
| 3 |
+
https://dl4ds.github.io/sp2024/static_files/lectures/21_RL_RLHF_v2.pdf
|