Spaces:
Build error
Build error
danseith
commited on
Commit
•
4db26d9
1
Parent(s):
b7321ed
search string generator
Browse files
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import torch
|
|
|
|
|
4 |
from transformers import pipeline
|
5 |
from transformers.pipelines import PIPELINE_REGISTRY, FillMaskPipeline
|
6 |
from transformers import AutoModelForMaskedLM
|
@@ -9,38 +11,45 @@ ex_str1 = "A crustless sandwich made from two slices of baked bread. The sandwic
|
|
9 |
"crustless bread pieces. The bread pieces have the same general outer shape defined by an outer periphery " \
|
10 |
"with central portions surrounded by an outer peripheral area, the bread pieces being at least partially " \
|
11 |
"crimped together at the outer peripheral area."
|
|
|
12 |
|
13 |
ex_str2 = "The present disclosure provides a DNA-targeting RNA that comprises a targeting sequence and, together with" \
|
14 |
" a modifying polypeptide, provides for site-specific modification of a target DNA and/or a polypeptide" \
|
15 |
" associated with the target DNA. "
|
|
|
16 |
|
17 |
ex_str3 = "The graphite plane is composed of a two-dimensional hexagonal lattice of carbon atoms and the plate has a " \
|
18 |
"length and a width parallel to the graphite plane and a thickness orthogonal to the graphite plane with at " \
|
19 |
"least one of the length, width, and thickness values being 100 nanometers or smaller. "
|
|
|
20 |
|
21 |
-
tab_two_examples = [[ex_str1,
|
22 |
-
[ex_str2,
|
23 |
-
[ex_str3,
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
|
26 |
-
['The present disclosure provides a DNA-targeting RNA that comprises a targeting _.'],
|
27 |
-
['The _ plane is composed of a two-dimensional hexagonal lattice of carbon atoms.']
|
28 |
-
]
|
29 |
|
30 |
|
31 |
-
def add_mask(text):
|
32 |
split_text = text.split()
|
33 |
-
|
|
|
|
|
34 |
# If the user supplies a mask, don't add more
|
35 |
if '_' in split_text:
|
36 |
u_pos = [i for i, s in enumerate(split_text) if '_' in s][0]
|
37 |
split_text[u_pos] = '[MASK]'
|
38 |
return ' '.join(split_text), '[MASK]'
|
39 |
|
40 |
-
idx = np.random.randint(len(split_text), size=1).astype(int)[0]
|
41 |
# Don't mask certain words
|
42 |
num_iters = 0
|
43 |
-
while split_text[idx].lower() in
|
44 |
num_iters += 1
|
45 |
idx = np.random.randint(len(split_text), size=1).astype(int)[0]
|
46 |
if num_iters > 10:
|
@@ -148,6 +157,7 @@ PIPELINE_REGISTRY.register_pipeline(
|
|
148 |
)
|
149 |
scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
|
150 |
|
|
|
151 |
|
152 |
def sample_output(out, sampling):
|
153 |
score_to_str = {out[k]: k for k in out.keys()}
|
@@ -167,10 +177,10 @@ def unmask_single(text, temp=1):
|
|
167 |
return out
|
168 |
|
169 |
|
170 |
-
def unmask(text, temp, rounds):
|
171 |
sampling = 'multi'
|
172 |
for _ in range(rounds):
|
173 |
-
masked_text, masked = add_mask(text)
|
174 |
split_text = masked_text.split()
|
175 |
res = scrambler(masked_text, temp=temp, top_k=15)
|
176 |
mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
|
@@ -194,51 +204,140 @@ def unmask(text, temp, rounds):
|
|
194 |
return ''.join(text)
|
195 |
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
-
|
|
|
|
|
206 |
description1 = """<p>
|
|
|
|
|
|
|
|
|
207 |
This is a model based on
|
208 |
<a href= "https://github.com/google/patents-public-data/blob/master/models/BERT%20for%20Patents.md">Patent BERT</a> created by Google.
|
209 |
-
|
210 |
-
|
211 |
-
<strong>Note:</strong> You can only add one '_' per submission.
|
212 |
<br/>
|
213 |
<p/>"""
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
allow_flagging='never',
|
228 |
title=title1,
|
229 |
description=description1
|
230 |
)
|
231 |
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
gr.TabbedInterface(
|
243 |
-
[
|
244 |
).launch()
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
+
from nltk.stem import PorterStemmer
|
5 |
+
from collections import defaultdict
|
6 |
from transformers import pipeline
|
7 |
from transformers.pipelines import PIPELINE_REGISTRY, FillMaskPipeline
|
8 |
from transformers import AutoModelForMaskedLM
|
|
|
11 |
"crustless bread pieces. The bread pieces have the same general outer shape defined by an outer periphery " \
|
12 |
"with central portions surrounded by an outer peripheral area, the bread pieces being at least partially " \
|
13 |
"crimped together at the outer peripheral area."
|
14 |
+
ex_key1 = "sandwich bread crimped"
|
15 |
|
16 |
ex_str2 = "The present disclosure provides a DNA-targeting RNA that comprises a targeting sequence and, together with" \
|
17 |
" a modifying polypeptide, provides for site-specific modification of a target DNA and/or a polypeptide" \
|
18 |
" associated with the target DNA. "
|
19 |
+
ex_key2 = "DNA target modification"
|
20 |
|
21 |
ex_str3 = "The graphite plane is composed of a two-dimensional hexagonal lattice of carbon atoms and the plate has a " \
|
22 |
"length and a width parallel to the graphite plane and a thickness orthogonal to the graphite plane with at " \
|
23 |
"least one of the length, width, and thickness values being 100 nanometers or smaller. "
|
24 |
+
ex_key3 = "graphite lattice orthogonal "
|
25 |
|
26 |
+
tab_two_examples = [[ex_str1, ex_key1],
|
27 |
+
[ex_str2, ex_key2],
|
28 |
+
[ex_str3, ex_key3]]
|
29 |
+
#
|
30 |
+
# tab_one_examples = [['A crustless _ made from two slices of baked bread.'],
|
31 |
+
# ['The present disclosure provides a DNA-targeting RNA that comprises a targeting _.'],
|
32 |
+
# ['The _ plane is composed of a two-dimensional hexagonal lattice of carbon atoms.']
|
33 |
+
# ]
|
34 |
|
35 |
+
ignore = ['a', 'an', 'the', 'is', 'and', 'or']
|
|
|
|
|
|
|
36 |
|
37 |
|
38 |
+
def add_mask(text, lower_bound=0, index=None):
|
39 |
split_text = text.split()
|
40 |
+
if index is not None:
|
41 |
+
split_text[index] = '[MASK]'
|
42 |
+
return ' '.join(split_text), None
|
43 |
# If the user supplies a mask, don't add more
|
44 |
if '_' in split_text:
|
45 |
u_pos = [i for i, s in enumerate(split_text) if '_' in s][0]
|
46 |
split_text[u_pos] = '[MASK]'
|
47 |
return ' '.join(split_text), '[MASK]'
|
48 |
|
49 |
+
idx = np.random.randint(low=lower_bound, high=len(split_text), size=1).astype(int)[0]
|
50 |
# Don't mask certain words
|
51 |
num_iters = 0
|
52 |
+
while split_text[idx].lower() in ignore:
|
53 |
num_iters += 1
|
54 |
idx = np.random.randint(len(split_text), size=1).astype(int)[0]
|
55 |
if num_iters > 10:
|
|
|
157 |
)
|
158 |
scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
|
159 |
|
160 |
+
generator = pipeline('text-generation', model='gpt2')
|
161 |
|
162 |
def sample_output(out, sampling):
|
163 |
score_to_str = {out[k]: k for k in out.keys()}
|
|
|
177 |
return out
|
178 |
|
179 |
|
180 |
+
def unmask(text, temp, rounds, lower_bound=0):
|
181 |
sampling = 'multi'
|
182 |
for _ in range(rounds):
|
183 |
+
masked_text, masked = add_mask(text, lower_bound)
|
184 |
split_text = masked_text.split()
|
185 |
res = scrambler(masked_text, temp=temp, top_k=15)
|
186 |
mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
|
|
|
204 |
return ''.join(text)
|
205 |
|
206 |
|
207 |
+
def autocomplete(text, temp):
|
208 |
+
output = generator(text, max_length=30, num_return_sequences=1)
|
209 |
+
gpt_out = output[0]['generated_text']
|
210 |
+
# diff = gpt_out.replace(text, '')
|
211 |
+
patent_bert_out = unmask(gpt_out, temp=temp, rounds=5, lower_bound=len(text.split()))
|
212 |
+
# Take the output from gpt-2 and randomly mask, if a mask is confident, swap it in. Iterate 5 times
|
213 |
+
return patent_bert_out
|
214 |
+
|
215 |
+
|
216 |
+
def extract_keywords(text, queries):
|
217 |
+
q_dict = {}
|
218 |
+
temp = 1 # set temperature to 1
|
219 |
+
for query in queries.split():
|
220 |
+
# Iterate through text and mask each token
|
221 |
+
ps = PorterStemmer()
|
222 |
+
top_scores = defaultdict(list)
|
223 |
+
top_k_range = 10
|
224 |
+
indices = [i for i, t in enumerate(text.split()) if t.lower() == query.lower()]
|
225 |
+
for i in indices:
|
226 |
+
masked_text, masked = add_mask(text, index=i)
|
227 |
+
res = scrambler(masked_text, temp=temp, top_k=top_k_range)
|
228 |
+
out = {item["token_str"]: item["score"] for item in res}
|
229 |
+
sorted_keys = sorted(out, key=out.get)
|
230 |
+
# If the key does not appear, floor its rank for that round
|
231 |
+
for rank, token_str in enumerate(sorted_keys):
|
232 |
+
stemmed = ps.stem(token_str)
|
233 |
+
if token_str not in top_scores.keys():
|
234 |
+
top_scores[stemmed].append(0)
|
235 |
+
norm_rank = rank / top_k_range
|
236 |
+
top_scores[stemmed].append(norm_rank)
|
237 |
+
|
238 |
+
# Calc mean
|
239 |
+
for key in top_scores.keys():
|
240 |
+
top_scores[key] = np.mean(top_scores[key])
|
241 |
+
# Normalize
|
242 |
+
for key in top_scores.keys():
|
243 |
+
top_scores[key] = top_scores[key] / np.sum(list(top_scores.values()))
|
244 |
+
# Get top_k
|
245 |
+
top_n = sorted(list(top_scores.values()))[-3]
|
246 |
+
for key in list(top_scores.keys()):
|
247 |
+
if top_scores[key] < top_n:
|
248 |
+
del top_scores[key]
|
249 |
+
q_dict[query] = top_scores
|
250 |
+
|
251 |
+
keywords = ''
|
252 |
+
for i, q in enumerate(q_dict.keys()):
|
253 |
+
keywords += '['
|
254 |
+
for ii, k in enumerate(q_dict[q].keys()):
|
255 |
+
keywords += k
|
256 |
+
if ii < len(q_dict[q].keys()) - 1:
|
257 |
+
keywords += ' OR '
|
258 |
+
else:
|
259 |
+
keywords += ']'
|
260 |
+
if i < len(q_dict.keys()) - 1:
|
261 |
+
keywords += ' AND '
|
262 |
+
# keywords = set([k for q in q_dict.keys() for k in q_dict[q].keys()])
|
263 |
+
# search_str = ' OR '.join(keywords)
|
264 |
+
output = [q_dict[q] for q in q_dict]
|
265 |
+
output.append(keywords)
|
266 |
+
return output
|
267 |
+
# fig, ax = plt.subplots(nrows=1, ncols=3)
|
268 |
+
# for q in q_dict:
|
269 |
+
# ax.bar(q_dict[q])
|
270 |
+
# return fig
|
271 |
+
|
272 |
+
label0 = gr.Label(label='keyword 1', num_top_classes=3)
|
273 |
+
label01 = gr.Label(label='keyword 2', num_top_classes=3)
|
274 |
+
label02 = gr.Label(label='keyword 3', num_top_classes=3)
|
275 |
+
textbox02 = gr.Textbox(label="Input Keywords", lines=3)
|
276 |
+
textbox01 = gr.Textbox(label="Input Keywords", placeholder="Type keywords here", lines=1)
|
277 |
+
textbox0 = gr.Textbox(label="Input Sentences", placeholder="Type sentences here", lines=5)
|
278 |
+
|
279 |
+
output_textbox0 = gr.Textbox(label='Search String of Keywords', placeholder="Output will appear here", lines=4)
|
280 |
+
# temp_slider0 = gr.Slider(1.0, 3.0, value=1.0, label='Creativity')
|
281 |
|
282 |
+
textbox1 = gr.Textbox(label="Input Sentence", lines=5)
|
283 |
+
# output_textbox1 = gr.Textbox(placeholder="Output will appear here", lines=4)
|
284 |
+
title1 = "Patent-BERT: Context-Dependent Synonym Generator"
|
285 |
description1 = """<p>
|
286 |
+
Try inserting a few sentences from a patent, and pick keywords for the model to analyze. The model will analyze the
|
287 |
+
context of the keywords in the sentences and generate the top five most likely candidates for each word.
|
288 |
+
Can be used for more creative patent drafting or patent searches using the generated search string.
|
289 |
+
|
290 |
This is a model based on
|
291 |
<a href= "https://github.com/google/patents-public-data/blob/master/models/BERT%20for%20Patents.md">Patent BERT</a> created by Google.
|
292 |
+
|
293 |
+
<strong>Note:</strong> Current pipeline only allows for three keyword submission.
|
|
|
294 |
<br/>
|
295 |
<p/>"""
|
296 |
+
|
297 |
+
# textbox2 = gr.Textbox(label="Input Sentences", lines=5)
|
298 |
+
# output_textbox2 = gr.Textbox(placeholder="Output will appear here", lines=4)
|
299 |
+
# temp_slider2 = gr.Slider(1.0, 3.0, value=1.0, label='Creativity')
|
300 |
+
# edit_slider2 = gr.Slider(1, 20, step=1, value=1.0, label='Number of edits')
|
301 |
+
|
302 |
+
|
303 |
+
# title2 = "Patent-BERT Sentence Remix-er: Multiple Edits"
|
304 |
+
# description2 = """<p>
|
305 |
+
#
|
306 |
+
# Try typing in a sentence for the model to remix. Adjust the 'creativity' scale bar to change the
|
307 |
+
# the model's confidence in its likely substitutions and the 'number of edits' for the number of edits you want
|
308 |
+
# the model to attempt to make. The words substituted in the output sentence will be enclosed in asterisks (e.g., *word*).
|
309 |
+
# <br/> <p/> """
|
310 |
+
|
311 |
+
demo0 = gr.Interface(
|
312 |
+
fn=extract_keywords,
|
313 |
+
inputs=[textbox0, textbox01],
|
314 |
+
outputs=[label0, label01, label02, output_textbox0],
|
315 |
+
examples=tab_two_examples,
|
316 |
allow_flagging='never',
|
317 |
title=title1,
|
318 |
description=description1
|
319 |
)
|
320 |
|
321 |
+
# demo1 = gr.Interface(
|
322 |
+
# fn=unmask_single,
|
323 |
+
# inputs=[textbox1],
|
324 |
+
# outputs='label',
|
325 |
+
# examples=tab_one_examples,
|
326 |
+
# allow_flagging='never',
|
327 |
+
# title=title1,
|
328 |
+
# description=description1
|
329 |
+
# )
|
330 |
+
|
331 |
+
# demo2 = gr.Interface(
|
332 |
+
# fn=unmask,
|
333 |
+
# inputs=[textbox2, temp_slider2, edit_slider2],
|
334 |
+
# outputs=[output_textbox2],
|
335 |
+
# examples=tab_two_examples,
|
336 |
+
# allow_flagging='never',
|
337 |
+
# title=title2,
|
338 |
+
# description=description2
|
339 |
+
# )
|
340 |
|
341 |
gr.TabbedInterface(
|
342 |
+
[demo0], ["Keyword generator"]
|
343 |
).launch()
|