Spaces:
Sleeping
Sleeping
Commit
·
197a291
1
Parent(s):
050c8b5
init chat
Browse files- app.py +91 -62
- build_rag.py +24 -0
- requirements.txt +4 -1
- src/__init__.py +0 -0
- src/chat.py +94 -0
- src/prompts.py +31 -0
- src/rag.py +166 -0
- templates/template_html.j2 +92 -0
app.py
CHANGED
@@ -1,63 +1,92 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
)
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
""
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
)
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
import gradio as gr
|
5 |
+
import os
|
6 |
+
from jinja2 import Environment, FileSystemLoader
|
7 |
+
|
8 |
+
from src.chat import Chat
|
9 |
+
from src.rag import FaissDB, AICompletion, define_query
|
10 |
+
from src.prompts import *
|
11 |
+
|
12 |
+
chat_model = AICompletion()
|
13 |
+
chat = Chat(system_prompt=SYSTEM_PROMPT)
|
14 |
+
faiss_index = FaissDB(emb_model=os.environ["OPENAI_EMBEDDINGS_MODEL"])
|
15 |
+
faiss_index.load_index(os.environ["PATH_TO_INDEX"])
|
16 |
+
|
17 |
+
proj_dir = Path(__file__).parent
|
18 |
+
logging.basicConfig(level=logging.INFO)
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
|
21 |
+
template_html = env.get_template('template_html.j2')
|
22 |
+
|
23 |
+
|
24 |
+
def add_text(text, history):
|
25 |
+
history = [] if history is None else history
|
26 |
+
history = history + [(text, None)]
|
27 |
+
return gr.Textbox(value="", interactive=False), gr.Textbox(value="", interactive=False), gr.Textbox(value="", interactive=False), history
|
28 |
+
|
29 |
+
|
30 |
+
def turn_on_activity():
|
31 |
+
return gr.Textbox(interactive=True), gr.Textbox(interactive=True), gr.Textbox(interactive=True)
|
32 |
+
|
33 |
+
|
34 |
+
def bot(history):
|
35 |
+
user_query = history[-1][0]
|
36 |
+
|
37 |
+
if not user_query:
|
38 |
+
raise gr.Warning("Please submit a non-empty string")
|
39 |
+
|
40 |
+
retrieve_query = define_query(user_query, chat_model)
|
41 |
+
documents = faiss_index.similarity_search(retrieve_query) if retrieve_query else ''
|
42 |
+
user_prompt = USER_PROMPT(user_query, documents)
|
43 |
+
|
44 |
+
prompt_html = template_html.render(documents=documents, query=user_query)
|
45 |
+
stream = chat.stream(user_prompt)
|
46 |
+
|
47 |
+
history[-1][1] = ""
|
48 |
+
for character in stream:
|
49 |
+
history[-1][1] = character
|
50 |
+
yield history, prompt_html
|
51 |
+
|
52 |
+
|
53 |
+
with (gr.Blocks() as demo):
|
54 |
+
chatbot = gr.Chatbot(
|
55 |
+
[],
|
56 |
+
elem_id="chatbot",
|
57 |
+
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
|
58 |
+
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
|
59 |
+
bubble_full_width=False,
|
60 |
+
show_copy_button=True,
|
61 |
+
show_share_button=True,
|
62 |
+
)
|
63 |
+
|
64 |
+
with gr.Row():
|
65 |
+
txt = gr.Textbox(
|
66 |
+
scale=4,
|
67 |
+
show_label=False,
|
68 |
+
placeholder="Enter text",
|
69 |
+
container=False,
|
70 |
+
)
|
71 |
+
txt_btn = gr.Button(value="Submit text", scale=1)
|
72 |
+
|
73 |
+
prompt_html = gr.HTML()
|
74 |
+
|
75 |
+
txt_msg = txt_btn.click(
|
76 |
+
add_text, [txt, chatbot], [txt, chatbot], queue=False
|
77 |
+
).then(
|
78 |
+
bot, [chatbot], [chatbot, prompt_html]
|
79 |
+
)
|
80 |
+
|
81 |
+
txt_msg.then(turn_on_activity, None, [txt], queue=False)
|
82 |
+
|
83 |
+
txt_msg = txt.submit(
|
84 |
+
add_text, [txt, chatbot], [txt, chatbot], queue=False
|
85 |
+
).then(
|
86 |
+
bot, [chatbot], [chatbot, prompt_html]
|
87 |
+
)
|
88 |
+
|
89 |
+
txt_msg.then(turn_on_activity, None, [txt], queue=False)
|
90 |
+
|
91 |
+
demo.queue()
|
92 |
+
demo.launch(debug=True)
|
build_rag.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.rag import CustomAgglomerativeSplitter, FaissDB
|
2 |
+
import argparse
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
import os
|
5 |
+
|
6 |
+
load_dotenv()
|
7 |
+
|
8 |
+
|
9 |
+
def main(path_to_dataset: str, path_to_index: str):
|
10 |
+
splitter = CustomAgglomerativeSplitter(emb_model=os.getenv("OPENAI_EMBEDDINGS_MODEL"))
|
11 |
+
documents = splitter.read_and_split(path_to_dataset)
|
12 |
+
|
13 |
+
faiss_db = FaissDB(emb_model=os.getenv("OPENAI_EMBEDDINGS_MODEL"))
|
14 |
+
faiss_db.init_index(documents)
|
15 |
+
faiss_db.save_index(path_to_index)
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == "__main__":
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument("--path_to_dataset", type=str, required=True)
|
21 |
+
parser.add_argument("--path_to_index", type=str, required=True)
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
main(args.path_to_dataset, args.path_to_index)
|
requirements.txt
CHANGED
@@ -4,4 +4,7 @@ langchain-community==0.2.7
|
|
4 |
langchain-openai==0.1.15
|
5 |
nltk==3.8.1
|
6 |
textract==1.6.5
|
7 |
-
faiss-cpu==1.8.0.post1
|
|
|
|
|
|
|
|
4 |
langchain-openai==0.1.15
|
5 |
nltk==3.8.1
|
6 |
textract==1.6.5
|
7 |
+
faiss-cpu==1.8.0.post1
|
8 |
+
numpy==1.26.4
|
9 |
+
python-dotenv==1.0.1
|
10 |
+
langchain_groq==0.1.6
|
src/__init__.py
ADDED
File without changes
|
src/chat.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from langchain_community.llms import OpenAI
|
5 |
+
from langchain_core.messages import HumanMessage, SystemMessage
|
6 |
+
from langchain_core.chat_history import BaseChatMessageHistory
|
7 |
+
from langchain_community.chat_message_histories import ChatMessageHistory
|
8 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
9 |
+
from langchain_core.runnables.history import RunnableWithMessageHistory
|
10 |
+
from langchain_openai.chat_models import ChatOpenAI
|
11 |
+
|
12 |
+
|
13 |
+
GENERATE_ARGS = {
|
14 |
+
'temperature': max(float(os.getenv("TEMPERATURE", 0.3)), 1e-2),
|
15 |
+
'max_tokens': int(os.getenv("MAX_NEW_TOKENS", 512)),
|
16 |
+
}
|
17 |
+
|
18 |
+
GENERATE_KWARGS = {
|
19 |
+
'top_p': float(os.getenv("TOP_P", 0.6)),
|
20 |
+
'frequency_penalty': max(-2, min(float(os.getenv("FREQ_PENALTY", 0)), 2))
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
class Chat:
|
25 |
+
|
26 |
+
def __init__(self, system_prompt: str):
|
27 |
+
|
28 |
+
base = ChatOpenAI
|
29 |
+
model = os.getenv("OPENAI_MODEL")
|
30 |
+
|
31 |
+
self.assistant_model = base(
|
32 |
+
model=model,
|
33 |
+
streaming=True,
|
34 |
+
**GENERATE_ARGS,
|
35 |
+
model_kwargs=GENERATE_KWARGS
|
36 |
+
)
|
37 |
+
|
38 |
+
self.store = {}
|
39 |
+
|
40 |
+
self.prompt = ChatPromptTemplate.from_messages([
|
41 |
+
("system", system_prompt),
|
42 |
+
MessagesPlaceholder(variable_name="history"),
|
43 |
+
("human", "{input}")
|
44 |
+
])
|
45 |
+
self.runnable = self.prompt | self.assistant_model
|
46 |
+
|
47 |
+
self.chat_model = RunnableWithMessageHistory(
|
48 |
+
self.runnable,
|
49 |
+
self.get_session_history,
|
50 |
+
input_messages_key="input",
|
51 |
+
history_messages_key="history",
|
52 |
+
)
|
53 |
+
|
54 |
+
def format_prompt(self, system_prompt: str, user_prompt: str):
|
55 |
+
messages = [
|
56 |
+
SystemMessage(
|
57 |
+
content=system_prompt
|
58 |
+
),
|
59 |
+
HumanMessage(
|
60 |
+
content=user_prompt
|
61 |
+
),
|
62 |
+
]
|
63 |
+
|
64 |
+
return messages
|
65 |
+
|
66 |
+
def get_session_history(self, session_id: (str | int)) -> BaseChatMessageHistory:
|
67 |
+
if session_id not in self.store:
|
68 |
+
self.store[session_id] = ChatMessageHistory()
|
69 |
+
return self.store[session_id]
|
70 |
+
|
71 |
+
def stream(self, user_prompt: str, session_id: (str | int) = 0):
|
72 |
+
try:
|
73 |
+
|
74 |
+
stream_answer = self.chat_model.stream(
|
75 |
+
{"input": user_prompt},
|
76 |
+
config={"configurable": {"session_id": session_id}},
|
77 |
+
)
|
78 |
+
output = ""
|
79 |
+
for response in stream_answer:
|
80 |
+
if type(self.assistant_model) == OpenAI:
|
81 |
+
if response.choices[0].delta.content:
|
82 |
+
output += response.choices[0].delta.content
|
83 |
+
yield output
|
84 |
+
else:
|
85 |
+
output += response.content
|
86 |
+
yield output
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
if "Too Many Requests" in str(e):
|
90 |
+
raise gr.Error(f"Too many requests: {str(e)}")
|
91 |
+
elif "Authorization header is invalid" in str(e):
|
92 |
+
raise gr.Error("Authentication error: API token was either not provided or incorrect")
|
93 |
+
else:
|
94 |
+
raise gr.Error(f"Unhandled Exception: {str(e)}")
|
src/prompts.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DEFINE_QUERY_PROMPT = """
|
2 |
+
Prompt:
|
3 |
+
You must identify if user's query is about a specific topic or it's a follow-up question.
|
4 |
+
If user asks about a specific topic, you must extract this topic and return it.
|
5 |
+
If the question is a follow-up query without mentioning any specific topics, you must return "Unrelated."
|
6 |
+
|
7 |
+
Example 1 (Extract topic):
|
8 |
+
User: Could you please explain what is Faiss. Thanks!
|
9 |
+
Your response: What is Faiss?
|
10 |
+
In this case your response must include only the topic name without any additional information or comments.
|
11 |
+
|
12 |
+
Example 2 (Follow-up or Unrelated):
|
13 |
+
User: Could you clarify the third point you mentioned earlier?
|
14 |
+
Your response: Unrelated.
|
15 |
+
In this case your response must be "Unrelated." without any additional information or comments.
|
16 |
+
"""
|
17 |
+
|
18 |
+
|
19 |
+
SYSTEM_PROMPT = """
|
20 |
+
Your task is to answer user's questions. You must provide clear and concise answers to user's queries.
|
21 |
+
If user provides any documents in 'Documents' section, your answer must be based on the information from these documents.
|
22 |
+
If this section is empty, it means that user asks follow-up questions or questions that are not related to any specific topic.
|
23 |
+
"""
|
24 |
+
|
25 |
+
|
26 |
+
USER_PROMPT = """
|
27 |
+
User query:
|
28 |
+
{0}
|
29 |
+
Documents:
|
30 |
+
{1}
|
31 |
+
"""
|
src/rag.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_openai.embeddings import OpenAIEmbeddings
|
2 |
+
from langchain.docstore.document import Document
|
3 |
+
import nltk
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import textract
|
7 |
+
from collections import defaultdict
|
8 |
+
from langchain_community.vectorstores import FAISS
|
9 |
+
from langchain_openai import ChatOpenAI
|
10 |
+
from langchain_core.prompts import ChatPromptTemplate
|
11 |
+
from src.prompts import DEFINE_QUERY_PROMPT
|
12 |
+
from typing import Optional
|
13 |
+
|
14 |
+
|
15 |
+
class AgglomerativeClustering:
|
16 |
+
def __init__(self, n_clusters: int = 16):
|
17 |
+
self.n_clusters = n_clusters
|
18 |
+
self.inf = 1e16
|
19 |
+
self.sample_size = 0
|
20 |
+
self._distances = None
|
21 |
+
|
22 |
+
def _init_clusters(self, X: np.array):
|
23 |
+
distances = self.distance(XA=X, XB=X) + np.eye(self.sample_size) * self.inf
|
24 |
+
clusters = [[i] for i in range(self.sample_size)]
|
25 |
+
return distances, clusters
|
26 |
+
|
27 |
+
def _average(self, clusters, min_cluster, max_cluster):
|
28 |
+
return (self._distances[min_cluster] * len(clusters[min_cluster]) + self._distances[max_cluster] * len(
|
29 |
+
clusters[max_cluster])) / (len(clusters[min_cluster]) + len(clusters[max_cluster]))
|
30 |
+
|
31 |
+
def _get_params(self, counter):
|
32 |
+
min_distance = np.argmin(self._distances)
|
33 |
+
param_1 = min_distance // counter
|
34 |
+
param_2 = min_distance % counter
|
35 |
+
return min(param_1, param_2), max(param_1, param_2)
|
36 |
+
|
37 |
+
def _merge_clusters(self, clusters, min_cluster, max_cluster):
|
38 |
+
self._distances[:, min_cluster] = self._distances[min_cluster, :]
|
39 |
+
self._distances = np.delete(self._distances, max_cluster, axis=0)
|
40 |
+
self._distances = np.delete(self._distances, max_cluster, axis=1)
|
41 |
+
self._distances[min_cluster][min_cluster] = np.inf
|
42 |
+
clusters[min_cluster].extend(clusters[max_cluster])
|
43 |
+
clusters.pop(max_cluster)
|
44 |
+
|
45 |
+
def _get_labels(self, clusters):
|
46 |
+
result = [0] * self.sample_size
|
47 |
+
for cluster in range(len(clusters)):
|
48 |
+
for dote in clusters[cluster]:
|
49 |
+
result[dote] = cluster
|
50 |
+
return result
|
51 |
+
|
52 |
+
def fit_predict(self, X: np.array) -> np.array:
|
53 |
+
self.sample_size = X.shape[0]
|
54 |
+
self._distances, clusters = self._init_clusters(X)
|
55 |
+
|
56 |
+
while len(clusters) > self.n_clusters:
|
57 |
+
min_cluster, max_cluster = self._get_params(len(clusters))
|
58 |
+
if max(clusters[min_cluster]) + 1 == min(clusters[max_cluster]):
|
59 |
+
self._distances[min_cluster] = self._average(clusters=clusters, min_cluster=min_cluster,
|
60 |
+
max_cluster=max_cluster)
|
61 |
+
self._merge_clusters(clusters=clusters, min_cluster=min_cluster, max_cluster=max_cluster)
|
62 |
+
else:
|
63 |
+
self._distances[min_cluster, max_cluster] = self.inf
|
64 |
+
self._distances[max_cluster, min_cluster] = self.inf
|
65 |
+
|
66 |
+
return np.array(self._get_labels(clusters))
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def distance(XA, XB):
|
70 |
+
return np.sqrt(((XA[:, np.newaxis] - XB[np.newaxis, :]) ** 2).sum(axis=2))
|
71 |
+
|
72 |
+
|
73 |
+
class CustomAgglomerativeSplitter:
|
74 |
+
def __init__(self, emb_model: str):
|
75 |
+
self._embeddings_model = OpenAIEmbeddings(model=emb_model)
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def read_pdfs(path: str) -> tuple[list, list]:
|
79 |
+
files = os.listdir(path)
|
80 |
+
pages = []
|
81 |
+
file_names = []
|
82 |
+
for file in files:
|
83 |
+
page = textract.process(f"{path}/{file}", method='pdfminer').decode('utf-8').replace('\n', ' ')
|
84 |
+
text = nltk.sent_tokenize(page)
|
85 |
+
pages.append(text)
|
86 |
+
file_names.append(file)
|
87 |
+
return pages, file_names
|
88 |
+
|
89 |
+
def get_embeddings(self, pages: list) -> list[np.array]:
|
90 |
+
return [np.array(self._embeddings_model.embed_documents(texts)) for texts in pages]
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def split_list_by_indexes(data: list, indexes: list) -> list:
|
94 |
+
result_dict = defaultdict(list)
|
95 |
+
for element, index in zip(data, indexes):
|
96 |
+
result_dict[index].append(element)
|
97 |
+
return list(result_dict.values())
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def balance_pages(pages: list, max_tokens: int = 256) -> list:
|
101 |
+
balanced_pages = []
|
102 |
+
for page in pages:
|
103 |
+
str_page = ' '.join(page)
|
104 |
+
if len(str_page.split()) > max_tokens:
|
105 |
+
n_of_pages = int(np.ceil(len(str_page.split()) / max_tokens))
|
106 |
+
result = [' '.join(list(res)) for res in np.array_split(page, n_of_pages)]
|
107 |
+
balanced_pages.extend(result)
|
108 |
+
else:
|
109 |
+
balanced_pages.append(' '.join(page))
|
110 |
+
return balanced_pages
|
111 |
+
|
112 |
+
def cluster_pages(self, pages: list, embeddings: list, file_names: list, mean_n_of_sentences: int = 5) -> list:
|
113 |
+
documents = []
|
114 |
+
for page_number, page in enumerate(pages):
|
115 |
+
sentence_embeddings = embeddings[page_number]
|
116 |
+
n_clusters = len(page) // mean_n_of_sentences
|
117 |
+
model = AgglomerativeClustering(n_clusters=n_clusters)
|
118 |
+
labels = model.fit_predict(sentence_embeddings)
|
119 |
+
page_docs = self.split_list_by_indexes(page, labels)
|
120 |
+
page_docs = self.balance_pages(page_docs)
|
121 |
+
documents.extend([
|
122 |
+
Document(page_content=text, metadata={"file_name": file_names[page_number]}) for text in page_docs
|
123 |
+
])
|
124 |
+
return documents
|
125 |
+
|
126 |
+
def read_and_split(self, path: str) -> list:
|
127 |
+
pages, file_names = self.read_pdfs(path)
|
128 |
+
embeddings = self.get_embeddings(pages)
|
129 |
+
return self.cluster_pages(pages, embeddings, file_names)
|
130 |
+
|
131 |
+
|
132 |
+
class FaissDB:
|
133 |
+
def __init__(self, emb_model):
|
134 |
+
self._embeddings_model = OpenAIEmbeddings(model=emb_model)
|
135 |
+
self.index = None
|
136 |
+
|
137 |
+
def init_index(self, documents: list[Document]):
|
138 |
+
self.index = FAISS.from_documents(documents, self._embeddings_model)
|
139 |
+
|
140 |
+
def save_index(self, path: str):
|
141 |
+
self.index.save_local(path)
|
142 |
+
|
143 |
+
def load_index(self, path: str):
|
144 |
+
self.index = FAISS.load_local(path, self._embeddings_model, allow_dangerous_deserialization=True)
|
145 |
+
|
146 |
+
def similarity_search(self, query: str, k: int = 5):
|
147 |
+
if self.index is None:
|
148 |
+
raise ValueError("Index is not initialized")
|
149 |
+
return self.index.similarity_search(query, k)
|
150 |
+
|
151 |
+
|
152 |
+
class AICompletion:
|
153 |
+
def __init__(self, chat_model: str = "gpt-4o", temperature: float = 0.0):
|
154 |
+
self.human = "{text}"
|
155 |
+
self.model = ChatOpenAI(model=chat_model, temperature=temperature)
|
156 |
+
|
157 |
+
def get_answer(self, system_prompt: str, text: str) -> (str | None):
|
158 |
+
prompt = ChatPromptTemplate.from_messages([("system", system_prompt),
|
159 |
+
("human", self.human)])
|
160 |
+
chain = prompt | self.model
|
161 |
+
return chain.invoke({"text": text}).content
|
162 |
+
|
163 |
+
|
164 |
+
def define_query(query: str, chat_model: AICompletion) -> Optional[str]:
|
165 |
+
result = chat_model.get_answer(DEFINE_QUERY_PROMPT, query)
|
166 |
+
return result if result != "Unrelated." else None
|
templates/template_html.j2
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Information Page</title>
|
7 |
+
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap">
|
8 |
+
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600&display=swap">
|
9 |
+
<style>
|
10 |
+
* {
|
11 |
+
font-family: "Source Sans Pro";
|
12 |
+
}
|
13 |
+
.instructions > * {
|
14 |
+
color: #111 !important;
|
15 |
+
}
|
16 |
+
details.doc-box * {
|
17 |
+
color: #111 !important;
|
18 |
+
}
|
19 |
+
.dark {
|
20 |
+
background: #111;
|
21 |
+
color: white;
|
22 |
+
}
|
23 |
+
.doc-box {
|
24 |
+
padding: 10px;
|
25 |
+
margin-top: 10px;
|
26 |
+
background-color: #baecc2;
|
27 |
+
border-radius: 6px;
|
28 |
+
color: #111 !important;
|
29 |
+
max-width: 700px;
|
30 |
+
box-shadow: rgba(0, 0, 0, 0.2) 0px 1px 2px 0px;
|
31 |
+
}
|
32 |
+
.doc-full {
|
33 |
+
margin: 10px 14px;
|
34 |
+
line-height: 1.6rem;
|
35 |
+
}
|
36 |
+
.instructions {
|
37 |
+
color: #111 !important;
|
38 |
+
background: #b7bdfd;
|
39 |
+
display: block;
|
40 |
+
border-radius: 6px;
|
41 |
+
padding: 6px 10px;
|
42 |
+
line-height: 1.6rem;
|
43 |
+
max-width: 700px;
|
44 |
+
box-shadow: rgba(0, 0, 0, 0.2) 0px 1px 2px 0px;
|
45 |
+
}
|
46 |
+
.query {
|
47 |
+
color: #111 !important;
|
48 |
+
background: #ffbcbc;
|
49 |
+
display: block;
|
50 |
+
border-radius: 6px;
|
51 |
+
padding: 6px 10px;
|
52 |
+
line-height: 2rem;
|
53 |
+
max-width: 700px;
|
54 |
+
box-shadow: rgba(0, 0, 0, 0.2) 0px 1px 2px 0px;
|
55 |
+
}
|
56 |
+
</style>
|
57 |
+
</head>
|
58 |
+
<body>
|
59 |
+
<div class="prose svelte-1ybaih5" id="component-6">
|
60 |
+
<h2>Prompt</h2>
|
61 |
+
Below is the prompt that is given to the model. <hr>
|
62 |
+
{#<h2>Instructions</h2>#}
|
63 |
+
{# <span class="instructions">{{ instructions }}</span>#}
|
64 |
+
<h2>Context</h2>
|
65 |
+
{% for doc in documents %}
|
66 |
+
<details class="doc-box">
|
67 |
+
<summary>
|
68 |
+
<b>Doc {{ loop.index }}:</b> <span class="doc-short">{{ doc[:100] }}...</span>
|
69 |
+
</summary>
|
70 |
+
<div class="doc-full">{{ doc }}</div>
|
71 |
+
</details>
|
72 |
+
{% endfor %}
|
73 |
+
<h2>Query</h2>
|
74 |
+
<span class="query">{{ query }}</span>
|
75 |
+
</div>
|
76 |
+
<script>
|
77 |
+
document.addEventListener("DOMContentLoaded", function() {
|
78 |
+
const detailsElements = document.querySelectorAll('.doc-box');
|
79 |
+
detailsElements.forEach(detail => {
|
80 |
+
detail.addEventListener('toggle', function() {
|
81 |
+
const docShort = this.querySelector('.doc-short');
|
82 |
+
if (this.open) {
|
83 |
+
docShort.style.display = 'none';
|
84 |
+
} else {
|
85 |
+
docShort.style.display = 'inline';
|
86 |
+
}
|
87 |
+
});
|
88 |
+
});
|
89 |
+
});
|
90 |
+
</script>
|
91 |
+
</body>
|
92 |
+
</html>
|