IProject-10 commited on
Commit
3217c53
·
verified ·
1 Parent(s): 744ae18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -125
app.py CHANGED
@@ -1,125 +1,158 @@
1
- import os
2
- from rank_bm25 import BM25Okapi
3
- from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering
4
- import torch
5
- import gradio as gr
6
- from docx import Document
7
- import pdfplumber
8
-
9
- # Load the fine-tuned BERT-based QA model and tokenizer
10
- model_name = "IProject-10/roberta-base-finetuned-squad2" # Replace with your model name
11
- qa_model = AutoModelForQuestionAnswering.from_pretrained(model_name)
12
- tokenizer = AutoTokenizer.from_pretrained(model_name)
13
-
14
- # Set up the device for BERT
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- qa_model.to(device)
17
- qa_model.eval()
18
-
19
- # Create a pipeline for retrieval-augmented QA
20
- retrieval_qa_pipeline = pipeline(
21
- "question-answering",
22
- model=qa_model,
23
- tokenizer=tokenizer,
24
- device=device.index if torch.cuda.is_available() else -1
25
- )
26
-
27
- def extract_text_from_file(file):
28
- # Determine the file extension
29
- file_extension = os.path.splitext(file.name)[1].lower()
30
- text = ""
31
-
32
- try:
33
- if file_extension == ".txt":
34
- with open(file.name, "r") as f:
35
- text = f.read()
36
- elif file_extension == ".docx":
37
- doc = Document(file.name)
38
- for para in doc.paragraphs:
39
- text += para.text + "\n"
40
- elif file_extension == ".pdf":
41
- with pdfplumber.open(file.name) as pdf:
42
- for page in pdf.pages:
43
- text += page.extract_text() + "\n"
44
- else:
45
- raise ValueError("Unsupported file format: {}".format(file_extension))
46
- except Exception as e:
47
- text = str(e)
48
- return text
49
-
50
- def load_passages(files):
51
- passages = []
52
- for file in files:
53
- passage = extract_text_from_file(file)
54
- passages.append(passage)
55
- return passages
56
-
57
- def highlight_answer(context, answer):
58
- start_index = context.find(answer)
59
- if start_index != -1:
60
- end_index = start_index + len(answer)
61
- highlighted_context = f"{context[:start_index]}_________<<{context[start_index:end_index]}>>_________{context[end_index:]}"
62
- return highlighted_context
63
- else:
64
- return context
65
-
66
- def answer_question(question, files):
67
- try:
68
- # Load passages from the uploaded files
69
- passages = load_passages(files)
70
-
71
- # Create an index using BM25
72
- bm25 = BM25Okapi([passage.split() for passage in passages])
73
-
74
- # Retrieve relevant passages using BM25
75
- tokenized_query = question.split()
76
- candidate_passages = bm25.get_top_n(tokenized_query, passages, n=3)
77
- bm25_scores = bm25.get_scores(tokenized_query)
78
-
79
- # Extract answer using the pipeline for each candidate passage
80
- answers_with_context = []
81
- for passage in candidate_passages:
82
- answer = retrieval_qa_pipeline(question=question, context=passage)
83
- bm25_score = bm25_scores[passages.index(passage)]
84
- answer_with_context = {
85
- "context": passage,
86
- "answer": answer["answer"],
87
- "BM25-score": bm25_score # BM25 confidence score for this passage
88
- }
89
- answers_with_context.append(answer_with_context)
90
-
91
- # Choose the answer with the highest model confidence score
92
- best_answer = max(answers_with_context, key=lambda x: x["BM25-score"])
93
-
94
- # Highlight the answer in the context
95
- highlighted_context = highlight_answer(best_answer["context"], best_answer["answer"])
96
-
97
- return best_answer["answer"], highlighted_context, best_answer["BM25-score"]
98
- except Exception as e:
99
- return str(e), "", ""
100
-
101
- # Define Gradio interface
102
- iface = gr.Interface(
103
- fn=answer_question,
104
- inputs=[
105
- gr.Textbox(lines=2, placeholder="Enter your question here...", label="Question"),
106
- gr.Files(label="Upload text, Word, or PDF files")
107
- ],
108
- outputs=[
109
- gr.Textbox(label="Answer"),
110
- gr.Textbox(label="Context"),
111
- gr.Textbox(label="BM25 Score")
112
- ],
113
- title="Question Answering Model",
114
- description="Upload a text document and ask a question from the content",
115
- css="""
116
- .container { max-width: 800px; margin: auto; }
117
- .interface-title { font-family: Arial, sans-serif; font-size: 24px; font-weight: bold; }
118
- .interface-description { font-family: Arial, sans-serif; font-size: 16px; margin-bottom: 20px; }
119
- .input-textbox, .output-textbox { font-family: Arial, sans-serif; font-size: 14px; }
120
- .error { color: red; font-family: Arial, sans-serif; font-size: 14px; }
121
- """
122
- )
123
-
124
- # Launch the interface
125
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from rank_bm25 import BM25Okapi
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering
4
+ import torch
5
+ import gradio as gr
6
+ from docx import Document
7
+ import pdfplumber
8
+
9
+ # Load the fine-tuned BERT-based QA model and tokenizer
10
+ model_name = "IProject-10/roberta-base-finetuned-squad2" # Replace with your model name
11
+ qa_model = AutoModelForQuestionAnswering.from_pretrained(model_name)
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+
14
+ # Set up the device for BERT
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ qa_model.to(device)
17
+ qa_model.eval()
18
+
19
+ # Create a pipeline for retrieval-augmented QA
20
+ retrieval_qa_pipeline = pipeline(
21
+ "question-answering",
22
+ model=qa_model,
23
+ tokenizer=tokenizer,
24
+ device=device.index if torch.cuda.is_available() else -1
25
+ )
26
+
27
+ def extract_text_from_file(file):
28
+ # Determine the file extension
29
+ file_extension = os.path.splitext(file.name)[1].lower()
30
+ text = ""
31
+
32
+ try:
33
+ if file_extension == ".txt":
34
+ with open(file.name, "r") as f:
35
+ text = f.read()
36
+ elif file_extension == ".docx":
37
+ doc = Document(file.name)
38
+ for para in doc.paragraphs:
39
+ text += para.text + "\n"
40
+ elif file_extension == ".pdf":
41
+ with pdfplumber.open(file.name) as pdf:
42
+ for page in pdf.pages:
43
+ text += page.extract_text() + "\n"
44
+ else:
45
+ raise ValueError("Unsupported file format: {}".format(file_extension))
46
+ except Exception as e:
47
+ text = str(e)
48
+ return text
49
+
50
+ def load_passages(files):
51
+ passages = []
52
+ for file in files:
53
+ passage = extract_text_from_file(file)
54
+ passages.append(passage)
55
+ return passages
56
+
57
+ def highlight_answer(context, answer):
58
+ start_index = context.find(answer)
59
+ if start_index != -1:
60
+ end_index = start_index + len(answer)
61
+ highlighted_context = f"{context[:start_index]}_________<<{context[start_index:end_index]}>>_________{context[end_index:]}"
62
+ return highlighted_context
63
+ else:
64
+ return context
65
+
66
+ def answer_question(question, files):
67
+ try:
68
+ # Load passages from the uploaded files
69
+ passages = load_passages(files)
70
+
71
+ # Create an index using BM25
72
+ bm25 = BM25Okapi([passage.split() for passage in passages])
73
+
74
+ # Retrieve relevant passages using BM25
75
+ tokenized_query = question.split()
76
+ candidate_passages = bm25.get_top_n(tokenized_query, passages, n=3)
77
+ bm25_scores = bm25.get_scores(tokenized_query)
78
+
79
+ # Extract answer using the pipeline for each candidate passage
80
+ answers_with_context = []
81
+ for passage in candidate_passages:
82
+ answer = retrieval_qa_pipeline(question=question, context=passage)
83
+ bm25_score = bm25_scores[passages.index(passage)]
84
+ answer_with_context = {
85
+ "context": passage,
86
+ "answer": answer["answer"],
87
+ "BM25-score": bm25_score # BM25 confidence score for this passage
88
+ }
89
+ answers_with_context.append(answer_with_context)
90
+
91
+ # Choose the answer with the highest model confidence score
92
+ best_answer = max(answers_with_context, key=lambda x: x["BM25-score"])
93
+
94
+ # Highlight the answer in the context
95
+ highlighted_context = highlight_answer(best_answer["context"], best_answer["answer"])
96
+
97
+ return best_answer["answer"], highlighted_context, best_answer["BM25-score"]
98
+ except Exception as e:
99
+ return str(e), "", ""
100
+
101
+ # Description
102
+ md = """
103
+ ### Brief Overview of the project:
104
+
105
+ A Document-Retrieval QA application built by training **[RoBERTa model](https://arxiv.org/pdf/1907.11692)** on **[SQuAD 2.0](https://rajpurkar.github.io/SQuAD-explorer/)** dataset for efficient answer extraction and
106
+ the system is augmented by using NLP based **[BM25](https://www.researchgate.net/publication/220613776_The_Probabilistic_Relevance_Framework_BM25_and_Beyond)** retriever for information retrieval from a large text corpus.
107
+ The project is a brief enhancement and augmentation to the work done in the research paper **[Encoder-based LLMs: Building QA systems and Comparative Analysis](https://drive.google.com/file/d/1Ztd6x46g21ufoewmKZMoElMxViNfd_2P/view?usp=sharing)**.
108
+ In this paper we study about BERT and its advanced variants and learn to build an efficient answer extraction QA system from scratch.
109
+ The built system can be used in information retrieval system and search engines.
110
+
111
+ **Objectives of the projects:**
112
+ 1. Build a simple Answer Extraction QA system using **RoBERTa-base**: The project is deployed public url objective1.
113
+ 2. Building a Information Retrieval system for data augmentation using **BM25**
114
+ 3. **Document Retrieval QA** system by merging Answer Extraction QA system and Information retrieval system
115
+
116
+ ### Demonstrating working of the Application:
117
+
118
+ <div style="text-align: center;">
119
+ <img src="https://i.imgur.com/oYg8y7N.jpeg" alt="Description Image" style="border: 2px solid #000; border-radius: 5px; width: 600px; height: auto; display: block; margin: 0 auto;">
120
+ </div>
121
+
122
+ **Key Features:**
123
+ - Fine-tuned **RoBERTa**- Performs **Answer Extraction** from the retrieved document
124
+ - **BM25** Retriever- Performs **Information Retrieval** from the text corpus
125
+ - Provides answers with **highlighted context**.
126
+ - Application displays accurate **answer**, most relevant document **context** and the corresponding **BM25 score** of the passage to the user
127
+
128
+ **How to Use:**
129
+ 1. Upload your corpus document(s).
130
+ 2. Enter your question in the text box followed by a question mark(?).
131
+ 3. Get the answer with context and corresponding BM25 scores.
132
+ """
133
+
134
+ # Define Gradio interface
135
+ iface = gr.Interface(
136
+ fn=answer_question,
137
+ inputs=[
138
+ gr.Textbox(lines=2, placeholder="Enter your question here...", label="Question"),
139
+ gr.Files(label="Upload text, Word, or PDF files")
140
+ ],
141
+ outputs=[
142
+ gr.Textbox(label="Answer"),
143
+ gr.Textbox(label="Context"),
144
+ gr.Textbox(label="BM25 Score")
145
+ ],
146
+ title="Document Retrieval Question Answering Application",
147
+ description=md,
148
+ css="""
149
+ .container { max-width: 800px; margin: auto; }
150
+ .interface-title { font-family: Arial, sans-serif; font-size: 24px; font-weight: bold; }
151
+ .interface-description { font-family: Arial, sans-serif; font-size: 16px; margin-bottom: 20px; }
152
+ .input-textbox, .output-textbox { font-family: Arial, sans-serif; font-size: 14px; }
153
+ .error { color: red; font-family: Arial, sans-serif; font-size: 14px; }
154
+ """
155
+ )
156
+
157
+ # Launch the interface
158
+ iface.launch()