thesnak commited on
Commit
c58cb45
·
verified ·
1 Parent(s): bb13b3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -6,11 +6,11 @@ import numpy as np
6
  from transformers import pipeline
7
 
8
  # Load models
9
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # For embedding text chunks
10
- qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2") # For QA
11
 
12
  # Initialize FAISS index
13
- dimension = 384 # Dimension of the embedding model
14
  index = faiss.IndexFlatL2(dimension)
15
 
16
  # Store text chunks and their embeddings
@@ -58,14 +58,21 @@ def answer_question(question):
58
  # Embed the question
59
  question_embedding = embedding_model.encode([question])
60
 
61
- # Retrieve top-k relevant chunks
62
- distances, indices = index.search(question_embedding, k=2)
 
63
  relevant_chunks = [text_chunks[i] for i in indices[0]]
64
 
65
  # Use the QA model to generate an answer
66
  context = " ".join(relevant_chunks)
67
  result = qa_pipeline(question=question, context=context)
68
- return result['answer']
 
 
 
 
 
 
69
 
70
  # Gradio Interface
71
  with gr.Blocks() as demo:
 
6
  from transformers import pipeline
7
 
8
  # Load models
9
+ embedding_model = SentenceTransformer('all-mpnet-base-v2') # Better embedding model
10
+ qa_pipeline = pipeline("question-answering", model="deepset/roberta-large-squad2") # Larger QA model
11
 
12
  # Initialize FAISS index
13
+ dimension = 768 # Dimension of the embedding model
14
  index = faiss.IndexFlatL2(dimension)
15
 
16
  # Store text chunks and their embeddings
 
58
  # Embed the question
59
  question_embedding = embedding_model.encode([question])
60
 
61
+ # Retrieve top-k relevant chunks (increase k for more context)
62
+ k = 5 # Retrieve more chunks for better context
63
+ distances, indices = index.search(question_embedding, k=k)
64
  relevant_chunks = [text_chunks[i] for i in indices[0]]
65
 
66
  # Use the QA model to generate an answer
67
  context = " ".join(relevant_chunks)
68
  result = qa_pipeline(question=question, context=context)
69
+
70
+ # Post-process the answer
71
+ answer = result['answer']
72
+ if answer.strip() == "":
73
+ return "The paper does not provide enough information to answer this question."
74
+
75
+ return answer
76
 
77
  # Gradio Interface
78
  with gr.Blocks() as demo: