danseith commited on
Commit
f023836
1 Parent(s): 0d6ff2f

Updated rules to ignore punctuation

Browse files
Files changed (2) hide show
  1. app.py +12 -7
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import numpy as np
3
  import torch
 
4
  from nltk.stem import PorterStemmer
5
  from collections import defaultdict
6
  from transformers import pipeline
@@ -32,7 +33,8 @@ tab_two_examples = [[ex_str1, ex_key1],
32
  # ['The _ plane is composed of a two-dimensional hexagonal lattice of carbon atoms.']
33
  # ]
34
 
35
- ignore = ['a', 'an', 'the', 'is', 'and', 'or']
 
36
 
37
 
38
  def add_mask(text, lower_bound=0, index=None):
@@ -49,7 +51,7 @@ def add_mask(text, lower_bound=0, index=None):
49
  idx = np.random.randint(low=lower_bound, high=len(split_text), size=1).astype(int)[0]
50
  # Don't mask certain words
51
  num_iters = 0
52
- while split_text[idx].lower() in ignore:
53
  num_iters += 1
54
  idx = np.random.randint(len(split_text), size=1).astype(int)[0]
55
  if num_iters > 10:
@@ -220,8 +222,9 @@ def extract_keywords(text, queries):
220
  # Iterate through text and mask each token
221
  ps = PorterStemmer()
222
  top_scores = defaultdict(list)
223
- top_k_range = 10
224
- indices = [i for i, t in enumerate(text.split()) if t.lower() == query.lower()]
 
225
  for i in indices:
226
  masked_text, masked = add_mask(text, index=i)
227
  res = scrambler(masked_text, temp=temp, top_k=top_k_range)
@@ -229,12 +232,14 @@ def extract_keywords(text, queries):
229
  sorted_keys = sorted(out, key=out.get)
230
  # If the key does not appear, floor its rank for that round
231
  for rank, token_str in enumerate(sorted_keys):
 
 
232
  stemmed = ps.stem(token_str)
233
- if token_str not in top_scores.keys():
234
- top_scores[stemmed].append(0)
235
  norm_rank = rank / top_k_range
236
  top_scores[stemmed].append(norm_rank)
237
-
 
 
238
  # Calc mean
239
  for key in top_scores.keys():
240
  top_scores[key] = np.mean(top_scores[key])
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
+ import re
5
  from nltk.stem import PorterStemmer
6
  from collections import defaultdict
7
  from transformers import pipeline
 
33
  # ['The _ plane is composed of a two-dimensional hexagonal lattice of carbon atoms.']
34
  # ]
35
 
36
+ ignore_str = ['a', 'an', 'the', 'is', 'and', 'or', '!', '(', ')', '-', '[', ']', '{', '}', ';', ':', "'", '"', '\\',
37
+ ',', '<', '>', '.', '/', '?', '@', '#', '$', '%', '^', '&', '*', '_', '~']
38
 
39
 
40
  def add_mask(text, lower_bound=0, index=None):
 
51
  idx = np.random.randint(low=lower_bound, high=len(split_text), size=1).astype(int)[0]
52
  # Don't mask certain words
53
  num_iters = 0
54
+ while split_text[idx].lower() in ignore_str:
55
  num_iters += 1
56
  idx = np.random.randint(len(split_text), size=1).astype(int)[0]
57
  if num_iters > 10:
 
222
  # Iterate through text and mask each token
223
  ps = PorterStemmer()
224
  top_scores = defaultdict(list)
225
+ top_k_range = 30
226
+ text_no_punc = re.sub(r'[^\w\s]', '', text)
227
+ indices = [i for i, t in enumerate(text_no_punc.split()) if t.lower() == query.lower()]
228
  for i in indices:
229
  masked_text, masked = add_mask(text, index=i)
230
  res = scrambler(masked_text, temp=temp, top_k=top_k_range)
 
232
  sorted_keys = sorted(out, key=out.get)
233
  # If the key does not appear, floor its rank for that round
234
  for rank, token_str in enumerate(sorted_keys):
235
+ if token_str in ignore_str:
236
+ continue
237
  stemmed = ps.stem(token_str)
 
 
238
  norm_rank = rank / top_k_range
239
  top_scores[stemmed].append(norm_rank)
240
+ for key in top_scores.keys():
241
+ if key not in out.keys():
242
+ top_scores[key].append(0)
243
  # Calc mean
244
  for key in top_scores.keys():
245
  top_scores[key] = np.mean(top_scores[key])
requirements.txt CHANGED
@@ -2,4 +2,5 @@ gradio
2
  torch
3
  transformers
4
  numpy
5
- nltk
 
 
2
  torch
3
  transformers
4
  numpy
5
+ nltk
6
+ re