Spaces:
Build error
Build error
danseith
commited on
Commit
•
f023836
1
Parent(s):
0d6ff2f
Updated rules to ignore punctuation
Browse files- app.py +12 -7
- requirements.txt +2 -1
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import torch
|
|
|
4 |
from nltk.stem import PorterStemmer
|
5 |
from collections import defaultdict
|
6 |
from transformers import pipeline
|
@@ -32,7 +33,8 @@ tab_two_examples = [[ex_str1, ex_key1],
|
|
32 |
# ['The _ plane is composed of a two-dimensional hexagonal lattice of carbon atoms.']
|
33 |
# ]
|
34 |
|
35 |
-
|
|
|
36 |
|
37 |
|
38 |
def add_mask(text, lower_bound=0, index=None):
|
@@ -49,7 +51,7 @@ def add_mask(text, lower_bound=0, index=None):
|
|
49 |
idx = np.random.randint(low=lower_bound, high=len(split_text), size=1).astype(int)[0]
|
50 |
# Don't mask certain words
|
51 |
num_iters = 0
|
52 |
-
while split_text[idx].lower() in
|
53 |
num_iters += 1
|
54 |
idx = np.random.randint(len(split_text), size=1).astype(int)[0]
|
55 |
if num_iters > 10:
|
@@ -220,8 +222,9 @@ def extract_keywords(text, queries):
|
|
220 |
# Iterate through text and mask each token
|
221 |
ps = PorterStemmer()
|
222 |
top_scores = defaultdict(list)
|
223 |
-
top_k_range =
|
224 |
-
|
|
|
225 |
for i in indices:
|
226 |
masked_text, masked = add_mask(text, index=i)
|
227 |
res = scrambler(masked_text, temp=temp, top_k=top_k_range)
|
@@ -229,12 +232,14 @@ def extract_keywords(text, queries):
|
|
229 |
sorted_keys = sorted(out, key=out.get)
|
230 |
# If the key does not appear, floor its rank for that round
|
231 |
for rank, token_str in enumerate(sorted_keys):
|
|
|
|
|
232 |
stemmed = ps.stem(token_str)
|
233 |
-
if token_str not in top_scores.keys():
|
234 |
-
top_scores[stemmed].append(0)
|
235 |
norm_rank = rank / top_k_range
|
236 |
top_scores[stemmed].append(norm_rank)
|
237 |
-
|
|
|
|
|
238 |
# Calc mean
|
239 |
for key in top_scores.keys():
|
240 |
top_scores[key] = np.mean(top_scores[key])
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
+
import re
|
5 |
from nltk.stem import PorterStemmer
|
6 |
from collections import defaultdict
|
7 |
from transformers import pipeline
|
|
|
33 |
# ['The _ plane is composed of a two-dimensional hexagonal lattice of carbon atoms.']
|
34 |
# ]
|
35 |
|
36 |
+
ignore_str = ['a', 'an', 'the', 'is', 'and', 'or', '!', '(', ')', '-', '[', ']', '{', '}', ';', ':', "'", '"', '\\',
|
37 |
+
',', '<', '>', '.', '/', '?', '@', '#', '$', '%', '^', '&', '*', '_', '~']
|
38 |
|
39 |
|
40 |
def add_mask(text, lower_bound=0, index=None):
|
|
|
51 |
idx = np.random.randint(low=lower_bound, high=len(split_text), size=1).astype(int)[0]
|
52 |
# Don't mask certain words
|
53 |
num_iters = 0
|
54 |
+
while split_text[idx].lower() in ignore_str:
|
55 |
num_iters += 1
|
56 |
idx = np.random.randint(len(split_text), size=1).astype(int)[0]
|
57 |
if num_iters > 10:
|
|
|
222 |
# Iterate through text and mask each token
|
223 |
ps = PorterStemmer()
|
224 |
top_scores = defaultdict(list)
|
225 |
+
top_k_range = 30
|
226 |
+
text_no_punc = re.sub(r'[^\w\s]', '', text)
|
227 |
+
indices = [i for i, t in enumerate(text_no_punc.split()) if t.lower() == query.lower()]
|
228 |
for i in indices:
|
229 |
masked_text, masked = add_mask(text, index=i)
|
230 |
res = scrambler(masked_text, temp=temp, top_k=top_k_range)
|
|
|
232 |
sorted_keys = sorted(out, key=out.get)
|
233 |
# If the key does not appear, floor its rank for that round
|
234 |
for rank, token_str in enumerate(sorted_keys):
|
235 |
+
if token_str in ignore_str:
|
236 |
+
continue
|
237 |
stemmed = ps.stem(token_str)
|
|
|
|
|
238 |
norm_rank = rank / top_k_range
|
239 |
top_scores[stemmed].append(norm_rank)
|
240 |
+
for key in top_scores.keys():
|
241 |
+
if key not in out.keys():
|
242 |
+
top_scores[key].append(0)
|
243 |
# Calc mean
|
244 |
for key in top_scores.keys():
|
245 |
top_scores[key] = np.mean(top_scores[key])
|
requirements.txt
CHANGED
@@ -2,4 +2,5 @@ gradio
|
|
2 |
torch
|
3 |
transformers
|
4 |
numpy
|
5 |
-
nltk
|
|
|
|
2 |
torch
|
3 |
transformers
|
4 |
numpy
|
5 |
+
nltk
|
6 |
+
re
|