Shriharsh commited on
Commit
aaaa3f2
·
verified ·
1 Parent(s): a951dd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -1,13 +1,15 @@
1
  # Web Content Q&A Tool for Hugging Face Spaces
2
  # Optimized for memory constraints (2GB RAM) and 24-hour timeline
3
- # Features: Ingest up to 3 URLs, ask questions, get concise answers using DistilBERT
4
 
5
  import gradio as gr
6
  from bs4 import BeautifulSoup
7
  import requests
8
  from sentence_transformers import SentenceTransformer, util
9
  import numpy as np
10
- from transformers import pipeline
 
 
11
 
12
  # Global variables for in-memory storage (reset on app restart)
13
  corpus = [] # List of paragraphs from URLs
@@ -17,8 +19,13 @@ sources_list = [] # Source URLs for each paragraph
17
  # Load models at startup (memory: ~340MB total)
18
  # Retrieval model: all-MiniLM-L6-v2 (~80MB, 384-dim embeddings)
19
  retriever = SentenceTransformer('all-MiniLM-L6-v2')
20
- # QA model: Xenova/distilbert-base-uncased-distilled-squad
21
- qa_model = pipeline("question-answering", model="Xenova/distilbert-base-uncased-distilled-squad")
 
 
 
 
 
22
 
23
  def ingest_urls(urls):
24
  """
@@ -69,7 +76,7 @@ def ingest_urls(urls):
69
 
70
  def answer_question(question):
71
  """
72
- Answer a question using retrieved context and DistilBERT QA.
73
  Retrieves top 3 paragraphs to provide broader context for cross-questioning.
74
  If total context exceeds 512 tokens (DistilBERT's max length), it will be truncated automatically.
75
  """
@@ -82,7 +89,7 @@ def answer_question(question):
82
 
83
  # Compute cosine similarity with stored embeddings
84
  cos_scores = util.cos_sim(question_embedding, embeddings)[0]
85
- top_k = min(1, len(corpus)) # Get topmost or less if fewer paragraphs
86
  top_indices = np.argsort(-cos_scores)[:top_k]
87
 
88
  # Retrieve context (top 3 paragraphs)
@@ -90,7 +97,7 @@ def answer_question(question):
90
  context = " ".join(contexts) # Concatenate with space
91
  sources = [sources_list[i] for i in top_indices]
92
 
93
- # Extract answer with DistilBERT
94
  # Note: If total tokens exceed 512, it will be truncated automatically
95
  result = qa_model(question=question, context=context)
96
  answer = result['answer']
 
1
  # Web Content Q&A Tool for Hugging Face Spaces
2
  # Optimized for memory constraints (2GB RAM) and 24-hour timeline
3
+ # Features: Ingest up to 3 URLs, ask questions, get concise answers using DistilBERT with ONNX
4
 
5
  import gradio as gr
6
  from bs4 import BeautifulSoup
7
  import requests
8
  from sentence_transformers import SentenceTransformer, util
9
  import numpy as np
10
+ from optimum.onnxruntime import ORTModelForQuestionAnswering
11
+ from transformers import AutoTokenizer
12
+ from optimum.pipelines import pipeline
13
 
14
  # Global variables for in-memory storage (reset on app restart)
15
  corpus = [] # List of paragraphs from URLs
 
19
  # Load models at startup (memory: ~340MB total)
20
  # Retrieval model: all-MiniLM-L6-v2 (~80MB, 384-dim embeddings)
21
  retriever = SentenceTransformer('all-MiniLM-L6-v2')
22
+
23
+ # Load ONNX model for QA using optimum.onnxruntime
24
+ # Model: Xenova/distilbert-base-uncased-distilled-squad (~260MB)
25
+ # Use ORTModelForQuestionAnswering to load the ONNX model
26
+ model = ORTModelForQuestionAnswering.from_pretrained("Xenova/distilbert-base-uncased-distilled-squad")
27
+ tokenizer = AutoTokenizer.from_pretrained("Xenova/distilbert-base-uncased-distilled-squad")
28
+ qa_model = pipeline("question-answering", model=model, tokenizer=tokenizer, framework="ort")
29
 
30
  def ingest_urls(urls):
31
  """
 
76
 
77
  def answer_question(question):
78
  """
79
+ Answer a question using retrieved context and DistilBERT QA (ONNX).
80
  Retrieves top 3 paragraphs to provide broader context for cross-questioning.
81
  If total context exceeds 512 tokens (DistilBERT's max length), it will be truncated automatically.
82
  """
 
89
 
90
  # Compute cosine similarity with stored embeddings
91
  cos_scores = util.cos_sim(question_embedding, embeddings)[0]
92
+ top_k = min(2, len(corpus)) # Get top 3 or less if fewer paragraphs
93
  top_indices = np.argsort(-cos_scores)[:top_k]
94
 
95
  # Retrieve context (top 3 paragraphs)
 
97
  context = " ".join(contexts) # Concatenate with space
98
  sources = [sources_list[i] for i in top_indices]
99
 
100
+ # Extract answer with DistilBERT (ONNX)
101
  # Note: If total tokens exceed 512, it will be truncated automatically
102
  result = qa_model(question=question, context=context)
103
  answer = result['answer']