Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -25,6 +25,9 @@ def analyze_sentence(index, vis_type):
|
|
25 |
attn_maps = attn_maps[:, 1:, 1:]
|
26 |
if vis_type == VisType.ALL:
|
27 |
plot_data = attn_maps.sum(0)
|
|
|
|
|
|
|
28 |
sns.heatmap(plot_data)
|
29 |
plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
|
30 |
plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
|
@@ -41,7 +44,7 @@ with demo:
|
|
41 |
choices=[x.split('</s> ')[1] for x in dataset['text']],
|
42 |
value=0, min_width=500, type='index')
|
43 |
vis_dropdown = gr.Dropdown(label="Visualization", choices=list(VisType),
|
44 |
-
min_width=
|
45 |
btn = gr.Button("Run", min_width=50)
|
46 |
output = gr.Plot(label="Plot", container=True)
|
47 |
metrics = gr.Label("Metrics")
|
|
|
25 |
attn_maps = attn_maps[:, 1:, 1:]
|
26 |
if vis_type == VisType.ALL:
|
27 |
plot_data = attn_maps.sum(0)
|
28 |
+
else:
|
29 |
+
print(vist_type)
|
30 |
+
0/0
|
31 |
sns.heatmap(plot_data)
|
32 |
plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
|
33 |
plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
|
|
|
44 |
choices=[x.split('</s> ')[1] for x in dataset['text']],
|
45 |
value=0, min_width=500, type='index')
|
46 |
vis_dropdown = gr.Dropdown(label="Visualization", choices=list(VisType),
|
47 |
+
min_width=200, value=VisType.ALL, type='value')
|
48 |
btn = gr.Button("Run", min_width=50)
|
49 |
output = gr.Plot(label="Plot", container=True)
|
50 |
metrics = gr.Label("Metrics")
|