Shriharsh commited on
Commit
1bb4299
·
verified ·
1 Parent(s): bc9fd78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -21
app.py CHANGED
@@ -1,15 +1,14 @@
1
  # Web Content Q&A Tool for Hugging Face Spaces
2
  # Optimized for memory constraints (2GB RAM) and 24-hour timeline
3
- # Features: Ingest up to 3 URLs, ask questions, get concise answers using DistilBERT with ONNX
4
 
5
  import gradio as gr
6
  from bs4 import BeautifulSoup
7
  import requests
8
  from sentence_transformers import SentenceTransformer, util
9
  import numpy as np
10
- from optimum.onnxruntime import ORTModelForQuestionAnswering
11
- from transformers import AutoTokenizer
12
- from optimum.pipelines import pipeline
13
 
14
  # Global variables for in-memory storage (reset on app restart)
15
  corpus = [] # List of paragraphs from URLs
@@ -20,16 +19,28 @@ sources_list = [] # Source URLs for each paragraph
20
  # Retrieval model: all-MiniLM-L6-v2 (~80MB, 384-dim embeddings)
21
  retriever = SentenceTransformer('all-MiniLM-L6-v2')
22
 
23
- # Load ONNX model for QA using optimum.onnxruntime
24
- # Model: Xenova/distilbert-base-uncased-distilled-squad (~260MB)
25
- # Specify file_name="model.onnx" to select the correct ONNX file
26
- model = ORTModelForQuestionAnswering.from_pretrained(
27
- "Xenova/distilbert-base-uncased-distilled-squad",
28
- file_name="onnx/model.onnx",
29
- provider="CPUExecutionProvider"
 
 
 
 
 
 
 
 
 
 
 
30
  )
31
- tokenizer = AutoTokenizer.from_pretrained("Xenova/distilbert-base-uncased-distilled-squad")
32
- qa_model = pipeline("question-answering", model=model, tokenizer=tokenizer, framework="ort", device=0)
 
33
 
34
  def ingest_urls(urls):
35
  """
@@ -80,8 +91,8 @@ def ingest_urls(urls):
80
 
81
  def answer_question(question):
82
  """
83
- Answer a question using retrieved context and DistilBERT QA (ONNX).
84
- Retrieves top 3 paragraphs to provide broader context for cross-questioning.
85
  If total context exceeds 512 tokens (DistilBERT's max length), it will be truncated automatically.
86
  """
87
  global corpus, embeddings, sources_list
@@ -93,17 +104,17 @@ def answer_question(question):
93
 
94
  # Compute cosine similarity with stored embeddings
95
  cos_scores = util.cos_sim(question_embedding, embeddings)[0]
96
- top_k = min(2, len(corpus)) # Get top 2 or less if fewer paragraphs
97
  top_indices = np.argsort(-cos_scores)[:top_k]
98
 
99
- # Retrieve context (top 3 paragraphs)
100
  contexts = [corpus[i] for i in top_indices]
101
  context = " ".join(contexts) # Concatenate with space
102
  sources = [sources_list[i] for i in top_indices]
103
 
104
- # Extract answer with DistilBERT (ONNX)
105
- # Note: If total tokens exceed 512, it will be truncated automatically
106
- result = qa_model(question=question, context=context)
107
  answer = result['answer']
108
  confidence = result['score']
109
 
@@ -149,4 +160,4 @@ with gr.Blocks(title="Web Content Q&A Tool") as demo:
149
  clear_btn.click(fn=clear_all, inputs=None, outputs=[url_input, ingest_output, answer_output])
150
 
151
  # Launch the app (HF Spaces expects port 7860)
152
- demo.launch(share = True, server_name="0.0.0.0", server_port=7860)
 
1
  # Web Content Q&A Tool for Hugging Face Spaces
2
  # Optimized for memory constraints (2GB RAM) and 24-hour timeline
3
+ # Features: Ingest up to 3 URLs, ask questions, get concise answers using DistilBERT with PyTorch
4
 
5
  import gradio as gr
6
  from bs4 import BeautifulSoup
7
  import requests
8
  from sentence_transformers import SentenceTransformer, util
9
  import numpy as np
10
+ from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
11
+ import torch
 
12
 
13
  # Global variables for in-memory storage (reset on app restart)
14
  corpus = [] # List of paragraphs from URLs
 
19
  # Retrieval model: all-MiniLM-L6-v2 (~80MB, 384-dim embeddings)
20
  retriever = SentenceTransformer('all-MiniLM-L6-v2')
21
 
22
+ # Load PyTorch model for QA
23
+ # Model: distilbert-base-uncased-distilled-squad (~260MB)
24
+ model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad")
25
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-distilled-squad")
26
+
27
+ # Set model to evaluation mode
28
+ model.eval()
29
+
30
+ # Compile the model with torch.compile for faster inference (PyTorch 2.0+)
31
+ # Use backend="inductor" for CPU optimization
32
+ try:
33
+ model = torch.compile(model, backend="inductor")
34
+ except Exception as e:
35
+ print(f"Warning: torch.compile failed with error: {str(e)}. Proceeding without compilation.")
36
+
37
+ # Apply quantization to the model for additional speedup on CPU
38
+ model = torch.quantization.quantize_dynamic(
39
+ model, {torch.nn.Linear}, dtype=torch.qint8
40
  )
41
+
42
+ # Create the QA pipeline with PyTorch
43
+ qa_model = pipeline("question-answering", model=model, tokenizer=tokenizer, framework="pt", device=-1) # device=-1 for CPU
44
 
45
  def ingest_urls(urls):
46
  """
 
91
 
92
  def answer_question(question):
93
  """
94
+ Answer a question using retrieved context and DistilBERT QA (PyTorch).
95
+ Retrieves top 1 paragraph to reduce inference time.
96
  If total context exceeds 512 tokens (DistilBERT's max length), it will be truncated automatically.
97
  """
98
  global corpus, embeddings, sources_list
 
104
 
105
  # Compute cosine similarity with stored embeddings
106
  cos_scores = util.cos_sim(question_embedding, embeddings)[0]
107
+ top_k = min(1, len(corpus)) # Get top 1 paragraph to speed up inference
108
  top_indices = np.argsort(-cos_scores)[:top_k]
109
 
110
+ # Retrieve context (top 1 paragraph)
111
  contexts = [corpus[i] for i in top_indices]
112
  context = " ".join(contexts) # Concatenate with space
113
  sources = [sources_list[i] for i in top_indices]
114
 
115
+ # Extract answer with DistilBERT (PyTorch)
116
+ with torch.no_grad(): # Disable gradient computation for faster inference
117
+ result = qa_model(question=question, context=context)
118
  answer = result['answer']
119
  confidence = result['score']
120
 
 
160
  clear_btn.click(fn=clear_all, inputs=None, outputs=[url_input, ingest_output, answer_output])
161
 
162
  # Launch the app (HF Spaces expects port 7860)
163
+ demo.launch(server_name="0.0.0.0", server_port=7860)