File size: 3,422 Bytes
923a652
a2c0a8e
267eb52
 
 
 
 
 
 
a2c0a8e
267eb52
 
 
923a652
267eb52
 
a2c0a8e
267eb52
a2c0a8e
267eb52
 
 
a2c0a8e
267eb52
 
 
 
 
 
a2c0a8e
267eb52
 
 
 
 
 
a2c0a8e
 
267eb52
923a652
 
267eb52
 
 
 
 
 
 
923a652
267eb52
 
 
 
 
 
 
923a652
 
 
 
267eb52
 
 
923a652
 
 
267eb52
 
a2c0a8e
267eb52
a2c0a8e
267eb52
 
 
 
 
 
 
 
 
 
 
 
d353114
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import streamlit as st
from langchain_community.document_loaders.pdf import PyPDFDirectoryLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFaceHub

def make_vectorstore(embeddings):
    loader = PyPDFDirectoryLoader("data") 
    documents = loader.load()
    text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=0)
    texts = text_splitter.split_documents(documents)
    docsearch = FAISS.from_documents(texts, embeddings)
    
    return docsearch

def get_conversation(vectorstore, model):
    
    memory = ConversationBufferMemory(memory_key="messages", return_messages=True)

    conversation_chain = RetrievalQA.from_llm(
        llm=model,
        retriever=vectorstore.as_retriever(),
        memory=memory)
    
    return conversation_chain

def get_response(conversation_chain, query):
    # get the response
    response = conversation_chain.invoke(query)
    response = response["result"]
    answer = response.split('\nHelpful Answer: ')[1]
    return answer

def main():
    st.title("Chat LLM")
    if not os.path.exists("data"):
        os.makedirs("data")
    
    print("Downloading Embeddings Model")
    with st.spinner('Downloading Embeddings Model...'):
        embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-base", model_kwargs = {'device': 'cpu'})
    
    print("Loading LLM from HuggingFace")
    with st.spinner('Loading LLM from HuggingFace...'):
        llm = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta", model_kwargs={"temperature":0.7, "max_new_tokens":512, "top_p":0.95, "top_k":50})
    
    st.sidebar.title("Upload PDFs")
    uploaded_files = st.sidebar.file_uploader("Upload PDFs", accept_multiple_files=True)
    if uploaded_files:
        for file in uploaded_files:
            with open(f"data/{file.name}", "wb") as f:
                f.write(file.getbuffer())
        with st.spinner('making a vectorstore database...'):
            vectorstore = make_vectorstore(embeddings)
        with st.spinner('making a conversation chain...'):
            conversation_chain = get_conversation(vectorstore, llm)
        st.sidebar.success("PDFs uploaded successfully")
    else:
        st.sidebar.warning("Please upload PDFs")
    # add a clear chat button which will clear the session state
    if st.button("Clear Chat"):
        st.session_state.messages = []
    
    if "messages" not in st.session_state:
        st.session_state.messages = []
    
    for message in st.session_state.messages:
        if message["role"] == "user":
            st.chat_message("user").markdown(message["content"])
        else:
            st.chat_message("bot").markdown(message["content"])
    
    user_prompt = st.chat_input("ask a question", key="user")
    if user_prompt:
        st.chat_message("user").markdown(user_prompt)
        st.session_state.messages.append({"role": "user", "content": user_prompt})
        response = get_response(conversation_chain, user_prompt)
        st.chat_message("bot").markdown(response)
        st.session_state.messages.append({"role": "bot", "content": response})
        
if __name__ == "__main__":
    main()