Update app.py
Browse files
app.py
CHANGED
@@ -67,16 +67,16 @@ class Retriever:
|
|
67 |
def load_chunks(self):
|
68 |
self.text = self.extract_text_from_pdf(self.file_path)
|
69 |
text_splitter = RecursiveCharacterTextSplitter(
|
70 |
-
chunk_size=
|
71 |
chunk_overlap=20,
|
72 |
length_function=self.token_len,
|
73 |
-
separators=["\n\n", "
|
74 |
)
|
75 |
|
76 |
self.chunks = text_splitter.split_text(self.text)
|
77 |
|
78 |
def load_context_embeddings(self):
|
79 |
-
encoded_input = self.context_tokenizer(self.chunks, return_tensors='pt', padding=True, truncation=True, max_length=
|
80 |
|
81 |
with torch.no_grad():
|
82 |
model_output = self.context_model(**encoded_input)
|
@@ -89,20 +89,16 @@ class Retriever:
|
|
89 |
encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", truncation=True, padding=True).to(device)
|
90 |
|
91 |
with torch.no_grad():
|
92 |
-
|
93 |
-
|
94 |
|
95 |
query_vector_np = query_vector.cpu().numpy()
|
96 |
D, I = self.index.search(query_vector_np, k)
|
97 |
|
98 |
-
retrieved_texts = [self.chunks[i] for i in I[0]]
|
99 |
|
100 |
scores = [d for d in D[0]]
|
101 |
|
102 |
-
# print("Top 5 retrieved texts and their associated scores:")
|
103 |
-
# for idx, (text, score) in enumerate(zip(retrieved_texts, scores)):
|
104 |
-
# print(f"{idx + 1}. Text: {text} \n Score: {score:.4f}\n")
|
105 |
-
|
106 |
return retrieved_texts
|
107 |
|
108 |
class RAG:
|
@@ -115,22 +111,23 @@ class RAG:
|
|
115 |
|
116 |
# generator_name = "valhalla/bart-large-finetuned-squadv1"
|
117 |
# generator_name = "'vblagoje/bart_lfqa'"
|
118 |
-
generator_name = "a-ware/bart-squadv2"
|
119 |
-
|
120 |
self.generator_tokenizer = BartTokenizer.from_pretrained(generator_name)
|
121 |
self.generator_model = BartForConditionalGeneration.from_pretrained(generator_name).to(device)
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
self.retriever = Retriever(file_path, device, context_model_name, question_model_name)
|
124 |
self.retriever.load_chunks()
|
125 |
self.retriever.load_context_embeddings()
|
126 |
|
127 |
-
def get_answer(self, question, context):
|
128 |
-
input_text = "context: %s <question for context: %s </s>" % (context,question)
|
129 |
-
features = self.generator_tokenizer([input_text], return_tensors='pt')
|
130 |
-
out = self.generator_model.generate(input_ids=features['input_ids'].to(device), attention_mask=features['attention_mask'].to(device))
|
131 |
-
return self.generator_tokenizer.decode(out[0])
|
132 |
|
133 |
-
def
|
134 |
context = self.retriever.retrieve_top_k(question, k=5)
|
135 |
# input_text = question + " " + " ".join(context)
|
136 |
|
@@ -144,22 +141,46 @@ class RAG:
|
|
144 |
answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
145 |
return answer
|
146 |
|
|
|
|
|
|
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
question_model_name="facebook/dpr-question_encoder-multiset-base"
|
151 |
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
-
|
|
|
155 |
|
156 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
st.title("RAG Model Query Interface")
|
159 |
|
160 |
-
|
|
|
|
|
161 |
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
st.write(answer)
|
|
|
67 |
def load_chunks(self):
|
68 |
self.text = self.extract_text_from_pdf(self.file_path)
|
69 |
text_splitter = RecursiveCharacterTextSplitter(
|
70 |
+
chunk_size=150,
|
71 |
chunk_overlap=20,
|
72 |
length_function=self.token_len,
|
73 |
+
separators=["Section", "\n\n", "\n", ".", " ", ""]
|
74 |
)
|
75 |
|
76 |
self.chunks = text_splitter.split_text(self.text)
|
77 |
|
78 |
def load_context_embeddings(self):
|
79 |
+
encoded_input = self.context_tokenizer(self.chunks, return_tensors='pt', padding=True, truncation=True, max_length=300).to(device)
|
80 |
|
81 |
with torch.no_grad():
|
82 |
model_output = self.context_model(**encoded_input)
|
|
|
89 |
encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", truncation=True, padding=True).to(device)
|
90 |
|
91 |
with torch.no_grad():
|
92 |
+
model_output = self.question_model(**encoded_query)
|
93 |
+
query_vector = model_output.pooler_output
|
94 |
|
95 |
query_vector_np = query_vector.cpu().numpy()
|
96 |
D, I = self.index.search(query_vector_np, k)
|
97 |
|
98 |
+
retrieved_texts = [' '.join(self.chunks[i].split('\n')) for i in I[0]] # Replacing newlines with spaces
|
99 |
|
100 |
scores = [d for d in D[0]]
|
101 |
|
|
|
|
|
|
|
|
|
102 |
return retrieved_texts
|
103 |
|
104 |
class RAG:
|
|
|
111 |
|
112 |
# generator_name = "valhalla/bart-large-finetuned-squadv1"
|
113 |
# generator_name = "'vblagoje/bart_lfqa'"
|
114 |
+
# generator_name = "a-ware/bart-squadv2"
|
115 |
+
|
116 |
self.generator_tokenizer = BartTokenizer.from_pretrained(generator_name)
|
117 |
self.generator_model = BartForConditionalGeneration.from_pretrained(generator_name).to(device)
|
118 |
|
119 |
+
# generator_name = "MaRiOrOsSi/t5-base-finetuned-question-answering"
|
120 |
+
# generator_name = "t5-small"
|
121 |
+
|
122 |
+
# self.generator_tokenizer = T5Tokenizer.from_pretrained(generator_name)
|
123 |
+
# self.generator_model = T5ForConditionalGeneration.from_pretrained(generator_name)
|
124 |
+
|
125 |
self.retriever = Retriever(file_path, device, context_model_name, question_model_name)
|
126 |
self.retriever.load_chunks()
|
127 |
self.retriever.load_context_embeddings()
|
128 |
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
+
def abstractive_query(self, question):
|
131 |
context = self.retriever.retrieve_top_k(question, k=5)
|
132 |
# input_text = question + " " + " ".join(context)
|
133 |
|
|
|
141 |
answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
142 |
return answer
|
143 |
|
144 |
+
def extractive_query(self, question):
|
145 |
+
context = self.retriever.retrieve_top_k(question, k=15)
|
146 |
+
generator_name = "valhalla/bart-large-finetuned-squadv1"
|
147 |
|
148 |
+
self.generator_tokenizer = AutoTokenizer.from_pretrained(generator_name)
|
149 |
+
self.generator_model = BartForQuestionAnswering.from_pretrained(generator_name).to(device)
|
|
|
150 |
|
151 |
+
inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=200 , padding="max_length")
|
152 |
+
with torch.no_grad():
|
153 |
+
model_inputs = inputs.to(device)
|
154 |
+
outputs = self.generator_model(**model_inputs)
|
155 |
+
|
156 |
+
answer_start_index = outputs.start_logits.argmax()
|
157 |
+
answer_end_index = outputs.end_logits.argmax()
|
158 |
|
159 |
+
if answer_end_index < answer_start_index:
|
160 |
+
answer_start_index, answer_end_index = answer_end_index, answer_start_index
|
161 |
|
162 |
+
print(answer_start_index, answer_end_index)
|
163 |
+
|
164 |
+
predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
|
165 |
+
answer = self.generator_tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
|
166 |
+
answer = answer.replace('\n', ' ').strip()
|
167 |
+
answer = answer.replace('$', '')
|
168 |
+
|
169 |
+
return answer
|
170 |
+
|
171 |
+
context_model_name="facebook/dpr-ctx_encoder-single-nq-base"
|
172 |
+
question_model_name = "facebook/dpr-question_encoder-single-nq-base"
|
173 |
+
# context_model_name="facebook/dpr-ctx_encoder-multiset-base"
|
174 |
+
# question_model_name="facebook/dpr-question_encoder-multiset-base"
|
175 |
+
|
176 |
+
rag = RAG(file_path, device)
|
177 |
|
178 |
st.title("RAG Model Query Interface")
|
179 |
|
180 |
+
# offer to ask a question and get an answer. make it so they can ask as many questions as they want
|
181 |
+
|
182 |
+
question = st.text_input("Ask a question", "What is another name for self-attention?")
|
183 |
|
184 |
+
if st.button("Ask"):
|
185 |
+
answer = rag.extractive_query(question)
|
186 |
+
st.write(answer)
|
|