Shriharsh commited on
Commit
524057e
·
verified ·
1 Parent(s): 9121798

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -64
app.py CHANGED
@@ -1,7 +1,6 @@
1
  # Web Content Q&A Tool for Hugging Face Spaces
2
  # Optimized for memory constraints (2GB RAM) and 24-hour timeline
3
- # Features: Ingest up to 3 URLs, ask questions, get concise one-line answers using DistilBERT with PyTorch
4
- # Includes keyword search fallback for low-confidence QA answers
5
 
6
  import gradio as gr
7
  from bs4 import BeautifulSoup
@@ -32,20 +31,20 @@ corpus = [] # List of paragraphs from URLs
32
  embeddings = None # Precomputed embeddings for retrieval
33
  sources_list = [] # Source URLs for each paragraph
34
 
35
- # Load models at startup (memory: ~370MB total)
36
- # Retrieval model: all-mpnet-base-v2 (~110MB, 768-dim embeddings)
37
- retriever = SentenceTransformer('all-mpnet-base-v2')
38
 
39
  # Load PyTorch model for QA
40
- # Model: distilbert-base-uncased-distilled-squad (~260MB)
41
  try:
42
- model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad")
43
- tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-distilled-squad")
44
  except Exception as e:
45
  print(f"Error loading model: {str(e)}. Retrying with force_download=True...")
46
  # Force re-download in case of corrupted cache
47
- model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad", force_download=True)
48
- tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-distilled-squad", force_download=True)
49
 
50
  # Set model to evaluation mode
51
  model.eval()
@@ -68,38 +67,6 @@ def truncate_to_one_line(text):
68
  first_sentence = first_sentence[:100].rsplit(' ', 1)[0] + "..."
69
  return first_sentence if first_sentence else "No answer available."
70
 
71
- # Keyword search function for fallback
72
- def keyword_search(question, corpus, sources_list):
73
- stop_words = set(["what", "is", "the", "a", "an", "in", "on", "at", "for", "with", "and", "or", "but", "not", "this", "that", "these", "those", "to", "of", "it", "by", "as", "if", "when", "where", "who", "which", "how", "why"])
74
-
75
- def clean_text(text):
76
- return re.sub(r'[^a-zA-Z\s]', '', text).lower()
77
-
78
- cleaned_question = clean_text(question)
79
- keywords = [word for word in cleaned_question.split() if word not in stop_words]
80
- if not keywords:
81
- return "No keywords found for search.", None
82
-
83
- best_paragraph = None
84
- best_count = 0
85
- best_source = None
86
-
87
- for i, para in enumerate(corpus):
88
- cleaned_para = clean_text(para)
89
- words = set(cleaned_para.split()) # Use set for faster lookup
90
- count = sum(1 for kw in keywords if kw in words)
91
- if count > best_count:
92
- best_count = count
93
- best_paragraph = para
94
- best_source = sources_list[i]
95
-
96
- if best_paragraph is None:
97
- return "No relevant paragraph found.", None
98
-
99
- # Truncate the paragraph to one line
100
- best_paragraph = truncate_to_one_line(best_paragraph)
101
- return best_paragraph, best_source
102
-
103
  def ingest_urls(urls):
104
  """
105
  Ingest up to 3 URLs, scrape content, and compute embeddings.
@@ -149,11 +116,10 @@ def ingest_urls(urls):
149
 
150
  def answer_question(question):
151
  """
152
- Answer a question using retrieved context and DistilBERT QA (PyTorch).
153
  Retrieves top 3 paragraphs to improve answer accuracy.
154
- If total context exceeds 512 tokens (DistilBERT's max length), it will be truncated automatically.
155
- If QA confidence is below 0.4, falls back to keyword search.
156
- Ensures answers are one line (max 100 chars).
157
  """
158
  global corpus, embeddings, sources_list
159
  if not corpus or embeddings is None:
@@ -164,35 +130,33 @@ def answer_question(question):
164
 
165
  # Compute cosine similarity with stored embeddings
166
  cos_scores = util.cos_sim(question_embedding, embeddings)[0]
167
- top_k = min(2, len(corpus)) # Get top 3 paragraphs to improve accuracy
168
  top_indices = np.argsort(-cos_scores)[:top_k]
169
 
170
- # Retrieve context (top 2 paragraphs)
171
  contexts = [corpus[i] for i in top_indices]
172
  context = " ".join(contexts) # Concatenate with space
173
  sources = [sources_list[i] for i in top_indices]
174
 
175
- # Extract answer with DistilBERT (PyTorch)
176
  with torch.no_grad(): # Disable gradient computation for faster inference
177
  result = qa_model(question=question, context=context)
178
  answer = result['answer']
179
  confidence = result['score']
180
 
181
- if confidence >= 0.4:
182
- # Truncate QA answer to one line
183
- answer = truncate_to_one_line(answer)
184
- # Ensure at least one line
185
- if not answer:
186
- answer = "No answer available."
187
- sources_str = "\n".join(set(sources)) # Unique sources
188
- return f"Answer: {answer}\nConfidence: {confidence:.2f}\nSources:\n{sources_str}"
189
- else:
190
- # Perform keyword search
191
- kw_answer, kw_source = keyword_search(question, corpus, sources_list)
192
- if kw_source:
193
- return f"Answer: {kw_answer} (from keyword search, as QA confidence was {confidence:.2f})\nSource: {kw_source}"
194
- else:
195
- return "No relevant answer found from keyword search."
196
 
197
  def clear_all():
198
  """Clear all inputs and outputs for a fresh start."""
 
1
  # Web Content Q&A Tool for Hugging Face Spaces
2
  # Optimized for memory constraints (2GB RAM) and 24-hour timeline
3
+ # Features: Ingest up to 3 URLs, ask questions, get concise one-line answers using RoBERTa with PyTorch
 
4
 
5
  import gradio as gr
6
  from bs4 import BeautifulSoup
 
31
  embeddings = None # Precomputed embeddings for retrieval
32
  sources_list = [] # Source URLs for each paragraph
33
 
34
+ # Load models at startup (memory: ~410MB total)
35
+ # Retrieval model: multi-qa-mpnet-base-dot-v1 (~110MB, 768-dim embeddings)
36
+ retriever = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
37
 
38
  # Load PyTorch model for QA
39
+ # Model: roberta-base-squad2 (~355MB, quantized to ~200-250MB)
40
  try:
41
+ model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
42
+ tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
43
  except Exception as e:
44
  print(f"Error loading model: {str(e)}. Retrying with force_download=True...")
45
  # Force re-download in case of corrupted cache
46
+ model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2", force_download=True)
47
+ tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2", force_download=True)
48
 
49
  # Set model to evaluation mode
50
  model.eval()
 
67
  first_sentence = first_sentence[:100].rsplit(' ', 1)[0] + "..."
68
  return first_sentence if first_sentence else "No answer available."
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def ingest_urls(urls):
71
  """
72
  Ingest up to 3 URLs, scrape content, and compute embeddings.
 
116
 
117
  def answer_question(question):
118
  """
119
+ Answer a question using retrieved context and RoBERTa QA (PyTorch).
120
  Retrieves top 3 paragraphs to improve answer accuracy.
121
+ If total context exceeds 512 tokens (RoBERTa's max length), it will be truncated automatically.
122
+ Rejects answers with confidence below 0.3. Ensures answers are one line (max 100 chars).
 
123
  """
124
  global corpus, embeddings, sources_list
125
  if not corpus or embeddings is None:
 
130
 
131
  # Compute cosine similarity with stored embeddings
132
  cos_scores = util.cos_sim(question_embedding, embeddings)[0]
133
+ top_k = min(3, len(corpus)) # Get top 3 paragraphs as preferred
134
  top_indices = np.argsort(-cos_scores)[:top_k]
135
 
136
+ # Retrieve context (top 3 paragraphs)
137
  contexts = [corpus[i] for i in top_indices]
138
  context = " ".join(contexts) # Concatenate with space
139
  sources = [sources_list[i] for i in top_indices]
140
 
141
+ # Extract answer with RoBERTa (PyTorch)
142
  with torch.no_grad(): # Disable gradient computation for faster inference
143
  result = qa_model(question=question, context=context)
144
  answer = result['answer']
145
  confidence = result['score']
146
 
147
+ # Check confidence threshold
148
+ if confidence < 0.3:
149
+ return f"No confident answer found (confidence {confidence:.2f} below 0.3)."
150
+
151
+ # Truncate answer to one line
152
+ answer = truncate_to_one_line(answer)
153
+ # Ensure at least one line
154
+ if not answer:
155
+ answer = "No answer available."
156
+
157
+ # Format response with answer, confidence, and sources
158
+ sources_str = "\n".join(set(sources)) # Unique sources
159
+ return f"Answer: {answer}\nConfidence: {confidence:.2f}\nSources:\n{sources_str}"
 
 
160
 
161
  def clear_all():
162
  """Clear all inputs and outputs for a fresh start."""