Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
from datasets import load_dataset
|
3 |
import matplotlib.pyplot as plt
|
4 |
import seaborn as sns
|
@@ -8,7 +9,15 @@ dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
|
|
8 |
def analyze_sentence(index):
|
9 |
row = dataset[index]
|
10 |
attn_maps = np.array(row['attention_maps']).reshape(*row['attention_maps_shape'])
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
iface = gr.Interface(fn=analyze_sentence, inputs=[gr.Dropdown(choices=dataset['text'], type='index')],
|
|
|
1 |
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
from datasets import load_dataset
|
4 |
import matplotlib.pyplot as plt
|
5 |
import seaborn as sns
|
|
|
9 |
def analyze_sentence(index):
|
10 |
row = dataset[index]
|
11 |
attn_maps = np.array(row['attention_maps']).reshape(*row['attention_maps_shape'])
|
12 |
+
plot = sns.heatmap(attn_maps.sum(1).sum(0))
|
13 |
+
plt.xticks(np.arange(len(tokenized)-1) + 0.5,
|
14 |
+
tokenizer.tokenize(text, add_special_tokens=False), rotation=90);
|
15 |
+
plt.yticks(np.arange(len(tokenized)-1) + 0.5,
|
16 |
+
tokenizer.tokenize(text, add_special_tokens=False), rotation=0);
|
17 |
+
plt.ylabel('TARGET')
|
18 |
+
plt.xlabel('SOURCE')
|
19 |
+
plt.grid()
|
20 |
+
return row['text'], plot
|
21 |
|
22 |
|
23 |
iface = gr.Interface(fn=analyze_sentence, inputs=[gr.Dropdown(choices=dataset['text'], type='index')],
|