Lora commited on
Commit
5758582
·
1 Parent(s): 0730e98

add negative sense words and note

Browse files
Files changed (1) hide show
  1. app.py +69 -18
app.py CHANGED
@@ -121,9 +121,10 @@ Args:
121
  length: length of the input sentence, used to get the contextualization weights for the last token
122
  token: the selected token
123
  token_index: the index of the selected token in the input sentence
124
- count: how many top words to display for each sense
 
125
  """
126
- def get_token_contextual_weights (contextualization_weights, length, token, token_index, count = 7):
127
  print(">>>>>in get_token_contextual_weights")
128
  print(f"Selected {token_index}th token: {token}")
129
 
@@ -139,47 +140,54 @@ def get_token_contextual_weights (contextualization_weights, length, token, toke
139
  senses = torch.squeeze(senses) # (nv, s=1, d)
140
 
141
  # build dataframe
142
- neg_word_lists = []
143
  pos_dfs, neg_dfs = [], []
144
 
145
  for i in range(num_senses):
146
  logits = lm_head(senses[i,:]) # (vocab,) [768, 50257] -> [50257]
147
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
148
 
149
- pos_sorted_words = [tokenizer.decode(sorted_indices[j]) for j in range(count)]
150
- pos_df = pd.DataFrame(pos_sorted_words)
151
  pos_dfs.append(pos_df)
152
 
153
- neg_sorted_words = [tokenizer.decode(sorted_indices[-j-1]) for j in range(count)]
154
- neg_df = pd.DataFrame(neg_sorted_words)
155
  neg_dfs.append(neg_df)
156
 
157
  sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, \
158
  sense6words, sense7words, sense8words, sense9words, sense10words, sense11words, \
159
  sense12words, sense13words, sense14words, sense15words = pos_dfs
160
 
 
 
 
 
161
  sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, \
162
  sense6slider, sense7slider, sense8slider, sense9slider, sense10slider, sense11slider, \
163
  sense12slider, sense13slider, sense14slider, sense15slider = token_contextualization_weights_list
164
 
165
- return token, token_index, sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, \
166
- sense7words, sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words, \
167
- sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider, \
168
- sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider
 
 
 
169
 
170
  """
171
  Wrapper for when the user selects a new token in the tokens dataframe.
172
  Converts `evt` (the selected token) to `token` and `token_index` which are used by get_token_contextual_weights.
173
  """
174
- def new_token_contextual_weights (contextualization_weights, length, evt: gr.SelectData, count = 7):
175
  print(">>>>>in new_token_contextual_weights")
176
  token_index = evt.index[1] # selected token is the token_index-th token in the sentence
177
  token = evt.value
178
  if not token:
179
- return None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
180
- None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
181
- None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
182
- return get_token_contextual_weights (contextualization_weights, length, token, token_index, count)
 
183
 
184
  def change_sense0_weight(contextualization_weights, length, token_index, new_weight):
185
  contextualization_weights[0, 0, length-1, token_index] = new_weight
@@ -273,7 +281,7 @@ with gr.Blocks( theme = gr.themes.Base(),
273
  with gr.Column(scale=1):
274
  selected_token = gr.Textbox(label="Current Selected Token", interactive=False)
275
  with gr.Column(scale=8):
276
- gr.Markdown("""#####
277
  Once a token is chosen, you can **use the sliders below to change the weights of any senses** for that token, \
278
  and then click "Predict next word" to see updated next-word predictions. \
279
  You can change the weights of *multiple senses of multiple tokens;* \
@@ -314,6 +322,23 @@ with gr.Blocks( theme = gr.themes.Base(),
314
  sense6words = gr.DataFrame(headers = ["Sense 6"])
315
  with gr.Column(scale=0, min_width=120):
316
  sense7words = gr.DataFrame(headers = ["Sense 7"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  with gr.Row():
318
  with gr.Column(scale=0, min_width=120):
319
  sense8slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 8", elem_id="sense8slider", interactive=True)
@@ -348,7 +373,26 @@ with gr.Blocks( theme = gr.themes.Base(),
348
  sense14words = gr.DataFrame(headers = ["Sense 14"])
349
  with gr.Column(scale=0, min_width=120):
350
  sense15words = gr.DataFrame(headers = ["Sense 15"])
351
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  # gr.Examples(
353
  # examples=[["Messi plays for", top_k, None]],
354
  # inputs=[input_sentence, top_k, contextualization_weights],
@@ -405,6 +449,7 @@ with gr.Blocks( theme = gr.themes.Base(),
405
  inputs=[contextualization_weights, length, token_index, sense15slider],
406
  outputs=[contextualization_weights])
407
 
 
408
  predict.click(
409
  fn=predict_next_word,
410
  inputs = [input_sentence, top_k, contextualization_weights],
@@ -418,6 +463,9 @@ with gr.Blocks( theme = gr.themes.Base(),
418
  sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
419
  sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
420
 
 
 
 
421
  sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
422
  sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
423
  )
@@ -438,6 +486,9 @@ with gr.Blocks( theme = gr.themes.Base(),
438
  sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
439
  sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
440
 
 
 
 
441
  sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
442
  sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
443
  )
 
121
  length: length of the input sentence, used to get the contextualization weights for the last token
122
  token: the selected token
123
  token_index: the index of the selected token in the input sentence
124
+ pos_count: how many top positive words to display for each sense
125
+ neg_count: how many top negative words to display for each sense
126
  """
127
+ def get_token_contextual_weights (contextualization_weights, length, token, token_index, pos_count = 5, neg_count = 3):
128
  print(">>>>>in get_token_contextual_weights")
129
  print(f"Selected {token_index}th token: {token}")
130
 
 
140
  senses = torch.squeeze(senses) # (nv, s=1, d)
141
 
142
  # build dataframe
 
143
  pos_dfs, neg_dfs = [], []
144
 
145
  for i in range(num_senses):
146
  logits = lm_head(senses[i,:]) # (vocab,) [768, 50257] -> [50257]
147
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
148
 
149
+ pos_sorted_words = [tokenizer.decode(sorted_indices[j]) for j in range(pos_count)]
150
+ pos_df = pd.DataFrame(pos_sorted_words, columns=["Sense {}".format(i)])
151
  pos_dfs.append(pos_df)
152
 
153
+ neg_sorted_words = [tokenizer.decode(sorted_indices[-j-1]) for j in range(neg_count)]
154
+ neg_df = pd.DataFrame(neg_sorted_words, columns=["Top Negative"])
155
  neg_dfs.append(neg_df)
156
 
157
  sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, \
158
  sense6words, sense7words, sense8words, sense9words, sense10words, sense11words, \
159
  sense12words, sense13words, sense14words, sense15words = pos_dfs
160
 
161
+ sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, \
162
+ sense6negwords, sense7negwords, sense8negwords, sense9negwords, sense10negwords, sense11negwords, \
163
+ sense12negwords, sense13negwords, sense14negwords, sense15negwords = neg_dfs
164
+
165
  sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, \
166
  sense6slider, sense7slider, sense8slider, sense9slider, sense10slider, sense11slider, \
167
  sense12slider, sense13slider, sense14slider, sense15slider = token_contextualization_weights_list
168
 
169
+ return token, token_index, \
170
+ sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words, \
171
+ sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words, \
172
+ sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, sense6negwords, sense7negwords, \
173
+ sense8negwords, sense9negwords, sense10negwords, sense11negwords, sense12negwords, sense13negwords, sense14negwords, sense15negwords, \
174
+ sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider, \
175
+ sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider
176
 
177
  """
178
  Wrapper for when the user selects a new token in the tokens dataframe.
179
  Converts `evt` (the selected token) to `token` and `token_index` which are used by get_token_contextual_weights.
180
  """
181
+ def new_token_contextual_weights (contextualization_weights, length, evt: gr.SelectData, pos_count = 5, neg_count = 3):
182
  print(">>>>>in new_token_contextual_weights")
183
  token_index = evt.index[1] # selected token is the token_index-th token in the sentence
184
  token = evt.value
185
  if not token:
186
+ return None, None, \
187
+ None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
188
+ None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
189
+ None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
190
+ return get_token_contextual_weights (contextualization_weights, length, token, token_index, pos_count, neg_count)
191
 
192
  def change_sense0_weight(contextualization_weights, length, token_index, new_weight):
193
  contextualization_weights[0, 0, length-1, token_index] = new_weight
 
281
  with gr.Column(scale=1):
282
  selected_token = gr.Textbox(label="Current Selected Token", interactive=False)
283
  with gr.Column(scale=8):
284
+ gr.Markdown("""####
285
  Once a token is chosen, you can **use the sliders below to change the weights of any senses** for that token, \
286
  and then click "Predict next word" to see updated next-word predictions. \
287
  You can change the weights of *multiple senses of multiple tokens;* \
 
322
  sense6words = gr.DataFrame(headers = ["Sense 6"])
323
  with gr.Column(scale=0, min_width=120):
324
  sense7words = gr.DataFrame(headers = ["Sense 7"])
325
+ with gr.Row():
326
+ with gr.Column(scale=0, min_width=120):
327
+ sense0negwords = gr.DataFrame(headers = ["Top Negative"])
328
+ with gr.Column(scale=0, min_width=120):
329
+ sense1negwords = gr.DataFrame(headers = ["Top Negative"])
330
+ with gr.Column(scale=0, min_width=120):
331
+ sense2negwords = gr.DataFrame(headers = ["Top Negative"])
332
+ with gr.Column(scale=0, min_width=120):
333
+ sense3negwords = gr.DataFrame(headers = ["Top Negative"])
334
+ with gr.Column(scale=0, min_width=120):
335
+ sense4negwords = gr.DataFrame(headers = ["Top Negative"])
336
+ with gr.Column(scale=0, min_width=120):
337
+ sense5negwords = gr.DataFrame(headers = ["Top Negative"])
338
+ with gr.Column(scale=0, min_width=120):
339
+ sense6negwords = gr.DataFrame(headers = ["Top Negative"])
340
+ with gr.Column(scale=0, min_width=120):
341
+ sense7negwords = gr.DataFrame(headers = ["Top Negative"])
342
  with gr.Row():
343
  with gr.Column(scale=0, min_width=120):
344
  sense8slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 8", elem_id="sense8slider", interactive=True)
 
373
  sense14words = gr.DataFrame(headers = ["Sense 14"])
374
  with gr.Column(scale=0, min_width=120):
375
  sense15words = gr.DataFrame(headers = ["Sense 15"])
376
+ with gr.Row():
377
+ with gr.Column(scale=0, min_width=120):
378
+ sense8negwords = gr.DataFrame(headers = ["Top Negative"])
379
+ with gr.Column(scale=0, min_width=120):
380
+ sense9negwords = gr.DataFrame(headers = ["Top Negative"])
381
+ with gr.Column(scale=0, min_width=120):
382
+ sense10negwords = gr.DataFrame(headers = ["Top Negative"])
383
+ with gr.Column(scale=0, min_width=120):
384
+ sense11negwords = gr.DataFrame(headers = ["Top Negative"])
385
+ with gr.Column(scale=0, min_width=120):
386
+ sense12negwords = gr.DataFrame(headers = ["Top Negative"])
387
+ with gr.Column(scale=0, min_width=120):
388
+ sense13negwords = gr.DataFrame(headers = ["Top Negative"])
389
+ with gr.Column(scale=0, min_width=120):
390
+ sense14negwords = gr.DataFrame(headers = ["Top Negative"])
391
+ with gr.Column(scale=0, min_width=120):
392
+ sense15negwords = gr.DataFrame(headers = ["Top Negative"])
393
+ gr.Markdown("""Note: **"Top Negative"** shows words that have the most negative dot products with the sense vector, which can
394
+ exhibit more coherent meaning than those with the most positive dot products.
395
+ To see more representative words of each sense, scroll to the top and use the **"Individual Word Sense Look Up"** tab.""")
396
  # gr.Examples(
397
  # examples=[["Messi plays for", top_k, None]],
398
  # inputs=[input_sentence, top_k, contextualization_weights],
 
449
  inputs=[contextualization_weights, length, token_index, sense15slider],
450
  outputs=[contextualization_weights])
451
 
452
+
453
  predict.click(
454
  fn=predict_next_word,
455
  inputs = [input_sentence, top_k, contextualization_weights],
 
463
  sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
464
  sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
465
 
466
+ sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, sense6negwords, sense7negwords,
467
+ sense8negwords, sense9negwords, sense10negwords, sense11negwords, sense12negwords, sense13negwords, sense14negwords, sense15negwords,
468
+
469
  sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
470
  sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
471
  )
 
486
  sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
487
  sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
488
 
489
+ sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, sense6negwords, sense7negwords,
490
+ sense8negwords, sense9negwords, sense10negwords, sense11negwords, sense12negwords, sense13negwords, sense14negwords, sense15negwords,
491
+
492
  sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
493
  sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
494
  )