danseith commited on
Commit
85ace6b
·
1 Parent(s): 1ca245c

Won't count iteration if no changes are made.

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -30,9 +30,10 @@ def add_mask(text, size=1):
30
  if '[MASK]' in split_text:
31
  return text
32
  idx = np.random.randint(len(split_text), size=size)
 
33
  for i in idx:
34
  split_text[i] = '[MASK]'
35
- return ' '.join(split_text)
36
 
37
 
38
  class TempScalePipe(FillMaskPipeline):
@@ -135,7 +136,7 @@ scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
135
  def unmask(text, temp, rounds):
136
  sampling = 'multi'
137
  for round in range(rounds):
138
- text = add_mask(text, size=1)
139
  split_text = text.split()
140
  res = scrambler(text, temp=temp, top_k=10)
141
  mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
@@ -148,7 +149,7 @@ def unmask(text, temp, rounds):
148
  idx = np.random.randint(0, len(score_list))
149
  score = score_list[idx]
150
  new_token = score_to_str[score]
151
- if len(list(new_token)) < 2:
152
  continue
153
  split_text[mask_pos] = '*' + new_token + '*'
154
  text = ' '.join(split_text)
 
30
  if '[MASK]' in split_text:
31
  return text
32
  idx = np.random.randint(len(split_text), size=size)
33
+ masked = split_text[idx]
34
  for i in idx:
35
  split_text[i] = '[MASK]'
36
+ return ' '.join(split_text), masked
37
 
38
 
39
  class TempScalePipe(FillMaskPipeline):
 
136
  def unmask(text, temp, rounds):
137
  sampling = 'multi'
138
  for round in range(rounds):
139
+ text, masked = add_mask(text, size=1)
140
  split_text = text.split()
141
  res = scrambler(text, temp=temp, top_k=10)
142
  mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
 
149
  idx = np.random.randint(0, len(score_list))
150
  score = score_list[idx]
151
  new_token = score_to_str[score]
152
+ if len(list(new_token)) < 2 or new_token == masked:
153
  continue
154
  split_text[mask_pos] = '*' + new_token + '*'
155
  text = ' '.join(split_text)