spoggy commited on
Commit
51c7636
·
1 Parent(s): a4d6677
Files changed (1) hide show
  1. app.py +245 -3
app.py CHANGED
@@ -1,4 +1,246 @@
1
- import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from langchain.document_loaders.pdf import PDFPlumberLoader
2
+ from langchain.document_loaders import PyPDFLoader
3
+ from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
4
+ from transformers import pipeline
5
+ from langchain.prompts import PromptTemplate
6
+ from langchain.chat_models import ChatOpenAI
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.chains import RetrievalQA
9
+ from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings
10
+ from langchain.embeddings.openai import OpenAIEmbeddings
11
+ from langchain.llms import OpenAI, HuggingFacePipeline
12
+ from constants import *
13
+ from transformers import AutoTokenizer
14
+ import torch
15
+ import os
16
+ import re
17
+ from pprint import pprint
18
 
19
+ class PdfQA:
20
+ def __init__(self,config:dict = {}):
21
+ self.config = config
22
+ self.embedding = None
23
+ self.vectordb = None
24
+ self.llm = None
25
+ self.qa = None
26
+ self.retriever = None
27
+
28
+ # The following class methods are useful to create global GPU model instances
29
+ # This way we don't need to reload models in an interactive app,
30
+ # and the same model instance can be used across multiple user sessions
31
+ @classmethod
32
+ def create_instructor_xl(cls):
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ return HuggingFaceInstructEmbeddings(model_name=EMB_INSTRUCTOR_XL, model_kwargs={"device": device})
35
+
36
+ @classmethod
37
+ def create_sbert_mpnet(cls):
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ return HuggingFaceEmbeddings(model_name=EMB_SBERT_MPNET_BASE, model_kwargs={"device": device})
40
+
41
+ @classmethod
42
+ def create_flan_t5_xxl(cls, load_in_8bit=False):
43
+ # Local flan-t5-xxl with 8-bit quantization for inference
44
+ # Wrap it in HF pipeline for use with LangChain
45
+ return pipeline(
46
+ task="text2text-generation",
47
+ model="google/flan-t5-xxl",
48
+ max_new_tokens=200,
49
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
50
+ )
51
+ @classmethod
52
+ def create_flan_t5_xl(cls, load_in_8bit=False):
53
+ return pipeline(
54
+ task="text2text-generation",
55
+ model="google/flan-t5-xl",
56
+ max_new_tokens=200,
57
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
58
+ )
59
+
60
+ @classmethod
61
+ def create_flan_t5_small(cls, load_in_8bit=False):
62
+ # Local flan-t5-small for inference
63
+ # Wrap it in HF pipeline for use with LangChain
64
+ model="google/flan-t5-small"
65
+ tokenizer = AutoTokenizer.from_pretrained(model)
66
+ return pipeline(
67
+ task="text2text-generation",
68
+ model=model,
69
+ tokenizer = tokenizer,
70
+ max_new_tokens=100,
71
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
72
+ )
73
+ @classmethod
74
+ def create_flan_t5_base(cls, load_in_8bit=False):
75
+ # Wrap it in HF pipeline for use with LangChain
76
+ model="google/flan-t5-base"
77
+ tokenizer = AutoTokenizer.from_pretrained(model)
78
+ return pipeline(
79
+ task="text2text-generation",
80
+ model=model,
81
+ tokenizer = tokenizer,
82
+ max_new_tokens=100,
83
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
84
+ )
85
+ @classmethod
86
+ def create_flan_t5_large(cls, load_in_8bit=False):
87
+ # Wrap it in HF pipeline for use with LangChain
88
+ model="google/flan-t5-large"
89
+ tokenizer = AutoTokenizer.from_pretrained(model)
90
+ return pipeline(
91
+ task="text2text-generation",
92
+ model=model,
93
+ tokenizer = tokenizer,
94
+ max_new_tokens=100,
95
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
96
+ )
97
+ @classmethod
98
+ def create_fastchat_t5_xl(cls, load_in_8bit=False):
99
+ return pipeline(
100
+ task="text2text-generation",
101
+ model = "lmsys/fastchat-t5-3b-v1.0",
102
+ max_new_tokens=100,
103
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
104
+ )
105
+
106
+ @classmethod
107
+ def create_falcon_instruct_small(cls, load_in_8bit=False):
108
+ model = "tiiuae/falcon-7b-instruct"
109
+
110
+ tokenizer = AutoTokenizer.from_pretrained(model)
111
+ hf_pipeline = pipeline(
112
+ task="text-generation",
113
+ model = model,
114
+ tokenizer = tokenizer,
115
+ trust_remote_code = True,
116
+ max_new_tokens=100,
117
+ model_kwargs={
118
+ "device_map": "auto",
119
+ "load_in_8bit": load_in_8bit,
120
+ "max_length": 512,
121
+ "temperature": 0.01,
122
+ "torch_dtype":torch.bfloat16,
123
+ }
124
+ )
125
+ return hf_pipeline
126
+
127
+ def init_embeddings(self) -> None:
128
+ # OpenAI ada embeddings API
129
+ if self.config["embedding"] == EMB_OPENAI_ADA:
130
+ self.embedding = OpenAIEmbeddings()
131
+ elif self.config["embedding"] == EMB_INSTRUCTOR_XL:
132
+ # Local INSTRUCTOR-XL embeddings
133
+ if self.embedding is None:
134
+ self.embedding = PdfQA.create_instructor_xl()
135
+ elif self.config["embedding"] == EMB_SBERT_MPNET_BASE:
136
+ ## this is for SBERT
137
+ if self.embedding is None:
138
+ self.embedding = PdfQA.create_sbert_mpnet()
139
+ else:
140
+ self.embedding = None ## DuckDb uses sbert embeddings
141
+ # raise ValueError("Invalid config")
142
+
143
+ def init_models(self) -> None:
144
+ """ Initialize LLM models based on config """
145
+ load_in_8bit = self.config.get("load_in_8bit",False)
146
+ # OpenAI GPT 3.5 API
147
+ if self.config["llm"] == LLM_OPENAI_GPT35:
148
+ # OpenAI GPT 3.5 API
149
+ pass
150
+ elif self.config["llm"] == LLM_FLAN_T5_SMALL:
151
+ if self.llm is None:
152
+ self.llm = PdfQA.create_flan_t5_small(load_in_8bit=load_in_8bit)
153
+ elif self.config["llm"] == LLM_FLAN_T5_BASE:
154
+ if self.llm is None:
155
+ self.llm = PdfQA.create_flan_t5_base(load_in_8bit=load_in_8bit)
156
+ elif self.config["llm"] == LLM_FLAN_T5_LARGE:
157
+ if self.llm is None:
158
+ self.llm = PdfQA.create_flan_t5_large(load_in_8bit=load_in_8bit)
159
+ elif self.config["llm"] == LLM_FLAN_T5_XL:
160
+ if self.llm is None:
161
+ self.llm = PdfQA.create_flan_t5_xl(load_in_8bit=load_in_8bit)
162
+ elif self.config["llm"] == LLM_FLAN_T5_XXL:
163
+ if self.llm is None:
164
+ self.llm = PdfQA.create_flan_t5_xxl(load_in_8bit=load_in_8bit)
165
+ elif self.config["llm"] == LLM_FASTCHAT_T5_XL:
166
+ if self.llm is None:
167
+ self.llm = PdfQA.create_fastchat_t5_xl(load_in_8bit=load_in_8bit)
168
+ elif self.config["llm"] == LLM_FALCON_SMALL:
169
+ if self.llm is None:
170
+ self.llm = PdfQA.create_falcon_instruct_small(load_in_8bit=load_in_8bit)
171
+
172
+ else:
173
+ raise ValueError("Invalid config")
174
+ def vector_db_pdf(self) -> None:
175
+ """
176
+ creates vector db for the embeddings and persists them or loads a vector db from the persist directory
177
+ """
178
+ pdf_path = self.config.get("pdf_path",None)
179
+ persist_directory = self.config.get("persist_directory",None)
180
+ if persist_directory and os.path.exists(persist_directory):
181
+ ## Load from the persist db
182
+ self.vectordb = Chroma(persist_directory=persist_directory, embedding_function=self.embedding)
183
+ elif pdf_path and os.path.exists(pdf_path):
184
+ ## 1. Extract the documents
185
+ loader = PyPDFLoader(pdf_path)
186
+ documents = loader.load()
187
+ ## 2. Split the texts
188
+ text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
189
+ texts = text_splitter.split_documents(documents)
190
+ # text_splitter = TokenTextSplitter(chunk_size=100, chunk_overlap=10, encoding_name="cl100k_base") # This the encoding for text-embedding-ada-002
191
+ text_splitter = TokenTextSplitter(chunk_size=100, chunk_overlap=10) # This the encoding for text-embedding-ada-002
192
+ texts = text_splitter.split_documents(texts)
193
+
194
+ ## 3. Create Embeddings and add to chroma store
195
+ ##TODO: Validate if self.embedding is not None
196
+ self.vectordb = Chroma.from_documents(documents=texts, embedding=self.embedding, persist_directory=persist_directory)
197
+ else:
198
+ raise ValueError("NO PDF found")
199
+
200
+ def retreival_qa_chain(self):
201
+ """
202
+ Creates retrieval qa chain using vectordb as retrivar and LLM to complete the prompt
203
+ """
204
+ ##TODO: Use custom prompt
205
+ print("one", self)
206
+ pprint(vars(self))
207
+ self.retriever = self.vectordb.as_retriever(search_kwargs={"k":3})
208
+ print("two")
209
+
210
+ if self.config["llm"] == LLM_OPENAI_GPT35:
211
+ # Use ChatGPT API
212
+ self.qa = RetrievalQA.from_chain_type(llm=OpenAI(model_name=LLM_OPENAI_GPT35, temperature=0.), chain_type="stuff",\
213
+ retriever=self.vectordb.as_retriever(search_kwargs={"k":3}))
214
+ else:
215
+ hf_llm = HuggingFacePipeline(pipeline=self.llm,model_id=self.config["llm"])
216
+
217
+ self.qa = RetrievalQA.from_chain_type(llm=hf_llm, chain_type="stuff",retriever=self.retriever)
218
+ if self.config["llm"] == LLM_FLAN_T5_SMALL or self.config["llm"] == LLM_FLAN_T5_BASE or self.config["llm"] == LLM_FLAN_T5_LARGE:
219
+ question_t5_template = """
220
+ context: {context}
221
+ question: {question}
222
+ answer:
223
+ """
224
+ QUESTION_T5_PROMPT = PromptTemplate(
225
+ template=question_t5_template, input_variables=["context", "question"]
226
+ )
227
+ self.qa.combine_documents_chain.llm_chain.prompt = QUESTION_T5_PROMPT
228
+ self.qa.combine_documents_chain.verbose = True
229
+ self.qa.return_source_documents = True
230
+ def answer_query(self,question:str) ->str:
231
+ """
232
+ Answer the question
233
+ """
234
+
235
+ answer_dict = self.qa({"query":question,})
236
+ print(answer_dict)
237
+ answer = answer_dict["result"]
238
+ if self.config["llm"] == LLM_FASTCHAT_T5_XL:
239
+ answer = self._clean_fastchat_t5_output(answer)
240
+ return answer
241
+ def _clean_fastchat_t5_output(self, answer: str) -> str:
242
+ # Remove <pad> tags, double spaces, trailing newline
243
+ answer = re.sub(r"<pad>\s+", "", answer)
244
+ answer = re.sub(r" ", " ", answer)
245
+ answer = re.sub(r"\n$", "", answer)
246
+ return answer