Upload 2 files
Browse files- app.py +16 -6
- requirements.txt +2 -1
app.py
CHANGED
@@ -2,12 +2,14 @@ import os
|
|
2 |
from pathlib import Path
|
3 |
|
4 |
from langchain.chains import ConversationalRetrievalChain
|
5 |
-
from langchain.embeddings import OpenAIEmbeddings
|
6 |
from langchain.vectorstores import Chroma
|
7 |
-
from langchain.llms.openai import OpenAIChat
|
8 |
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
|
9 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
10 |
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
|
|
|
|
|
11 |
|
12 |
import streamlit as st
|
13 |
|
@@ -29,7 +31,7 @@ def load_documents():
|
|
29 |
|
30 |
|
31 |
def split_documents(documents):
|
32 |
-
text_splitter =
|
33 |
texts = text_splitter.split_documents(documents)
|
34 |
return texts
|
35 |
|
@@ -41,7 +43,10 @@ def embeddings_on_local_vectordb(texts):
|
|
41 |
persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix(),
|
42 |
)
|
43 |
vectordb.persist()
|
44 |
-
retriever =
|
|
|
|
|
|
|
45 |
return retriever
|
46 |
|
47 |
|
@@ -51,10 +56,11 @@ def query_llm(retriever, query):
|
|
51 |
retriever=retriever,
|
52 |
return_source_documents=True,
|
53 |
)
|
|
|
54 |
result = qa_chain({"question": query, "chat_history": st.session_state.messages})
|
55 |
result = result["answer"]
|
56 |
st.session_state.messages.append((query, result))
|
57 |
-
return result
|
58 |
|
59 |
|
60 |
def input_fields():
|
@@ -77,6 +83,8 @@ def boot():
|
|
77 |
st.title("Enigma Chatbot")
|
78 |
input_fields()
|
79 |
st.sidebar.button("Submit Documents", on_click=process_documents)
|
|
|
|
|
80 |
if "messages" not in st.session_state:
|
81 |
st.session_state.messages = []
|
82 |
for message in st.session_state.messages:
|
@@ -84,7 +92,9 @@ def boot():
|
|
84 |
st.chat_message("ai").write(message[1])
|
85 |
if query := st.chat_input():
|
86 |
st.chat_message("human").write(query)
|
87 |
-
response = query_llm(st.session_state.retriever, query)
|
|
|
|
|
88 |
st.chat_message("ai").write(response)
|
89 |
|
90 |
|
|
|
2 |
from pathlib import Path
|
3 |
|
4 |
from langchain.chains import ConversationalRetrievalChain
|
|
|
5 |
from langchain.vectorstores import Chroma
|
6 |
+
from langchain.llms.openai import OpenAIChat, OpenAI
|
7 |
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
|
8 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
9 |
from langchain.embeddings.openai import OpenAIEmbeddings
|
10 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
11 |
+
from langchain.retrievers.document_compressors import LLMChainExtractor
|
12 |
+
from langchain_experimental.text_splitter import SemanticChunker
|
13 |
|
14 |
import streamlit as st
|
15 |
|
|
|
31 |
|
32 |
|
33 |
def split_documents(documents):
|
34 |
+
text_splitter = SemanticChunker(OpenAIEmbeddings())
|
35 |
texts = text_splitter.split_documents(documents)
|
36 |
return texts
|
37 |
|
|
|
43 |
persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix(),
|
44 |
)
|
45 |
vectordb.persist()
|
46 |
+
retriever = ContextualCompressionRetriever(
|
47 |
+
base_compressor=LLMChainExtractor.from_llm(OpenAI(temperature=0)),
|
48 |
+
base_retriever=vectordb.as_retriever(search_kwargs={"k": 3}, search_type="mmr"),
|
49 |
+
)
|
50 |
return retriever
|
51 |
|
52 |
|
|
|
56 |
retriever=retriever,
|
57 |
return_source_documents=True,
|
58 |
)
|
59 |
+
relevant_docs = retriever.get_relevant_documents(query)
|
60 |
result = qa_chain({"question": query, "chat_history": st.session_state.messages})
|
61 |
result = result["answer"]
|
62 |
st.session_state.messages.append((query, result))
|
63 |
+
return relevant_docs, result
|
64 |
|
65 |
|
66 |
def input_fields():
|
|
|
83 |
st.title("Enigma Chatbot")
|
84 |
input_fields()
|
85 |
st.sidebar.button("Submit Documents", on_click=process_documents)
|
86 |
+
st.sidebar.write("---")
|
87 |
+
st.sidebar.write("References made during the chat will appear here")
|
88 |
if "messages" not in st.session_state:
|
89 |
st.session_state.messages = []
|
90 |
for message in st.session_state.messages:
|
|
|
92 |
st.chat_message("ai").write(message[1])
|
93 |
if query := st.chat_input():
|
94 |
st.chat_message("human").write(query)
|
95 |
+
references, response = query_llm(st.session_state.retriever, query)
|
96 |
+
for doc in references:
|
97 |
+
st.sidebar.info(f"Page {doc.metadata['page']}\n\n{doc.page_content}")
|
98 |
st.chat_message("ai").write(response)
|
99 |
|
100 |
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
openai==0.28
|
2 |
langchain==0.1.1
|
3 |
pypdf==4.0.0
|
4 |
-
chromadb==0.4.22
|
|
|
|
1 |
openai==0.28
|
2 |
langchain==0.1.1
|
3 |
pypdf==4.0.0
|
4 |
+
chromadb==0.4.22
|
5 |
+
langchain-experimental==0.0.49
|