File size: 5,410 Bytes
8e2b48f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171

import streamlit as st

# Set a custom background
import torch
from langchain import HuggingFacePipeline
from langchain.chains import LLMChain, RetrievalQA
from langchain.document_loaders import (
    DirectoryLoader,
    PyPDFLoader,
    TextLoader,
    UnstructuredPDFLoader,
)
from langchain.embeddings import HuggingFaceEmbeddings, LlamaCppEmbeddings
from langchain.llms import LlamaCpp
from langchain.prompts import PromptTemplate
from langchain.text_splitter import (
    CharacterTextSplitter,
    RecursiveCharacterTextSplitter,
)
from langchain.vectorstores import Chroma
from PIL import Image
from streamlit_extras.add_vertical_space import add_vertical_space

st.set_page_config(page_title="Welcome to our AI Question Answering Bot")

with st.sidebar:
    st.title('🤗💬 QA App')
    st.markdown('''
    ## About
    This app is an LLM-powered chatbot built using:
    - [Streamlit](<https://streamlit.io/>)
    - [HugChat](<https://github.com/Soulter/hugging-chat-api>)
    - Chat Model = llama2-chat-hf 7B 
    - Retreiver model = all-MiniLM-L6-v2
    
    💡 Note: No API key required!
    ''')
    add_vertical_space(5)
    st.write('Made with ❤️ by us')

# logo = Image.open('logo.png')
# st.image(logo, use_column_width=True)


# Introduction
st.markdown("""
Welcome! This is not just any bot, it's a special one equipped with state-of-the-art natural language processing capabilities, and ready to answer your queries.


Ready to explore? Let's get started!

* Step 1: Upload a PDF document.
* Step 2: Type in a question related to your document's content.
* Step 3: Get your answer!

Push clear cache before uploading a new doc !


""")


def write_text_file(content, file_path):
    try:
        with open(file_path, 'wb') as file:
            file.write(content)
        return True
    except Exception as e:
        print(f"Error occurred while writing the file: {e}")
        return False


# Wrap prompt template in a PromptTemplate object

def set_qa_prompt():
    # set prompt template
    prompt_template = """<s>[INST] <<SYS>> Use the following pieces of context closed between $ to answer the question closed between |. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
    ${context}$ <</SYS>>
    Question: |{question}|
    Answer:[/INST]</s>"""
    prompt = PromptTemplate(
        template=prompt_template, input_variables=["context", "question"]
    )
    return prompt


# Build RetrievalQA object

def build_retrieval_qa(_llm, _prompt, _vectorstore):
    dbqa = RetrievalQA.from_chain_type(llm=_llm,
                                       chain_type='stuff',
                                       retriever=_vectorstore.as_retriever(search_kwargs={'k': 3}),
                                       return_source_documents=True,
                                       chain_type_kwargs={'prompt': _prompt})
    return dbqa


# Instantiate QA object
# @st.cache(allow_output_mutation=True)
# @st.cache_resource()
@st.cache(allow_output_mutation=True)
def setup_dbqa(_texts):
    print("setup_dbqa ...")
    llm = HuggingFacePipeline.from_model_id(
        model_id="NousResearch/Llama-2-13b-chat-hf",
        task="text-generation",
        model_kwargs={
            "max_length": 1500, "load_in_8bit": True},
    )

    embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2',
                                       model_kwargs={'device': 'cpu'})

    vectorstore = Chroma.from_documents(texts, embeddings, persist_directory="vectorstore")

    prompt = set_qa_prompt()

    return build_retrieval_qa(llm, prompt, vectorstore)


def load_docs(uploaded_file):
    print("loading docs ...")
    content = uploaded_file.read()
    file_path_aux = "./temp/file.pdf"
    write_text_file(content, file_path_aux)
    file_path = "./temp/"

    loader = DirectoryLoader(file_path,
                             glob="*.pdf",
                             loader_cls=UnstructuredPDFLoader)
    documents = loader.load()

    # Split text from PDF into chunks
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000,
                                                   chunk_overlap=0,
                                                   length_function=len,)
    texts = text_splitter.split_documents(documents)
    return texts


# Set the background image
# Load a PDF file
uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")

if uploaded_file is not None:
    st.write('Loading file')

    texts = load_docs(uploaded_file)
    model = setup_dbqa(texts)

    # Build and persist FAISS vector store

    question = st.text_input('Ask a question:')

    if question:
        # Placeholder for chatbot logic to generate an answer based on the question and the PDF content

        answer = model({'query': question})
        # The below is just a hardcoded response
        print(question)
        print(answer)

        # st.write('Question: ', answer["query"])
        st.write('Question: ', answer["query"])
        st.write('Answer: ', answer["result"])
        st.write('Source documents: ', answer["source_documents"])

# if st.button("Clear cache before loading new document"):
#     # Clears all st.cache_resource caches:
#     st.cache_resource.clear()