Ritvik19 commited on
Commit
ecb7a48
·
verified ·
1 Parent(s): 60e8923

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +16 -6
  2. requirements.txt +2 -1
app.py CHANGED
@@ -2,12 +2,14 @@ import os
2
  from pathlib import Path
3
 
4
  from langchain.chains import ConversationalRetrievalChain
5
- from langchain.embeddings import OpenAIEmbeddings
6
  from langchain.vectorstores import Chroma
7
- from langchain.llms.openai import OpenAIChat
8
  from langchain.document_loaders import PyPDFLoader, WebBaseLoader
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from langchain.embeddings.openai import OpenAIEmbeddings
 
 
 
11
 
12
  import streamlit as st
13
 
@@ -29,7 +31,7 @@ def load_documents():
29
 
30
 
31
  def split_documents(documents):
32
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=256)
33
  texts = text_splitter.split_documents(documents)
34
  return texts
35
 
@@ -41,7 +43,10 @@ def embeddings_on_local_vectordb(texts):
41
  persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix(),
42
  )
43
  vectordb.persist()
44
- retriever = vectordb.as_retriever(search_kwargs={"k": 3})
 
 
 
45
  return retriever
46
 
47
 
@@ -51,10 +56,11 @@ def query_llm(retriever, query):
51
  retriever=retriever,
52
  return_source_documents=True,
53
  )
 
54
  result = qa_chain({"question": query, "chat_history": st.session_state.messages})
55
  result = result["answer"]
56
  st.session_state.messages.append((query, result))
57
- return result
58
 
59
 
60
  def input_fields():
@@ -77,6 +83,8 @@ def boot():
77
  st.title("Enigma Chatbot")
78
  input_fields()
79
  st.sidebar.button("Submit Documents", on_click=process_documents)
 
 
80
  if "messages" not in st.session_state:
81
  st.session_state.messages = []
82
  for message in st.session_state.messages:
@@ -84,7 +92,9 @@ def boot():
84
  st.chat_message("ai").write(message[1])
85
  if query := st.chat_input():
86
  st.chat_message("human").write(query)
87
- response = query_llm(st.session_state.retriever, query)
 
 
88
  st.chat_message("ai").write(response)
89
 
90
 
 
2
  from pathlib import Path
3
 
4
  from langchain.chains import ConversationalRetrievalChain
 
5
  from langchain.vectorstores import Chroma
6
+ from langchain.llms.openai import OpenAIChat, OpenAI
7
  from langchain.document_loaders import PyPDFLoader, WebBaseLoader
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain.embeddings.openai import OpenAIEmbeddings
10
+ from langchain.retrievers import ContextualCompressionRetriever
11
+ from langchain.retrievers.document_compressors import LLMChainExtractor
12
+ from langchain_experimental.text_splitter import SemanticChunker
13
 
14
  import streamlit as st
15
 
 
31
 
32
 
33
  def split_documents(documents):
34
+ text_splitter = SemanticChunker(OpenAIEmbeddings())
35
  texts = text_splitter.split_documents(documents)
36
  return texts
37
 
 
43
  persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix(),
44
  )
45
  vectordb.persist()
46
+ retriever = ContextualCompressionRetriever(
47
+ base_compressor=LLMChainExtractor.from_llm(OpenAI(temperature=0)),
48
+ base_retriever=vectordb.as_retriever(search_kwargs={"k": 3}, search_type="mmr"),
49
+ )
50
  return retriever
51
 
52
 
 
56
  retriever=retriever,
57
  return_source_documents=True,
58
  )
59
+ relevant_docs = retriever.get_relevant_documents(query)
60
  result = qa_chain({"question": query, "chat_history": st.session_state.messages})
61
  result = result["answer"]
62
  st.session_state.messages.append((query, result))
63
+ return relevant_docs, result
64
 
65
 
66
  def input_fields():
 
83
  st.title("Enigma Chatbot")
84
  input_fields()
85
  st.sidebar.button("Submit Documents", on_click=process_documents)
86
+ st.sidebar.write("---")
87
+ st.sidebar.write("References made during the chat will appear here")
88
  if "messages" not in st.session_state:
89
  st.session_state.messages = []
90
  for message in st.session_state.messages:
 
92
  st.chat_message("ai").write(message[1])
93
  if query := st.chat_input():
94
  st.chat_message("human").write(query)
95
+ references, response = query_llm(st.session_state.retriever, query)
96
+ for doc in references:
97
+ st.sidebar.info(f"Page {doc.metadata['page']}\n\n{doc.page_content}")
98
  st.chat_message("ai").write(response)
99
 
100
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  openai==0.28
2
  langchain==0.1.1
3
  pypdf==4.0.0
4
- chromadb==0.4.22
 
 
1
  openai==0.28
2
  langchain==0.1.1
3
  pypdf==4.0.0
4
+ chromadb==0.4.22
5
+ langchain-experimental==0.0.49