danseith commited on
Commit
020fa3d
·
1 Parent(s): eabdff9

Require minimum number of specified edits to be made and added error message.

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -137,7 +137,10 @@ scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
137
 
138
  def unmask(text, temp, rounds):
139
  sampling = 'multi'
140
- for round in range(rounds):
 
 
 
141
  tp = add_mask(text, size=1)
142
  masked_text, masked = tp[0], tp[1]
143
  split_text = masked_text.split()
@@ -156,6 +159,10 @@ def unmask(text, temp, rounds):
156
  continue
157
  split_text[mask_pos] = '*' + new_token + '*'
158
  text = ' '.join(split_text)
 
 
 
 
159
  text = list(text)
160
  text[0] = text[0].upper()
161
  return ''.join(text)
 
137
 
138
  def unmask(text, temp, rounds):
139
  sampling = 'multi'
140
+ successful_iters = 0
141
+ unsuccessful_iters = 0
142
+ while successful_iters < rounds or unsuccessful_iters > 5:
143
+ unsuccessful_iters += 1
144
  tp = add_mask(text, size=1)
145
  masked_text, masked = tp[0], tp[1]
146
  split_text = masked_text.split()
 
159
  continue
160
  split_text[mask_pos] = '*' + new_token + '*'
161
  text = ' '.join(split_text)
162
+ successful_iters += 1
163
+ unsuccessful_iters -= 1
164
+ if unsuccessful_iters > 5:
165
+ text = "Ran into an issue :( Please try again."
166
  text = list(text)
167
  text[0] = text[0].upper()
168
  return ''.join(text)