Lora
commited on
Commit
·
5758582
1
Parent(s):
0730e98
add negative sense words and note
Browse files
app.py
CHANGED
@@ -121,9 +121,10 @@ Args:
|
|
121 |
length: length of the input sentence, used to get the contextualization weights for the last token
|
122 |
token: the selected token
|
123 |
token_index: the index of the selected token in the input sentence
|
124 |
-
|
|
|
125 |
"""
|
126 |
-
def get_token_contextual_weights (contextualization_weights, length, token, token_index,
|
127 |
print(">>>>>in get_token_contextual_weights")
|
128 |
print(f"Selected {token_index}th token: {token}")
|
129 |
|
@@ -139,47 +140,54 @@ def get_token_contextual_weights (contextualization_weights, length, token, toke
|
|
139 |
senses = torch.squeeze(senses) # (nv, s=1, d)
|
140 |
|
141 |
# build dataframe
|
142 |
-
neg_word_lists = []
|
143 |
pos_dfs, neg_dfs = [], []
|
144 |
|
145 |
for i in range(num_senses):
|
146 |
logits = lm_head(senses[i,:]) # (vocab,) [768, 50257] -> [50257]
|
147 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
148 |
|
149 |
-
pos_sorted_words = [tokenizer.decode(sorted_indices[j]) for j in range(
|
150 |
-
pos_df = pd.DataFrame(pos_sorted_words)
|
151 |
pos_dfs.append(pos_df)
|
152 |
|
153 |
-
neg_sorted_words = [tokenizer.decode(sorted_indices[-j-1]) for j in range(
|
154 |
-
neg_df = pd.DataFrame(neg_sorted_words)
|
155 |
neg_dfs.append(neg_df)
|
156 |
|
157 |
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, \
|
158 |
sense6words, sense7words, sense8words, sense9words, sense10words, sense11words, \
|
159 |
sense12words, sense13words, sense14words, sense15words = pos_dfs
|
160 |
|
|
|
|
|
|
|
|
|
161 |
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, \
|
162 |
sense6slider, sense7slider, sense8slider, sense9slider, sense10slider, sense11slider, \
|
163 |
sense12slider, sense13slider, sense14slider, sense15slider = token_contextualization_weights_list
|
164 |
|
165 |
-
return token, token_index,
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
169 |
|
170 |
"""
|
171 |
Wrapper for when the user selects a new token in the tokens dataframe.
|
172 |
Converts `evt` (the selected token) to `token` and `token_index` which are used by get_token_contextual_weights.
|
173 |
"""
|
174 |
-
def new_token_contextual_weights (contextualization_weights, length, evt: gr.SelectData,
|
175 |
print(">>>>>in new_token_contextual_weights")
|
176 |
token_index = evt.index[1] # selected token is the token_index-th token in the sentence
|
177 |
token = evt.value
|
178 |
if not token:
|
179 |
-
return None, None,
|
180 |
-
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
|
181 |
-
|
182 |
-
|
|
|
183 |
|
184 |
def change_sense0_weight(contextualization_weights, length, token_index, new_weight):
|
185 |
contextualization_weights[0, 0, length-1, token_index] = new_weight
|
@@ -273,7 +281,7 @@ with gr.Blocks( theme = gr.themes.Base(),
|
|
273 |
with gr.Column(scale=1):
|
274 |
selected_token = gr.Textbox(label="Current Selected Token", interactive=False)
|
275 |
with gr.Column(scale=8):
|
276 |
-
gr.Markdown("""
|
277 |
Once a token is chosen, you can **use the sliders below to change the weights of any senses** for that token, \
|
278 |
and then click "Predict next word" to see updated next-word predictions. \
|
279 |
You can change the weights of *multiple senses of multiple tokens;* \
|
@@ -314,6 +322,23 @@ with gr.Blocks( theme = gr.themes.Base(),
|
|
314 |
sense6words = gr.DataFrame(headers = ["Sense 6"])
|
315 |
with gr.Column(scale=0, min_width=120):
|
316 |
sense7words = gr.DataFrame(headers = ["Sense 7"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
with gr.Row():
|
318 |
with gr.Column(scale=0, min_width=120):
|
319 |
sense8slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 8", elem_id="sense8slider", interactive=True)
|
@@ -348,7 +373,26 @@ with gr.Blocks( theme = gr.themes.Base(),
|
|
348 |
sense14words = gr.DataFrame(headers = ["Sense 14"])
|
349 |
with gr.Column(scale=0, min_width=120):
|
350 |
sense15words = gr.DataFrame(headers = ["Sense 15"])
|
351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
# gr.Examples(
|
353 |
# examples=[["Messi plays for", top_k, None]],
|
354 |
# inputs=[input_sentence, top_k, contextualization_weights],
|
@@ -405,6 +449,7 @@ with gr.Blocks( theme = gr.themes.Base(),
|
|
405 |
inputs=[contextualization_weights, length, token_index, sense15slider],
|
406 |
outputs=[contextualization_weights])
|
407 |
|
|
|
408 |
predict.click(
|
409 |
fn=predict_next_word,
|
410 |
inputs = [input_sentence, top_k, contextualization_weights],
|
@@ -418,6 +463,9 @@ with gr.Blocks( theme = gr.themes.Base(),
|
|
418 |
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
|
419 |
sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
|
420 |
|
|
|
|
|
|
|
421 |
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
|
422 |
sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
|
423 |
)
|
@@ -438,6 +486,9 @@ with gr.Blocks( theme = gr.themes.Base(),
|
|
438 |
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
|
439 |
sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
|
440 |
|
|
|
|
|
|
|
441 |
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
|
442 |
sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
|
443 |
)
|
|
|
121 |
length: length of the input sentence, used to get the contextualization weights for the last token
|
122 |
token: the selected token
|
123 |
token_index: the index of the selected token in the input sentence
|
124 |
+
pos_count: how many top positive words to display for each sense
|
125 |
+
neg_count: how many top negative words to display for each sense
|
126 |
"""
|
127 |
+
def get_token_contextual_weights (contextualization_weights, length, token, token_index, pos_count = 5, neg_count = 3):
|
128 |
print(">>>>>in get_token_contextual_weights")
|
129 |
print(f"Selected {token_index}th token: {token}")
|
130 |
|
|
|
140 |
senses = torch.squeeze(senses) # (nv, s=1, d)
|
141 |
|
142 |
# build dataframe
|
|
|
143 |
pos_dfs, neg_dfs = [], []
|
144 |
|
145 |
for i in range(num_senses):
|
146 |
logits = lm_head(senses[i,:]) # (vocab,) [768, 50257] -> [50257]
|
147 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
148 |
|
149 |
+
pos_sorted_words = [tokenizer.decode(sorted_indices[j]) for j in range(pos_count)]
|
150 |
+
pos_df = pd.DataFrame(pos_sorted_words, columns=["Sense {}".format(i)])
|
151 |
pos_dfs.append(pos_df)
|
152 |
|
153 |
+
neg_sorted_words = [tokenizer.decode(sorted_indices[-j-1]) for j in range(neg_count)]
|
154 |
+
neg_df = pd.DataFrame(neg_sorted_words, columns=["Top Negative"])
|
155 |
neg_dfs.append(neg_df)
|
156 |
|
157 |
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, \
|
158 |
sense6words, sense7words, sense8words, sense9words, sense10words, sense11words, \
|
159 |
sense12words, sense13words, sense14words, sense15words = pos_dfs
|
160 |
|
161 |
+
sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, \
|
162 |
+
sense6negwords, sense7negwords, sense8negwords, sense9negwords, sense10negwords, sense11negwords, \
|
163 |
+
sense12negwords, sense13negwords, sense14negwords, sense15negwords = neg_dfs
|
164 |
+
|
165 |
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, \
|
166 |
sense6slider, sense7slider, sense8slider, sense9slider, sense10slider, sense11slider, \
|
167 |
sense12slider, sense13slider, sense14slider, sense15slider = token_contextualization_weights_list
|
168 |
|
169 |
+
return token, token_index, \
|
170 |
+
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words, \
|
171 |
+
sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words, \
|
172 |
+
sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, sense6negwords, sense7negwords, \
|
173 |
+
sense8negwords, sense9negwords, sense10negwords, sense11negwords, sense12negwords, sense13negwords, sense14negwords, sense15negwords, \
|
174 |
+
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider, \
|
175 |
+
sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider
|
176 |
|
177 |
"""
|
178 |
Wrapper for when the user selects a new token in the tokens dataframe.
|
179 |
Converts `evt` (the selected token) to `token` and `token_index` which are used by get_token_contextual_weights.
|
180 |
"""
|
181 |
+
def new_token_contextual_weights (contextualization_weights, length, evt: gr.SelectData, pos_count = 5, neg_count = 3):
|
182 |
print(">>>>>in new_token_contextual_weights")
|
183 |
token_index = evt.index[1] # selected token is the token_index-th token in the sentence
|
184 |
token = evt.value
|
185 |
if not token:
|
186 |
+
return None, None, \
|
187 |
+
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
|
188 |
+
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
|
189 |
+
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
190 |
+
return get_token_contextual_weights (contextualization_weights, length, token, token_index, pos_count, neg_count)
|
191 |
|
192 |
def change_sense0_weight(contextualization_weights, length, token_index, new_weight):
|
193 |
contextualization_weights[0, 0, length-1, token_index] = new_weight
|
|
|
281 |
with gr.Column(scale=1):
|
282 |
selected_token = gr.Textbox(label="Current Selected Token", interactive=False)
|
283 |
with gr.Column(scale=8):
|
284 |
+
gr.Markdown("""####
|
285 |
Once a token is chosen, you can **use the sliders below to change the weights of any senses** for that token, \
|
286 |
and then click "Predict next word" to see updated next-word predictions. \
|
287 |
You can change the weights of *multiple senses of multiple tokens;* \
|
|
|
322 |
sense6words = gr.DataFrame(headers = ["Sense 6"])
|
323 |
with gr.Column(scale=0, min_width=120):
|
324 |
sense7words = gr.DataFrame(headers = ["Sense 7"])
|
325 |
+
with gr.Row():
|
326 |
+
with gr.Column(scale=0, min_width=120):
|
327 |
+
sense0negwords = gr.DataFrame(headers = ["Top Negative"])
|
328 |
+
with gr.Column(scale=0, min_width=120):
|
329 |
+
sense1negwords = gr.DataFrame(headers = ["Top Negative"])
|
330 |
+
with gr.Column(scale=0, min_width=120):
|
331 |
+
sense2negwords = gr.DataFrame(headers = ["Top Negative"])
|
332 |
+
with gr.Column(scale=0, min_width=120):
|
333 |
+
sense3negwords = gr.DataFrame(headers = ["Top Negative"])
|
334 |
+
with gr.Column(scale=0, min_width=120):
|
335 |
+
sense4negwords = gr.DataFrame(headers = ["Top Negative"])
|
336 |
+
with gr.Column(scale=0, min_width=120):
|
337 |
+
sense5negwords = gr.DataFrame(headers = ["Top Negative"])
|
338 |
+
with gr.Column(scale=0, min_width=120):
|
339 |
+
sense6negwords = gr.DataFrame(headers = ["Top Negative"])
|
340 |
+
with gr.Column(scale=0, min_width=120):
|
341 |
+
sense7negwords = gr.DataFrame(headers = ["Top Negative"])
|
342 |
with gr.Row():
|
343 |
with gr.Column(scale=0, min_width=120):
|
344 |
sense8slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 8", elem_id="sense8slider", interactive=True)
|
|
|
373 |
sense14words = gr.DataFrame(headers = ["Sense 14"])
|
374 |
with gr.Column(scale=0, min_width=120):
|
375 |
sense15words = gr.DataFrame(headers = ["Sense 15"])
|
376 |
+
with gr.Row():
|
377 |
+
with gr.Column(scale=0, min_width=120):
|
378 |
+
sense8negwords = gr.DataFrame(headers = ["Top Negative"])
|
379 |
+
with gr.Column(scale=0, min_width=120):
|
380 |
+
sense9negwords = gr.DataFrame(headers = ["Top Negative"])
|
381 |
+
with gr.Column(scale=0, min_width=120):
|
382 |
+
sense10negwords = gr.DataFrame(headers = ["Top Negative"])
|
383 |
+
with gr.Column(scale=0, min_width=120):
|
384 |
+
sense11negwords = gr.DataFrame(headers = ["Top Negative"])
|
385 |
+
with gr.Column(scale=0, min_width=120):
|
386 |
+
sense12negwords = gr.DataFrame(headers = ["Top Negative"])
|
387 |
+
with gr.Column(scale=0, min_width=120):
|
388 |
+
sense13negwords = gr.DataFrame(headers = ["Top Negative"])
|
389 |
+
with gr.Column(scale=0, min_width=120):
|
390 |
+
sense14negwords = gr.DataFrame(headers = ["Top Negative"])
|
391 |
+
with gr.Column(scale=0, min_width=120):
|
392 |
+
sense15negwords = gr.DataFrame(headers = ["Top Negative"])
|
393 |
+
gr.Markdown("""Note: **"Top Negative"** shows words that have the most negative dot products with the sense vector, which can
|
394 |
+
exhibit more coherent meaning than those with the most positive dot products.
|
395 |
+
To see more representative words of each sense, scroll to the top and use the **"Individual Word Sense Look Up"** tab.""")
|
396 |
# gr.Examples(
|
397 |
# examples=[["Messi plays for", top_k, None]],
|
398 |
# inputs=[input_sentence, top_k, contextualization_weights],
|
|
|
449 |
inputs=[contextualization_weights, length, token_index, sense15slider],
|
450 |
outputs=[contextualization_weights])
|
451 |
|
452 |
+
|
453 |
predict.click(
|
454 |
fn=predict_next_word,
|
455 |
inputs = [input_sentence, top_k, contextualization_weights],
|
|
|
463 |
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
|
464 |
sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
|
465 |
|
466 |
+
sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, sense6negwords, sense7negwords,
|
467 |
+
sense8negwords, sense9negwords, sense10negwords, sense11negwords, sense12negwords, sense13negwords, sense14negwords, sense15negwords,
|
468 |
+
|
469 |
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
|
470 |
sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
|
471 |
)
|
|
|
486 |
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
|
487 |
sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
|
488 |
|
489 |
+
sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, sense6negwords, sense7negwords,
|
490 |
+
sense8negwords, sense9negwords, sense10negwords, sense11negwords, sense12negwords, sense13negwords, sense14negwords, sense15negwords,
|
491 |
+
|
492 |
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
|
493 |
sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
|
494 |
)
|