danseith commited on
Commit
a95bc58
1 Parent(s): e9178dd

Reduced suggested edits and max edits for faster testing. Changed s

Browse files
Files changed (1) hide show
  1. app.py +57 -17
app.py CHANGED
@@ -18,17 +18,24 @@ ex_str3 = "The graphite plane is composed of a two-dimensional hexagonal lattice
18
  "length and a width parallel to the graphite plane and a thickness orthogonal to the graphite plane with at " \
19
  "least one of the length, width, and thickness values being 100 nanometers or smaller. "
20
 
21
- examples = [[ex_str1, 1.2, 1],
22
- [ex_str2, 1.5, 10],
23
- [ex_str3, 1.4, 5]]
 
 
 
 
 
24
 
25
 
26
  def add_mask(text, size=1):
27
  split_text = text.split()
28
 
29
  # If the user supplies a mask, don't add more
30
- if '[MASK]' in split_text:
31
- return text
 
 
32
  idx = np.random.randint(len(split_text), size=size)
33
  masked_strings = []
34
  for i in idx:
@@ -146,6 +153,14 @@ def sample_output(out, sampling):
146
  return score_to_str[score]
147
 
148
 
 
 
 
 
 
 
 
 
149
  def unmask(text, temp, rounds):
150
  sampling = 'multi'
151
  for _ in range(rounds):
@@ -161,7 +176,7 @@ def unmask(text, temp, rounds):
161
  if unsuccessful_iters > 5:
162
  break
163
  print('skipped', new_token)
164
- new_token = sample_output(out, sampling='uniform')
165
  unsuccessful_iters += 1
166
  if new_token == masked[0]:
167
  split_text[mask_pos] = new_token
@@ -173,18 +188,43 @@ def unmask(text, temp, rounds):
173
  text[0] = text[0].upper()
174
  return ''.join(text)
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- textbox = gr.Textbox(label="Example prompts", lines=5)
178
- output_textbox = gr.Textbox(placeholder="Output will appear here", lines=4)
179
- temp_slider = gr.Slider(1.0, 3.0, value=1.0, label='Creativity')
180
- edit_slider = gr.Slider(1, 20, step=5, value=1.0, label='Number of edits')
181
-
182
- demo = gr.Interface(
183
  fn=unmask,
184
- inputs=[textbox, temp_slider, edit_slider],
185
- outputs=[output_textbox],
186
- examples=examples,
187
- allow_flagging='never'
 
 
188
  )
189
 
190
- demo.launch()
 
 
 
18
  "length and a width parallel to the graphite plane and a thickness orthogonal to the graphite plane with at " \
19
  "least one of the length, width, and thickness values being 100 nanometers or smaller. "
20
 
21
+ tab_two_examples = [[ex_str1, 1.2, 1],
22
+ [ex_str2, 1.5, 10],
23
+ [ex_str3, 1.4, 5]]
24
+
25
+ tab_one_examples = [['A crustless _ made from two slices of baked bread.'],
26
+ ['The present disclosure provides a DNA-targeting RNA that comprises a targeting _.'],
27
+ ['The _ plane is composed of a two-dimensional hexagonal lattice of carbon atoms.']
28
+ ]
29
 
30
 
31
  def add_mask(text, size=1):
32
  split_text = text.split()
33
 
34
  # If the user supplies a mask, don't add more
35
+ if '_' in split_text:
36
+ u_pos = [i for i, s in enumerate(split_text) if '_' in s][0]
37
+ split_text[u_pos] = '[MASK]'
38
+ return ' '.join(split_text), '[MASK]'
39
  idx = np.random.randint(len(split_text), size=size)
40
  masked_strings = []
41
  for i in idx:
 
153
  return score_to_str[score]
154
 
155
 
156
+ def unmask_single(text, temp=1):
157
+ tp = add_mask(text, size=1)
158
+ masked_text, masked = tp[0], tp[1]
159
+ res = scrambler(masked_text, temp=temp, top_k=10)
160
+ out = {item["token_str"]: item["score"] for item in res}
161
+ return out
162
+
163
+
164
  def unmask(text, temp, rounds):
165
  sampling = 'multi'
166
  for _ in range(rounds):
 
176
  if unsuccessful_iters > 5:
177
  break
178
  print('skipped', new_token)
179
+ new_token = sample_output(out, sampling=sampling)
180
  unsuccessful_iters += 1
181
  if new_token == masked[0]:
182
  split_text[mask_pos] = new_token
 
188
  text[0] = text[0].upper()
189
  return ''.join(text)
190
 
191
+ textbox1 = gr.Textbox(label="Input Sentence", lines=5)
192
+ output_textbox1 = gr.Textbox(placeholder="Output will appear here", lines=4)
193
+
194
+ textbox2 = gr.Textbox(label="Input Sentences", lines=5)
195
+ output_textbox2 = gr.Textbox(placeholder="Output will appear here", lines=4)
196
+ temp_slider2 = gr.Slider(1.0, 3.0, value=1.0, label='Creativity')
197
+ edit_slider2 = gr.Slider(1, 20, step=1, value=1.0, label='Number of edits')
198
+
199
+ title1 = "Patent-BERT Sentence Remix-er: Single Edit"
200
+ description1 = """<p>Try inserting a '_' where you want the model to generate a list of likely words.
201
+ <br/>
202
+ <p/>"""
203
+ title2 = "Patent-BERT Sentence Remix-er: Multiple Edits"
204
+ description2 = """<p>Try typing in a sentence for the model to remix. Adjust the 'creativity' scale bar to change the
205
+ the model's confidence in its likely substitutions and the 'number of edits' for the number of edits you want
206
+ the model to attempt to make. <br/> <p/> """
207
+
208
+ demo1 = gr.Interface(
209
+ fn=unmask_single,
210
+ inputs=[textbox1],
211
+ outputs='label',
212
+ examples=tab_one_examples,
213
+ allow_flagging='never',
214
+ title=title1,
215
+ description=description1
216
+ )
217
 
218
+ demo2 = gr.Interface(
 
 
 
 
 
219
  fn=unmask,
220
+ inputs=[textbox2, temp_slider2, edit_slider2],
221
+ outputs=[output_textbox2],
222
+ examples=tab_two_examples,
223
+ allow_flagging='never',
224
+ title=title2,
225
+ description=description2
226
  )
227
 
228
+ gr.TabbedInterface(
229
+ [demo1, demo2], ["Single edit", "Multiple Edits"]
230
+ ).launch()