Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,11 +6,11 @@ import numpy as np
|
|
6 |
from transformers import pipeline
|
7 |
|
8 |
# Load models
|
9 |
-
embedding_model = SentenceTransformer('all-
|
10 |
-
qa_pipeline = pipeline("question-answering", model="deepset/roberta-
|
11 |
|
12 |
# Initialize FAISS index
|
13 |
-
dimension =
|
14 |
index = faiss.IndexFlatL2(dimension)
|
15 |
|
16 |
# Store text chunks and their embeddings
|
@@ -58,14 +58,21 @@ def answer_question(question):
|
|
58 |
# Embed the question
|
59 |
question_embedding = embedding_model.encode([question])
|
60 |
|
61 |
-
# Retrieve top-k relevant chunks
|
62 |
-
|
|
|
63 |
relevant_chunks = [text_chunks[i] for i in indices[0]]
|
64 |
|
65 |
# Use the QA model to generate an answer
|
66 |
context = " ".join(relevant_chunks)
|
67 |
result = qa_pipeline(question=question, context=context)
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
# Gradio Interface
|
71 |
with gr.Blocks() as demo:
|
|
|
6 |
from transformers import pipeline
|
7 |
|
8 |
# Load models
|
9 |
+
embedding_model = SentenceTransformer('all-mpnet-base-v2') # Better embedding model
|
10 |
+
qa_pipeline = pipeline("question-answering", model="deepset/roberta-large-squad2") # Larger QA model
|
11 |
|
12 |
# Initialize FAISS index
|
13 |
+
dimension = 768 # Dimension of the embedding model
|
14 |
index = faiss.IndexFlatL2(dimension)
|
15 |
|
16 |
# Store text chunks and their embeddings
|
|
|
58 |
# Embed the question
|
59 |
question_embedding = embedding_model.encode([question])
|
60 |
|
61 |
+
# Retrieve top-k relevant chunks (increase k for more context)
|
62 |
+
k = 5 # Retrieve more chunks for better context
|
63 |
+
distances, indices = index.search(question_embedding, k=k)
|
64 |
relevant_chunks = [text_chunks[i] for i in indices[0]]
|
65 |
|
66 |
# Use the QA model to generate an answer
|
67 |
context = " ".join(relevant_chunks)
|
68 |
result = qa_pipeline(question=question, context=context)
|
69 |
+
|
70 |
+
# Post-process the answer
|
71 |
+
answer = result['answer']
|
72 |
+
if answer.strip() == "":
|
73 |
+
return "The paper does not provide enough information to answer this question."
|
74 |
+
|
75 |
+
return answer
|
76 |
|
77 |
# Gradio Interface
|
78 |
with gr.Blocks() as demo:
|