spoggy commited on
Commit
6f9b2cb
·
1 Parent(s): bb8d2f1

move app to pdf_qa and create app

Browse files
Files changed (4) hide show
  1. README.md +4 -0
  2. app.py +82 -237
  3. constants.py +15 -0
  4. pdf_qa.py +246 -0
README.md CHANGED
@@ -12,3 +12,7 @@ pinned: false
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
 
14
  inspired by source https://www.shakudo.io/blog/build-pdf-bot-open-source-llms
 
 
 
 
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
 
14
  inspired by source https://www.shakudo.io/blog/build-pdf-bot-open-source-llms
15
+
16
+
17
+ # Deployed on hugging face
18
+ https://huggingface.co/spaces/spoggy/streamlit_pdf_qna_open_models
app.py CHANGED
@@ -1,246 +1,91 @@
1
- #from langchain.document_loaders.pdf import PDFPlumberLoader
2
- from langchain.document_loaders import PyPDFLoader
3
- from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
4
- from transformers import pipeline
5
- from langchain.prompts import PromptTemplate
6
- from langchain.chat_models import ChatOpenAI
7
- from langchain.vectorstores import Chroma
8
- from langchain.chains import RetrievalQA
9
- from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings
10
- from langchain.embeddings.openai import OpenAIEmbeddings
11
- from langchain.llms import OpenAI, HuggingFacePipeline
12
  from constants import *
13
- from transformers import AutoTokenizer
14
- import torch
15
- import os
16
- import re
17
- from pprint import pprint
18
 
19
- class PdfQA:
20
- def __init__(self,config:dict = {}):
21
- self.config = config
22
- self.embedding = None
23
- self.vectordb = None
24
- self.llm = None
25
- self.qa = None
26
- self.retriever = None
27
 
28
- # The following class methods are useful to create global GPU model instances
29
- # This way we don't need to reload models in an interactive app,
30
- # and the same model instance can be used across multiple user sessions
31
- @classmethod
32
- def create_instructor_xl(cls):
33
- device = "cuda" if torch.cuda.is_available() else "cpu"
34
- return HuggingFaceInstructEmbeddings(model_name=EMB_INSTRUCTOR_XL, model_kwargs={"device": device})
35
-
36
- @classmethod
37
- def create_sbert_mpnet(cls):
38
- device = "cuda" if torch.cuda.is_available() else "cpu"
39
- return HuggingFaceEmbeddings(model_name=EMB_SBERT_MPNET_BASE, model_kwargs={"device": device})
40
-
41
- @classmethod
42
- def create_flan_t5_xxl(cls, load_in_8bit=False):
43
- # Local flan-t5-xxl with 8-bit quantization for inference
44
- # Wrap it in HF pipeline for use with LangChain
45
- return pipeline(
46
- task="text2text-generation",
47
- model="google/flan-t5-xxl",
48
- max_new_tokens=200,
49
- model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
50
- )
51
- @classmethod
52
- def create_flan_t5_xl(cls, load_in_8bit=False):
53
- return pipeline(
54
- task="text2text-generation",
55
- model="google/flan-t5-xl",
56
- max_new_tokens=200,
57
- model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
58
- )
59
-
60
- @classmethod
61
- def create_flan_t5_small(cls, load_in_8bit=False):
62
- # Local flan-t5-small for inference
63
- # Wrap it in HF pipeline for use with LangChain
64
- model="google/flan-t5-small"
65
- tokenizer = AutoTokenizer.from_pretrained(model)
66
- return pipeline(
67
- task="text2text-generation",
68
- model=model,
69
- tokenizer = tokenizer,
70
- max_new_tokens=100,
71
- model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
72
- )
73
- @classmethod
74
- def create_flan_t5_base(cls, load_in_8bit=False):
75
- # Wrap it in HF pipeline for use with LangChain
76
- model="google/flan-t5-base"
77
- tokenizer = AutoTokenizer.from_pretrained(model)
78
- return pipeline(
79
- task="text2text-generation",
80
- model=model,
81
- tokenizer = tokenizer,
82
- max_new_tokens=100,
83
- model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
84
- )
85
- @classmethod
86
- def create_flan_t5_large(cls, load_in_8bit=False):
87
- # Wrap it in HF pipeline for use with LangChain
88
- model="google/flan-t5-large"
89
- tokenizer = AutoTokenizer.from_pretrained(model)
90
- return pipeline(
91
- task="text2text-generation",
92
- model=model,
93
- tokenizer = tokenizer,
94
- max_new_tokens=100,
95
- model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
96
- )
97
- @classmethod
98
- def create_fastchat_t5_xl(cls, load_in_8bit=False):
99
- return pipeline(
100
- task="text2text-generation",
101
- model = "lmsys/fastchat-t5-3b-v1.0",
102
- max_new_tokens=100,
103
- model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
104
- )
105
-
106
- @classmethod
107
- def create_falcon_instruct_small(cls, load_in_8bit=False):
108
- model = "tiiuae/falcon-7b-instruct"
109
 
110
- tokenizer = AutoTokenizer.from_pretrained(model)
111
- hf_pipeline = pipeline(
112
- task="text-generation",
113
- model = model,
114
- tokenizer = tokenizer,
115
- trust_remote_code = True,
116
- max_new_tokens=100,
117
- model_kwargs={
118
- "device_map": "auto",
119
- "load_in_8bit": load_in_8bit,
120
- "max_length": 512,
121
- "temperature": 0.01,
122
- "torch_dtype":torch.bfloat16,
123
- }
124
- )
125
- return hf_pipeline
126
-
127
- def init_embeddings(self) -> None:
128
- # OpenAI ada embeddings API
129
- if self.config["embedding"] == EMB_OPENAI_ADA:
130
- self.embedding = OpenAIEmbeddings()
131
- elif self.config["embedding"] == EMB_INSTRUCTOR_XL:
132
- # Local INSTRUCTOR-XL embeddings
133
- if self.embedding is None:
134
- self.embedding = PdfQA.create_instructor_xl()
135
- elif self.config["embedding"] == EMB_SBERT_MPNET_BASE:
136
- ## this is for SBERT
137
- if self.embedding is None:
138
- self.embedding = PdfQA.create_sbert_mpnet()
139
- else:
140
- self.embedding = None ## DuckDb uses sbert embeddings
141
- # raise ValueError("Invalid config")
142
 
143
- def init_models(self) -> None:
144
- """ Initialize LLM models based on config """
145
- load_in_8bit = self.config.get("load_in_8bit",False)
146
- # OpenAI GPT 3.5 API
147
- if self.config["llm"] == LLM_OPENAI_GPT35:
148
- # OpenAI GPT 3.5 API
149
- pass
150
- elif self.config["llm"] == LLM_FLAN_T5_SMALL:
151
- if self.llm is None:
152
- self.llm = PdfQA.create_flan_t5_small(load_in_8bit=load_in_8bit)
153
- elif self.config["llm"] == LLM_FLAN_T5_BASE:
154
- if self.llm is None:
155
- self.llm = PdfQA.create_flan_t5_base(load_in_8bit=load_in_8bit)
156
- elif self.config["llm"] == LLM_FLAN_T5_LARGE:
157
- if self.llm is None:
158
- self.llm = PdfQA.create_flan_t5_large(load_in_8bit=load_in_8bit)
159
- elif self.config["llm"] == LLM_FLAN_T5_XL:
160
- if self.llm is None:
161
- self.llm = PdfQA.create_flan_t5_xl(load_in_8bit=load_in_8bit)
162
- elif self.config["llm"] == LLM_FLAN_T5_XXL:
163
- if self.llm is None:
164
- self.llm = PdfQA.create_flan_t5_xxl(load_in_8bit=load_in_8bit)
165
- elif self.config["llm"] == LLM_FASTCHAT_T5_XL:
166
- if self.llm is None:
167
- self.llm = PdfQA.create_fastchat_t5_xl(load_in_8bit=load_in_8bit)
168
- elif self.config["llm"] == LLM_FALCON_SMALL:
169
- if self.llm is None:
170
- self.llm = PdfQA.create_falcon_instruct_small(load_in_8bit=load_in_8bit)
171
-
172
- else:
173
- raise ValueError("Invalid config")
174
- def vector_db_pdf(self) -> None:
175
- """
176
- creates vector db for the embeddings and persists them or loads a vector db from the persist directory
177
- """
178
- pdf_path = self.config.get("pdf_path",None)
179
- persist_directory = self.config.get("persist_directory",None)
180
- if persist_directory and os.path.exists(persist_directory):
181
- ## Load from the persist db
182
- self.vectordb = Chroma(persist_directory=persist_directory, embedding_function=self.embedding)
183
- elif pdf_path and os.path.exists(pdf_path):
184
- ## 1. Extract the documents
185
- loader = PyPDFLoader(pdf_path)
186
- documents = loader.load()
187
- ## 2. Split the texts
188
- text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
189
- texts = text_splitter.split_documents(documents)
190
- # text_splitter = TokenTextSplitter(chunk_size=100, chunk_overlap=10, encoding_name="cl100k_base") # This the encoding for text-embedding-ada-002
191
- text_splitter = TokenTextSplitter(chunk_size=100, chunk_overlap=10) # This the encoding for text-embedding-ada-002
192
- texts = text_splitter.split_documents(texts)
193
 
194
- ## 3. Create Embeddings and add to chroma store
195
- ##TODO: Validate if self.embedding is not None
196
- self.vectordb = Chroma.from_documents(documents=texts, embedding=self.embedding, persist_directory=persist_directory)
197
- else:
198
- raise ValueError("NO PDF found")
 
 
 
 
 
 
 
 
 
199
 
200
- def retreival_qa_chain(self):
201
- """
202
- Creates retrieval qa chain using vectordb as retrivar and LLM to complete the prompt
203
- """
204
- ##TODO: Use custom prompt
205
- print("one", self)
206
- pprint(vars(self))
207
- self.retriever = self.vectordb.as_retriever(search_kwargs={"k":3})
208
- print("two")
209
-
210
- if self.config["llm"] == LLM_OPENAI_GPT35:
211
- # Use ChatGPT API
212
- self.qa = RetrievalQA.from_chain_type(llm=OpenAI(model_name=LLM_OPENAI_GPT35, temperature=0.), chain_type="stuff",\
213
- retriever=self.vectordb.as_retriever(search_kwargs={"k":3}))
214
- else:
215
- hf_llm = HuggingFacePipeline(pipeline=self.llm,model_id=self.config["llm"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- self.qa = RetrievalQA.from_chain_type(llm=hf_llm, chain_type="stuff",retriever=self.retriever)
218
- if self.config["llm"] == LLM_FLAN_T5_SMALL or self.config["llm"] == LLM_FLAN_T5_BASE or self.config["llm"] == LLM_FLAN_T5_LARGE:
219
- question_t5_template = """
220
- context: {context}
221
- question: {question}
222
- answer:
223
- """
224
- QUESTION_T5_PROMPT = PromptTemplate(
225
- template=question_t5_template, input_variables=["context", "question"]
226
- )
227
- self.qa.combine_documents_chain.llm_chain.prompt = QUESTION_T5_PROMPT
228
- self.qa.combine_documents_chain.verbose = True
229
- self.qa.return_source_documents = True
230
- def answer_query(self,question:str) ->str:
231
- """
232
- Answer the question
233
- """
234
 
235
- answer_dict = self.qa({"query":question,})
236
- print(answer_dict)
237
- answer = answer_dict["result"]
238
- if self.config["llm"] == LLM_FASTCHAT_T5_XL:
239
- answer = self._clean_fastchat_t5_output(answer)
240
- return answer
241
- def _clean_fastchat_t5_output(self, answer: str) -> str:
242
- # Remove <pad> tags, double spaces, trailing newline
243
- answer = re.sub(r"<pad>\s+", "", answer)
244
- answer = re.sub(r" ", " ", answer)
245
- answer = re.sub(r"\n$", "", answer)
246
- return answer
 
1
+ import streamlit as st
2
+ from pdf_qa import PdfQA
3
+ from pathlib import Path
4
+ from tempfile import NamedTemporaryFile
5
+ import time
6
+ import shutil
 
 
 
 
 
7
  from constants import *
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Streamlit app code
12
+ st.set_page_config(
13
+ page_title='Q&A Bot for PDF',
14
+ page_icon='🔖',
15
+ layout='wide',
16
+ initial_sidebar_state='auto',
17
+ )
18
+
19
+
20
+ if "pdf_qa_model" not in st.session_state:
21
+ st.session_state["pdf_qa_model"]:PdfQA = PdfQA() ## Intialisation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ ## To cache resource across multiple session
24
+ @st.cache_resource
25
+ def load_llm(llm,load_in_8bit):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ if llm == LLM_OPENAI_GPT35:
28
+ pass
29
+ elif llm == LLM_FLAN_T5_SMALL:
30
+ return PdfQA.create_flan_t5_small(load_in_8bit)
31
+ elif llm == LLM_FLAN_T5_BASE:
32
+ return PdfQA.create_flan_t5_base(load_in_8bit)
33
+ elif llm == LLM_FLAN_T5_LARGE:
34
+ return PdfQA.create_flan_t5_large(load_in_8bit)
35
+ elif llm == LLM_FASTCHAT_T5_XL:
36
+ return PdfQA.create_fastchat_t5_xl(load_in_8bit)
37
+ elif llm == LLM_FALCON_SMALL:
38
+ return PdfQA.create_falcon_instruct_small(load_in_8bit)
39
+ else:
40
+ raise ValueError("Invalid LLM setting")
41
 
42
+ ## To cache resource across multiple session
43
+ @st.cache_resource
44
+ def load_emb(emb):
45
+ if emb == EMB_INSTRUCTOR_XL:
46
+ return PdfQA.create_instructor_xl()
47
+ elif emb == EMB_SBERT_MPNET_BASE:
48
+ return PdfQA.create_sbert_mpnet()
49
+ elif emb == EMB_SBERT_MINILM:
50
+ pass ##ChromaDB takes care
51
+ else:
52
+ raise ValueError("Invalid embedding setting")
53
+
54
+
55
+
56
+ st.title("PDF Q&A (Self hosted LLMs)")
57
+
58
+ with st.sidebar:
59
+ emb = st.radio("**Select Embedding Model**", [EMB_INSTRUCTOR_XL, EMB_SBERT_MPNET_BASE,EMB_SBERT_MINILM],index=1)
60
+ llm = st.radio("**Select LLM Model**", [LLM_FASTCHAT_T5_XL, LLM_FLAN_T5_SMALL,LLM_FLAN_T5_BASE,LLM_FLAN_T5_LARGE,LLM_FLAN_T5_XL,LLM_FALCON_SMALL],index=2)
61
+ load_in_8bit = st.radio("**Load 8 bit**", [True, False],index=1)
62
+ pdf_file = st.file_uploader("**Upload PDF**", type="pdf")
63
+
64
+
65
+ if st.button("Submit") and pdf_file is not None:
66
+ with st.spinner(text="Uploading PDF and Generating Embeddings.."):
67
+ with NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
68
+ shutil.copyfileobj(pdf_file, tmp)
69
+ tmp_path = Path(tmp.name)
70
+ st.session_state["pdf_qa_model"].config = {
71
+ "pdf_path": str(tmp_path),
72
+ "embedding": emb,
73
+ "llm": llm,
74
+ "load_in_8bit": load_in_8bit
75
+ }
76
+ st.session_state["pdf_qa_model"].embedding = load_emb(emb)
77
+ st.session_state["pdf_qa_model"].llm = load_llm(llm,load_in_8bit)
78
+ st.session_state["pdf_qa_model"].init_embeddings()
79
+ st.session_state["pdf_qa_model"].init_models()
80
+ st.session_state["pdf_qa_model"].vector_db_pdf()
81
+ st.sidebar.success("PDF uploaded successfully")
82
 
83
+ question = st.text_input('Ask a question', 'What is this document?')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ if st.button("Answer"):
86
+ try:
87
+ st.session_state["pdf_qa_model"].retreival_qa_chain()
88
+ answer = st.session_state["pdf_qa_model"].answer_query(question)
89
+ st.write(f"{answer}")
90
+ except Exception as e:
91
+ st.error(f"Error answering the question: {str(e)}")
 
 
 
 
 
constants.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Constants
2
+ EMB_OPENAI_ADA = "text-embedding-ada-002"
3
+ EMB_INSTRUCTOR_XL = "hkunlp/instructor-xl"
4
+ EMB_SBERT_MPNET_BASE = "sentence-transformers/all-mpnet-base-v2" # Chroma takes care if embeddings are None
5
+ EMB_SBERT_MINILM = "sentence-transformers/all-MiniLM-L6-v2" # Chroma takes care if embeddings are None
6
+
7
+
8
+ LLM_OPENAI_GPT35 = "gpt-3.5-turbo"
9
+ LLM_FLAN_T5_XXL = "google/flan-t5-xxl"
10
+ LLM_FLAN_T5_XL = "google/flan-t5-xl"
11
+ LLM_FASTCHAT_T5_XL = "lmsys/fastchat-t5-3b-v1.0"
12
+ LLM_FLAN_T5_SMALL = "google/flan-t5-small"
13
+ LLM_FLAN_T5_BASE = "google/flan-t5-base"
14
+ LLM_FLAN_T5_LARGE = "google/flan-t5-large"
15
+ LLM_FALCON_SMALL = "tiiuae/falcon-7b-instruct"
pdf_qa.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from langchain.document_loaders.pdf import PDFPlumberLoader
2
+ from langchain.document_loaders import PyPDFLoader
3
+ from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
4
+ from transformers import pipeline
5
+ from langchain.prompts import PromptTemplate
6
+ from langchain.chat_models import ChatOpenAI
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.chains import RetrievalQA
9
+ from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings
10
+ from langchain.embeddings.openai import OpenAIEmbeddings
11
+ from langchain.llms import OpenAI, HuggingFacePipeline
12
+ from constants import *
13
+ from transformers import AutoTokenizer
14
+ import torch
15
+ import os
16
+ import re
17
+ from pprint import pprint
18
+
19
+ class PdfQA:
20
+ def __init__(self,config:dict = {}):
21
+ self.config = config
22
+ self.embedding = None
23
+ self.vectordb = None
24
+ self.llm = None
25
+ self.qa = None
26
+ self.retriever = None
27
+
28
+ # The following class methods are useful to create global GPU model instances
29
+ # This way we don't need to reload models in an interactive app,
30
+ # and the same model instance can be used across multiple user sessions
31
+ @classmethod
32
+ def create_instructor_xl(cls):
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ return HuggingFaceInstructEmbeddings(model_name=EMB_INSTRUCTOR_XL, model_kwargs={"device": device})
35
+
36
+ @classmethod
37
+ def create_sbert_mpnet(cls):
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ return HuggingFaceEmbeddings(model_name=EMB_SBERT_MPNET_BASE, model_kwargs={"device": device})
40
+
41
+ @classmethod
42
+ def create_flan_t5_xxl(cls, load_in_8bit=False):
43
+ # Local flan-t5-xxl with 8-bit quantization for inference
44
+ # Wrap it in HF pipeline for use with LangChain
45
+ return pipeline(
46
+ task="text2text-generation",
47
+ model="google/flan-t5-xxl",
48
+ max_new_tokens=200,
49
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
50
+ )
51
+ @classmethod
52
+ def create_flan_t5_xl(cls, load_in_8bit=False):
53
+ return pipeline(
54
+ task="text2text-generation",
55
+ model="google/flan-t5-xl",
56
+ max_new_tokens=200,
57
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
58
+ )
59
+
60
+ @classmethod
61
+ def create_flan_t5_small(cls, load_in_8bit=False):
62
+ # Local flan-t5-small for inference
63
+ # Wrap it in HF pipeline for use with LangChain
64
+ model="google/flan-t5-small"
65
+ tokenizer = AutoTokenizer.from_pretrained(model)
66
+ return pipeline(
67
+ task="text2text-generation",
68
+ model=model,
69
+ tokenizer = tokenizer,
70
+ max_new_tokens=100,
71
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
72
+ )
73
+ @classmethod
74
+ def create_flan_t5_base(cls, load_in_8bit=False):
75
+ # Wrap it in HF pipeline for use with LangChain
76
+ model="google/flan-t5-base"
77
+ tokenizer = AutoTokenizer.from_pretrained(model)
78
+ return pipeline(
79
+ task="text2text-generation",
80
+ model=model,
81
+ tokenizer = tokenizer,
82
+ max_new_tokens=100,
83
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
84
+ )
85
+ @classmethod
86
+ def create_flan_t5_large(cls, load_in_8bit=False):
87
+ # Wrap it in HF pipeline for use with LangChain
88
+ model="google/flan-t5-large"
89
+ tokenizer = AutoTokenizer.from_pretrained(model)
90
+ return pipeline(
91
+ task="text2text-generation",
92
+ model=model,
93
+ tokenizer = tokenizer,
94
+ max_new_tokens=100,
95
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
96
+ )
97
+ @classmethod
98
+ def create_fastchat_t5_xl(cls, load_in_8bit=False):
99
+ return pipeline(
100
+ task="text2text-generation",
101
+ model = "lmsys/fastchat-t5-3b-v1.0",
102
+ max_new_tokens=100,
103
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
104
+ )
105
+
106
+ @classmethod
107
+ def create_falcon_instruct_small(cls, load_in_8bit=False):
108
+ model = "tiiuae/falcon-7b-instruct"
109
+
110
+ tokenizer = AutoTokenizer.from_pretrained(model)
111
+ hf_pipeline = pipeline(
112
+ task="text-generation",
113
+ model = model,
114
+ tokenizer = tokenizer,
115
+ trust_remote_code = True,
116
+ max_new_tokens=100,
117
+ model_kwargs={
118
+ "device_map": "auto",
119
+ "load_in_8bit": load_in_8bit,
120
+ "max_length": 512,
121
+ "temperature": 0.01,
122
+ "torch_dtype":torch.bfloat16,
123
+ }
124
+ )
125
+ return hf_pipeline
126
+
127
+ def init_embeddings(self) -> None:
128
+ # OpenAI ada embeddings API
129
+ if self.config["embedding"] == EMB_OPENAI_ADA:
130
+ self.embedding = OpenAIEmbeddings()
131
+ elif self.config["embedding"] == EMB_INSTRUCTOR_XL:
132
+ # Local INSTRUCTOR-XL embeddings
133
+ if self.embedding is None:
134
+ self.embedding = PdfQA.create_instructor_xl()
135
+ elif self.config["embedding"] == EMB_SBERT_MPNET_BASE:
136
+ ## this is for SBERT
137
+ if self.embedding is None:
138
+ self.embedding = PdfQA.create_sbert_mpnet()
139
+ else:
140
+ self.embedding = None ## DuckDb uses sbert embeddings
141
+ # raise ValueError("Invalid config")
142
+
143
+ def init_models(self) -> None:
144
+ """ Initialize LLM models based on config """
145
+ load_in_8bit = self.config.get("load_in_8bit",False)
146
+ # OpenAI GPT 3.5 API
147
+ if self.config["llm"] == LLM_OPENAI_GPT35:
148
+ # OpenAI GPT 3.5 API
149
+ pass
150
+ elif self.config["llm"] == LLM_FLAN_T5_SMALL:
151
+ if self.llm is None:
152
+ self.llm = PdfQA.create_flan_t5_small(load_in_8bit=load_in_8bit)
153
+ elif self.config["llm"] == LLM_FLAN_T5_BASE:
154
+ if self.llm is None:
155
+ self.llm = PdfQA.create_flan_t5_base(load_in_8bit=load_in_8bit)
156
+ elif self.config["llm"] == LLM_FLAN_T5_LARGE:
157
+ if self.llm is None:
158
+ self.llm = PdfQA.create_flan_t5_large(load_in_8bit=load_in_8bit)
159
+ elif self.config["llm"] == LLM_FLAN_T5_XL:
160
+ if self.llm is None:
161
+ self.llm = PdfQA.create_flan_t5_xl(load_in_8bit=load_in_8bit)
162
+ elif self.config["llm"] == LLM_FLAN_T5_XXL:
163
+ if self.llm is None:
164
+ self.llm = PdfQA.create_flan_t5_xxl(load_in_8bit=load_in_8bit)
165
+ elif self.config["llm"] == LLM_FASTCHAT_T5_XL:
166
+ if self.llm is None:
167
+ self.llm = PdfQA.create_fastchat_t5_xl(load_in_8bit=load_in_8bit)
168
+ elif self.config["llm"] == LLM_FALCON_SMALL:
169
+ if self.llm is None:
170
+ self.llm = PdfQA.create_falcon_instruct_small(load_in_8bit=load_in_8bit)
171
+
172
+ else:
173
+ raise ValueError("Invalid config")
174
+ def vector_db_pdf(self) -> None:
175
+ """
176
+ creates vector db for the embeddings and persists them or loads a vector db from the persist directory
177
+ """
178
+ pdf_path = self.config.get("pdf_path",None)
179
+ persist_directory = self.config.get("persist_directory",None)
180
+ if persist_directory and os.path.exists(persist_directory):
181
+ ## Load from the persist db
182
+ self.vectordb = Chroma(persist_directory=persist_directory, embedding_function=self.embedding)
183
+ elif pdf_path and os.path.exists(pdf_path):
184
+ ## 1. Extract the documents
185
+ loader = PyPDFLoader(pdf_path)
186
+ documents = loader.load()
187
+ ## 2. Split the texts
188
+ text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
189
+ texts = text_splitter.split_documents(documents)
190
+ # text_splitter = TokenTextSplitter(chunk_size=100, chunk_overlap=10, encoding_name="cl100k_base") # This the encoding for text-embedding-ada-002
191
+ text_splitter = TokenTextSplitter(chunk_size=100, chunk_overlap=10) # This the encoding for text-embedding-ada-002
192
+ texts = text_splitter.split_documents(texts)
193
+
194
+ ## 3. Create Embeddings and add to chroma store
195
+ ##TODO: Validate if self.embedding is not None
196
+ self.vectordb = Chroma.from_documents(documents=texts, embedding=self.embedding, persist_directory=persist_directory)
197
+ else:
198
+ raise ValueError("NO PDF found")
199
+
200
+ def retreival_qa_chain(self):
201
+ """
202
+ Creates retrieval qa chain using vectordb as retrivar and LLM to complete the prompt
203
+ """
204
+ ##TODO: Use custom prompt
205
+ print("one", self)
206
+ pprint(vars(self))
207
+ self.retriever = self.vectordb.as_retriever(search_kwargs={"k":3})
208
+ print("two")
209
+
210
+ if self.config["llm"] == LLM_OPENAI_GPT35:
211
+ # Use ChatGPT API
212
+ self.qa = RetrievalQA.from_chain_type(llm=OpenAI(model_name=LLM_OPENAI_GPT35, temperature=0.), chain_type="stuff",\
213
+ retriever=self.vectordb.as_retriever(search_kwargs={"k":3}))
214
+ else:
215
+ hf_llm = HuggingFacePipeline(pipeline=self.llm,model_id=self.config["llm"])
216
+
217
+ self.qa = RetrievalQA.from_chain_type(llm=hf_llm, chain_type="stuff",retriever=self.retriever)
218
+ if self.config["llm"] == LLM_FLAN_T5_SMALL or self.config["llm"] == LLM_FLAN_T5_BASE or self.config["llm"] == LLM_FLAN_T5_LARGE:
219
+ question_t5_template = """
220
+ context: {context}
221
+ question: {question}
222
+ answer:
223
+ """
224
+ QUESTION_T5_PROMPT = PromptTemplate(
225
+ template=question_t5_template, input_variables=["context", "question"]
226
+ )
227
+ self.qa.combine_documents_chain.llm_chain.prompt = QUESTION_T5_PROMPT
228
+ self.qa.combine_documents_chain.verbose = True
229
+ self.qa.return_source_documents = True
230
+ def answer_query(self,question:str) ->str:
231
+ """
232
+ Answer the question
233
+ """
234
+
235
+ answer_dict = self.qa({"query":question,})
236
+ print(answer_dict)
237
+ answer = answer_dict["result"]
238
+ if self.config["llm"] == LLM_FASTCHAT_T5_XL:
239
+ answer = self._clean_fastchat_t5_output(answer)
240
+ return answer
241
+ def _clean_fastchat_t5_output(self, answer: str) -> str:
242
+ # Remove <pad> tags, double spaces, trailing newline
243
+ answer = re.sub(r"<pad>\s+", "", answer)
244
+ answer = re.sub(r" ", " ", answer)
245
+ answer = re.sub(r"\n$", "", answer)
246
+ return answer