dar-tau commited on
Commit
872cb05
·
verified ·
1 Parent(s): 8d02fbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import numpy as np
 
3
  from datasets import load_dataset
4
  from transformers import AutoTokenizer
5
  import matplotlib.pyplot as plt
@@ -42,7 +43,7 @@ def analyze_sentence(index, vis_type, vis_format):
42
  plt.xlabel('SOURCE')
43
  plt.grid()
44
  fig_output, graph_output = fig, ""
45
- else:
46
  ex = [{
47
  "words": [{"text": x, "tag": ""} for x in tokenized[1:]],
48
  "arcs": [{"start": j, "end": i, "label": "", "dir": "right"}
@@ -58,6 +59,14 @@ def analyze_sentence(index, vis_type, vis_format):
58
  "</div>"
59
  )
60
  fig_output = None
 
 
 
 
 
 
 
 
61
  metrics = {'Metrics': 1}
62
 
63
  metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})
@@ -75,7 +84,7 @@ with demo:
75
  min_width=70, value=VisType.SUM, type='value')
76
  btn = gr.Button("Run", min_width=30)
77
 
78
- vis_format_checkbox = gr.Radio(['Plot', 'Graph'])
79
 
80
  output = gr.Plot(label="Plot", container=True)
81
  graph_output = gr.HTML()
 
1
  import gradio as gr
2
  import numpy as np
3
+ import torch
4
  from datasets import load_dataset
5
  from transformers import AutoTokenizer
6
  import matplotlib.pyplot as plt
 
43
  plt.xlabel('SOURCE')
44
  plt.grid()
45
  fig_output, graph_output = fig, ""
46
+ elif vis_format == 'Graph':
47
  ex = [{
48
  "words": [{"text": x, "tag": ""} for x in tokenized[1:]],
49
  "arcs": [{"start": j, "end": i, "label": "", "dir": "right"}
 
59
  "</div>"
60
  )
61
  fig_output = None
62
+ else:
63
+ fig_output = None
64
+ top_values, top_indices = torch.tensor(plot_data).flatten().topk(30)
65
+ topk_data = []
66
+ for val, ind in zip(top_values, top_indices):
67
+ ind = np.unravel_index(ind, plot_data.shape)
68
+ topk_data += [(tokenized[1+ind[0]], tokenized[1+ind[1]])]
69
+ graph_output = '<div>' + str(topk_data) + '</div>'
70
  metrics = {'Metrics': 1}
71
 
72
  metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})
 
84
  min_width=70, value=VisType.SUM, type='value')
85
  btn = gr.Button("Run", min_width=30)
86
 
87
+ vis_format_checkbox = gr.Radio(['Plot', 'Graph', 'Text'])
88
 
89
  output = gr.Plot(label="Plot", container=True)
90
  graph_output = gr.HTML()