Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
|
|
3 |
from datasets import load_dataset
|
4 |
from transformers import AutoTokenizer
|
5 |
import matplotlib.pyplot as plt
|
@@ -42,7 +43,7 @@ def analyze_sentence(index, vis_type, vis_format):
|
|
42 |
plt.xlabel('SOURCE')
|
43 |
plt.grid()
|
44 |
fig_output, graph_output = fig, ""
|
45 |
-
|
46 |
ex = [{
|
47 |
"words": [{"text": x, "tag": ""} for x in tokenized[1:]],
|
48 |
"arcs": [{"start": j, "end": i, "label": "", "dir": "right"}
|
@@ -58,6 +59,14 @@ def analyze_sentence(index, vis_type, vis_format):
|
|
58 |
"</div>"
|
59 |
)
|
60 |
fig_output = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
metrics = {'Metrics': 1}
|
62 |
|
63 |
metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})
|
@@ -75,7 +84,7 @@ with demo:
|
|
75 |
min_width=70, value=VisType.SUM, type='value')
|
76 |
btn = gr.Button("Run", min_width=30)
|
77 |
|
78 |
-
vis_format_checkbox = gr.Radio(['Plot', 'Graph'])
|
79 |
|
80 |
output = gr.Plot(label="Plot", container=True)
|
81 |
graph_output = gr.HTML()
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
+
import torch
|
4 |
from datasets import load_dataset
|
5 |
from transformers import AutoTokenizer
|
6 |
import matplotlib.pyplot as plt
|
|
|
43 |
plt.xlabel('SOURCE')
|
44 |
plt.grid()
|
45 |
fig_output, graph_output = fig, ""
|
46 |
+
elif vis_format == 'Graph':
|
47 |
ex = [{
|
48 |
"words": [{"text": x, "tag": ""} for x in tokenized[1:]],
|
49 |
"arcs": [{"start": j, "end": i, "label": "", "dir": "right"}
|
|
|
59 |
"</div>"
|
60 |
)
|
61 |
fig_output = None
|
62 |
+
else:
|
63 |
+
fig_output = None
|
64 |
+
top_values, top_indices = torch.tensor(plot_data).flatten().topk(30)
|
65 |
+
topk_data = []
|
66 |
+
for val, ind in zip(top_values, top_indices):
|
67 |
+
ind = np.unravel_index(ind, plot_data.shape)
|
68 |
+
topk_data += [(tokenized[1+ind[0]], tokenized[1+ind[1]])]
|
69 |
+
graph_output = '<div>' + str(topk_data) + '</div>'
|
70 |
metrics = {'Metrics': 1}
|
71 |
|
72 |
metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})
|
|
|
84 |
min_width=70, value=VisType.SUM, type='value')
|
85 |
btn = gr.Button("Run", min_width=30)
|
86 |
|
87 |
+
vis_format_checkbox = gr.Radio(['Plot', 'Graph', 'Text'])
|
88 |
|
89 |
output = gr.Plot(label="Plot", container=True)
|
90 |
graph_output = gr.HTML()
|