File size: 5,536 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import logging
import matplotlib.pyplot as plt
import numpy
import os
from espnet.asr import asr_utils
def _plot_and_save_attention(att_w, filename, xtokens=None, ytokens=None):
# dynamically import matplotlib due to not found error
from matplotlib.ticker import MaxNLocator
d = os.path.dirname(filename)
if not os.path.exists(d):
os.makedirs(d)
w, h = plt.figaspect(1.0 / len(att_w))
fig = plt.Figure(figsize=(w * 2, h * 2))
axes = fig.subplots(1, len(att_w))
if len(att_w) == 1:
axes = [axes]
for ax, aw in zip(axes, att_w):
# plt.subplot(1, len(att_w), h)
ax.imshow(aw.astype(numpy.float32), aspect="auto")
ax.set_xlabel("Input")
ax.set_ylabel("Output")
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
# Labels for major ticks
if xtokens is not None:
ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, len(xtokens)))
ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, 1), minor=True)
ax.set_xticklabels(xtokens + [""], rotation=40)
if ytokens is not None:
ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, len(ytokens)))
ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, 1), minor=True)
ax.set_yticklabels(ytokens + [""])
fig.tight_layout()
return fig
def savefig(plot, filename):
plot.savefig(filename)
plt.clf()
def plot_multi_head_attention(
data,
uttid_list,
attn_dict,
outdir,
suffix="png",
savefn=savefig,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=4,
):
"""Plot multi head attentions.
:param dict data: utts info from json file
:param List uttid_list: utterance IDs
:param dict[str, torch.Tensor] attn_dict: multi head attention dict.
values should be torch.Tensor (head, input_length, output_length)
:param str outdir: dir to save fig
:param str suffix: filename suffix including image type (e.g., png)
:param savefn: function to save
:param str ikey: key to access input
:param int iaxis: dimension to access input
:param str okey: key to access output
:param int oaxis: dimension to access output
:param subsampling_factor: subsampling factor in encoder
"""
for name, att_ws in attn_dict.items():
for idx, att_w in enumerate(att_ws):
data_i = data[uttid_list[idx]]
filename = "%s/%s.%s.%s" % (outdir, uttid_list[idx], name, suffix)
dec_len = int(data_i[okey][oaxis]["shape"][0]) + 1 # +1 for <eos>
enc_len = int(data_i[ikey][iaxis]["shape"][0])
is_mt = "token" in data_i[ikey][iaxis].keys()
# for ASR/ST
if not is_mt:
enc_len //= subsampling_factor
xtokens, ytokens = None, None
if "encoder" in name:
att_w = att_w[:, :enc_len, :enc_len]
# for MT
if is_mt:
xtokens = data_i[ikey][iaxis]["token"].split()
ytokens = xtokens[:]
elif "decoder" in name:
if "self" in name:
# self-attention
att_w = att_w[:, :dec_len, :dec_len]
if "token" in data_i[okey][oaxis].keys():
ytokens = data_i[okey][oaxis]["token"].split() + ["<eos>"]
xtokens = ["<sos>"] + data_i[okey][oaxis]["token"].split()
else:
# cross-attention
att_w = att_w[:, :dec_len, :enc_len]
if "token" in data_i[okey][oaxis].keys():
ytokens = data_i[okey][oaxis]["token"].split() + ["<eos>"]
# for MT
if is_mt:
xtokens = data_i[ikey][iaxis]["token"].split()
else:
logging.warning("unknown name for shaping attention")
fig = _plot_and_save_attention(att_w, filename, xtokens, ytokens)
savefn(fig, filename)
class PlotAttentionReport(asr_utils.PlotAttentionReport):
def plotfn(self, *args, **kwargs):
kwargs["ikey"] = self.ikey
kwargs["iaxis"] = self.iaxis
kwargs["okey"] = self.okey
kwargs["oaxis"] = self.oaxis
kwargs["subsampling_factor"] = self.factor
plot_multi_head_attention(*args, **kwargs)
def __call__(self, trainer):
attn_dict, uttid_list = self.get_attention_weights()
suffix = "ep.{.updater.epoch}.png".format(trainer)
self.plotfn(self.data_dict, uttid_list, attn_dict, self.outdir, suffix, savefig)
def get_attention_weights(self):
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
batch = self.converter([return_batch], self.device)
if isinstance(batch, tuple):
att_ws = self.att_vis_fn(*batch)
elif isinstance(batch, dict):
att_ws = self.att_vis_fn(**batch)
return att_ws, uttid_list
def log_attentions(self, logger, step):
def log_fig(plot, filename):
logger.add_figure(os.path.basename(filename), plot, step)
plt.clf()
attn_dict, uttid_list = self.get_attention_weights()
self.plotfn(self.data_dict, uttid_list, attn_dict, self.outdir, "", log_fig)
|