tarrasyed19472007 commited on
Commit
7b6f550
·
verified ·
1 Parent(s): 3c5d220

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -19
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
  import fitz # PyMuPDF
3
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
4
- import numpy as np
5
 
6
  # Load the RAG model components
7
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
@@ -18,24 +18,19 @@ def extract_text_from_pdf(pdf_file):
18
 
19
  # Function to handle question answering
20
  def answer_question(question, pdf_text):
21
- # Tokenize the question
22
- inputs = tokenizer(question, return_tensors="pt")
23
-
24
- # Retrieve documents based on the PDF text
25
- doc_embeds = retriever.get_document_embeddings([pdf_text]) # Wrap pdf_text in a list
26
- retriever.set_retriever_doc_embeddings(doc_embeds)
27
-
28
- # Get the top k documents for the question
29
- k = 5
30
- retrieved_docs = retriever(question, n_docs=k)
31
-
32
  # Prepare the context for the model
33
- context = retrieved_docs["document_texts"]
34
- context = " ".join(context)
35
 
 
 
 
36
  # Generate the answer
37
- input_dict = tokenizer.prepare_seq2seq_batch(question, context, return_tensors="pt")
38
- outputs = model.generate(**input_dict)
 
 
 
 
39
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
  return answer
41
 
@@ -55,9 +50,12 @@ if pdf_file is not None:
55
 
56
  if question:
57
  with st.spinner("Finding answer..."):
58
- answer = answer_question(question, pdf_text)
59
- st.write("### Answer:")
60
- st.write(answer)
 
 
 
61
 
62
 
63
 
 
1
  import streamlit as st
2
  import fitz # PyMuPDF
3
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
4
+ import torch
5
 
6
  # Load the RAG model components
7
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
 
18
 
19
  # Function to handle question answering
20
  def answer_question(question, pdf_text):
 
 
 
 
 
 
 
 
 
 
 
21
  # Prepare the context for the model
22
+ inputs = tokenizer([question], return_tensors="pt")
 
23
 
24
+ # Tokenize PDF text
25
+ pdf_inputs = tokenizer([pdf_text], return_tensors="pt")
26
+
27
  # Generate the answer
28
+ with torch.no_grad():
29
+ outputs = model.generate(input_ids=inputs['input_ids'],
30
+ attention_mask=inputs['attention_mask'],
31
+ context_input_ids=pdf_inputs['input_ids'],
32
+ context_attention_mask=pdf_inputs['attention_mask'])
33
+
34
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
  return answer
36
 
 
50
 
51
  if question:
52
  with st.spinner("Finding answer..."):
53
+ try:
54
+ answer = answer_question(question, pdf_text)
55
+ st.write("### Answer:")
56
+ st.write(answer)
57
+ except Exception as e:
58
+ st.error(f"Error occurred: {str(e)}")
59
 
60
 
61