Npps commited on
Commit
b5162b6
1 Parent(s): 4992e0a

Upload 3 files

Browse files
Files changed (3) hide show
  1. LLM_tool.py +319 -0
  2. constants.py +16 -0
  3. requirements.txt +24 -0
LLM_tool.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import glob
4
+ from typing import Union
5
+ from io import BytesIO
6
+ from typing import List
7
+ from dotenv import load_dotenv
8
+ from multiprocessing import Pool
9
+ from constants import CHROMA_SETTINGS
10
+ import tempfile
11
+ from tqdm import tqdm
12
+ import argparse
13
+ import time
14
+ from PIL import Image
15
+ from langchain.chains import RetrievalQA
16
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
17
+ from langchain_community.chat_models import ChatOpenAI
18
+ from langchain.chains import ConversationalRetrievalChain
19
+ from langchain.docstore.document import Document
20
+ from langchain_community.embeddings import OpenAIEmbeddings
21
+ from langchain.memory import ConversationBufferMemory
22
+ from langchain.text_splitter import CharacterTextSplitter,RecursiveCharacterTextSplitter
23
+ from langchain_community.vectorstores import FAISS,Chroma
24
+ from langchain_community.llms import Ollama
25
+ from langchain_cohere import CohereEmbeddings
26
+
27
+ load_dotenv()
28
+
29
+
30
+ ######################### HTML CSS ############################
31
+ css = '''
32
+ <style>
33
+ .chat-message {
34
+ padding: 1.5rem; border-radius: 0.5rem; margin-bottom: 1rem; display: flex
35
+ }
36
+ .chat-message.user {
37
+ background-color: #2b313e
38
+ }
39
+ .chat-message.bot {
40
+ background-color: #475063
41
+ }
42
+ .chat-message .avatar {
43
+ width: 20%;
44
+ }
45
+ .chat-message .avatar img {
46
+ max-width: 78px;
47
+ max-height: 78px;
48
+ border-radius: 50%;
49
+ object-fit: cover;
50
+ }
51
+ .chat-message .message {
52
+ width: 80%;
53
+ padding: 0 1.5rem;
54
+ color: #fff;
55
+ }
56
+ '''
57
+
58
+ bot_template = '''
59
+ <div class="chat-message bot">
60
+ <div class="avatar">
61
+ <img src="https://i.pinimg.com/originals/0c/67/5a/0c675a8e1061478d2b7b21b330093444.gif" style="max-height: 70px; max-width: 50px; border-radius: 50%; object-fit: cover;">
62
+ </div>
63
+ <div class="message">{{MSG}}</div>
64
+ </div>
65
+ '''
66
+
67
+
68
+ user_template = '''
69
+ <div class="chat-message user">
70
+ <div class="avatar">
71
+ <img src="https://th.bing.com/th/id/OIP.uDqZFTOXkEWF9PPDHLCntAHaHa?pid=ImgDet&rs=1" style="max-height: 80px; max-width: 50px; border-radius: 50%; object-fit: cover;">
72
+ </div>
73
+ <div class="message">{{MSG}}</div>
74
+ </div>
75
+ '''
76
+ ###################################################
77
+
78
+ chunk_size = 500
79
+ chunk_overlap = 50
80
+ persist_directory = os.environ.get('PERSIST_DIRECTORY')
81
+ print(persist_directory)
82
+ source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents')
83
+ target_source_chunks= int(os.environ.get('TARGET_SOURCE_CHUNKS', 5))
84
+ embeddings_model_name = os.environ.get('EMBEDDINGS_MODEL_NAME')
85
+ model_type=os.environ.get('MODEL_TYPE')
86
+
87
+
88
+ from langchain_community.document_loaders import (
89
+ CSVLoader,
90
+ PyMuPDFLoader,
91
+ TextLoader)
92
+
93
+
94
+ # Map file extensions to document loaders and their arguments
95
+ LOADER_MAPPING = {
96
+ ".csv": (CSVLoader, {}),
97
+ ".pdf": (PyMuPDFLoader, {}),
98
+ ".txt": (TextLoader, {"encoding": "utf8"}),
99
+ }
100
+
101
+
102
+
103
+
104
+
105
+
106
+ def load_single_document(file_content: BytesIO, file_type:str) -> List[Document]:
107
+ ext = "." + file_type.rsplit("/", 1)[1]
108
+
109
+ with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as temp_file:
110
+ temp_file.write(file_content.getvalue())
111
+ temp_file_path = temp_file.name
112
+
113
+ if ext in LOADER_MAPPING:
114
+ loader_class, loader_args = LOADER_MAPPING[ext]
115
+ loader = loader_class(temp_file_path, **loader_args)
116
+ results = loader.load()
117
+ os.remove(temp_file_path)
118
+ return results
119
+
120
+ raise ValueError(f"Unsupported file extension '{ext}'")
121
+
122
+
123
+
124
+ def load_uploaded_documents(uploaded_files, uploaded_files_type, ignored_files: List[str] = []) -> List[Document]:
125
+ with Pool(processes=os.cpu_count()) as pool:
126
+ results = []
127
+ with tqdm(total=len(uploaded_files), desc='Loading new documents', ncols=80) as progress_bar:
128
+ for i, uploaded_file in enumerate(uploaded_files):
129
+ file_type = uploaded_files_type[i]
130
+ file_content=BytesIO(uploaded_file.read())
131
+ docs = load_single_document(file_content, file_type)
132
+ results.extend(docs)
133
+ progress_bar.update()
134
+ return results
135
+
136
+
137
+ def get_pdf_text(uploaded_files):
138
+ ignored_files = [] # Add files to ignore if needed
139
+
140
+ uploaded_files_list = [file for file in uploaded_files]
141
+ uploaded_files_type = [file.type for file in uploaded_files]
142
+ results = load_uploaded_documents(uploaded_files_list, uploaded_files_type, ignored_files)
143
+ return results
144
+
145
+
146
+
147
+
148
+ def does_vectorstore_exist(persist_directory: str) -> bool:
149
+ """
150
+ Checks if vectorstore exists
151
+ """
152
+ if os.path.exists(os.path.join(persist_directory, 'index')):
153
+ if os.path.exists(os.path.join(persist_directory, 'chroma-collections.parquet')) and os.path.exists(os.path.join(persist_directory, 'chroma-embeddings.parquet')):
154
+ list_index_files = glob.glob(os.path.join(persist_directory, 'index/*.bin'))
155
+ list_index_files += glob.glob(os.path.join(persist_directory, 'index/*.pkl'))
156
+ # At least 1 documents are needed in a working vectorstore
157
+ if len(list_index_files) > 0:
158
+ print("Yes vectorstore exists")
159
+ return True
160
+ return False
161
+
162
+
163
+
164
+ def get_text_chunks(results,chunk_size,chunk_overlap):
165
+ texts=[]
166
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
167
+ texts = text_splitter.split_documents(results)
168
+ return texts
169
+
170
+
171
+ def get_vectorstore(results,embeddings_model_name,persist_directory,client_settings,chunk_size,chunk_overlap):
172
+ if embeddings_model_name == "openai":
173
+ embeddings = OpenAIEmbeddings()
174
+ print('OpenAI embeddings loaded')
175
+ elif embeddings_model_name == "Cohereembeddings":
176
+ embeddings = CohereEmbeddings()
177
+ print('Cohere embeddings loaded')
178
+
179
+ if does_vectorstore_exist(persist_directory):
180
+ # Update and store locally vectorstore
181
+ print(f"Appending to existing vectorstore at {persist_directory}")
182
+ db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
183
+ collection = db.get()
184
+ #print(f"Creating embeddings. May take some minutes...")
185
+ # #print(f"Loaded text size:{len(texts)}")
186
+ texts=get_text_chunks(results,chunk_size=chunk_size,chunk_overlap=chunk_overlap)
187
+ if len(texts)>0:
188
+ db.add_documents(texts)
189
+ else:
190
+ # Create and store locally vectorstore
191
+ print("Creating new vectorstore")
192
+ print(f"Creating embeddings. May take some minutes...")
193
+ texts=get_text_chunks(results,chunk_size=chunk_size,chunk_overlap=chunk_overlap)
194
+
195
+ db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS)
196
+ db.add_documents(texts)
197
+
198
+ return db
199
+
200
+
201
+ def get_conversation_chain(vectorstore,target_source_chunks,model_type):
202
+ retriever = vectorstore.as_retriever(search_kwargs={"k": target_source_chunks})
203
+
204
+ # activate/deactivate the streaming StdOut callback for LLMs
205
+ #callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
206
+ # Prepare the LLM.
207
+
208
+ match model_type:
209
+ case "OpenaAI":
210
+ llm= ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
211
+ case "Llama3":
212
+ llm = Ollama(model="llama3")
213
+ case _default:
214
+ # raise exception if model_type is not supported
215
+ raise Exception(f"Model type {model_type} is not supported. Please choose one of the following: ")
216
+
217
+
218
+ #llm = ChatOpenAI()
219
+ # llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":0.5, "max_length":512})
220
+ # llm = GPT4All(model=model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, verbose=False)
221
+
222
+ memory = ConversationBufferMemory(
223
+ memory_key='chat_history', return_messages=True)
224
+ conversation_chain = ConversationalRetrievalChain.from_llm(
225
+ llm=llm,
226
+ retriever=retriever,
227
+ memory=memory
228
+ )
229
+ return conversation_chain
230
+
231
+
232
+ st.set_page_config(page_title="Generate Insights",page_icon=":bar_chart:")
233
+
234
+
235
+ def handle_userinput(user_question):
236
+ response = st.session_state.conversation({'question': user_question})
237
+ st.session_state.chat_history = response['chat_history']
238
+
239
+ for i, message in enumerate(st.session_state.chat_history):
240
+ if i % 2 == 0:
241
+ st.write(user_template.replace(
242
+ "{{MSG}}", message.content), unsafe_allow_html=True)
243
+ else:
244
+ st.write(bot_template.replace(
245
+ "{{MSG}}", message.content), unsafe_allow_html=True)
246
+
247
+
248
+
249
+
250
+
251
+ def add_logo(logo_path, width, height):
252
+ """Read and return a resized logo"""
253
+ logo = Image.open(logo_path)
254
+ modified_logo = logo.resize((width, height))
255
+ return modified_logo
256
+
257
+ st.markdown(f'<style>{css}</style>', unsafe_allow_html=True)
258
+ col1, col2,col3,col4,col5,col6 = st.columns(6)
259
+
260
+ with col5:
261
+ my_logo = add_logo(logo_path="CampusX.jfif", width=100, height=20)
262
+ st.image(my_logo)
263
+ with col6:
264
+ pg_logo=add_logo(logo_path="Q&A logo.jfif", width=60, height=40)
265
+ st.image(pg_logo)
266
+
267
+
268
+
269
+
270
+
271
+ def main():
272
+ load_dotenv()
273
+ css2 = '''
274
+ <style>
275
+ [data-testid="stSidebar"]{
276
+ min-width: 300px;
277
+ max-width: 300px;
278
+ }
279
+ </style>
280
+ '''
281
+ st.markdown(css2, unsafe_allow_html=True)
282
+
283
+ st.write(css, unsafe_allow_html=True)
284
+
285
+ if "conversation" not in st.session_state:
286
+ st.session_state.conversation = None
287
+ if "chat_history" not in st.session_state:
288
+ st.session_state.chat_history = None
289
+
290
+ st.header(":blue Generate Insights :bar_chart:")
291
+ user_question = st.text_input("Ask a question about your documents:")
292
+ if user_question:
293
+ handle_userinput(user_question)
294
+
295
+ with st.sidebar:
296
+ st.subheader("Your documents")
297
+ uploaded_files = st.file_uploader("Upload documents", type=["pdf", "xlsx",'csv'], accept_multiple_files=True)
298
+ #texts=[]
299
+
300
+ if st.button("Process"):
301
+ with st.spinner("Processing"):
302
+
303
+ # get pdf text
304
+ if uploaded_files is not None :
305
+ raw_text = get_pdf_text(uploaded_files=uploaded_files)
306
+
307
+ # get the text chunks
308
+ text_chunks = get_text_chunks(results=raw_text,chunk_size=chunk_size,chunk_overlap=chunk_overlap)
309
+
310
+ # create vector store
311
+ vectorstore = get_vectorstore(results=text_chunks,embeddings_model_name=embeddings_model_name,persist_directory=persist_directory,client_settings=CHROMA_SETTINGS,chunk_size=chunk_size,chunk_overlap=chunk_overlap)
312
+
313
+ # create conversation chain
314
+ st.session_state.conversation = get_conversation_chain(vectorstore=vectorstore,target_source_chunks=target_source_chunks,model_type=model_type)
315
+
316
+
317
+ if __name__ == '__main__':
318
+ main()
319
+
constants.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from chromadb.config import Settings
4
+
5
+ load_dotenv()
6
+
7
+ # Define the folder for storing database
8
+ PERSIST_DIRECTORY = os.environ.get('PERSIST_DIRECTORY')
9
+ if PERSIST_DIRECTORY is None:
10
+ raise Exception("Please set the PERSIST_DIRECTORY environment variable")
11
+
12
+ # Define the Chroma settings
13
+ CHROMA_SETTINGS = Settings(
14
+ persist_directory=PERSIST_DIRECTORY,
15
+ anonymized_telemetry=False
16
+ )
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ chromadb
3
+ urllib3
4
+ PyMuPDF
5
+ python-dotenv
6
+ unstructured
7
+ extract-msg
8
+ tabulate
9
+ pandoc
10
+ pypandoc
11
+ tqdm
12
+ sentence_transformers
13
+ langchain-community
14
+ tiktoken
15
+ langchain-openai
16
+ langchainhub
17
+ langchain-cohere
18
+ pymupdf
19
+ streamlit
20
+ chroma-migrate
21
+ langchain
22
+ llama-index
23
+ langchain-experimental
24
+ ollama