tobiasc's picture
Initial commit
ad16788
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import logging
import matplotlib.pyplot as plt
import numpy
import os
from espnet.asr import asr_utils
def _plot_and_save_attention(att_w, filename, xtokens=None, ytokens=None):
# dynamically import matplotlib due to not found error
from matplotlib.ticker import MaxNLocator
d = os.path.dirname(filename)
if not os.path.exists(d):
os.makedirs(d)
w, h = plt.figaspect(1.0 / len(att_w))
fig = plt.Figure(figsize=(w * 2, h * 2))
axes = fig.subplots(1, len(att_w))
if len(att_w) == 1:
axes = [axes]
for ax, aw in zip(axes, att_w):
# plt.subplot(1, len(att_w), h)
ax.imshow(aw.astype(numpy.float32), aspect="auto")
ax.set_xlabel("Input")
ax.set_ylabel("Output")
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
# Labels for major ticks
if xtokens is not None:
ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, len(xtokens)))
ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, 1), minor=True)
ax.set_xticklabels(xtokens + [""], rotation=40)
if ytokens is not None:
ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, len(ytokens)))
ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, 1), minor=True)
ax.set_yticklabels(ytokens + [""])
fig.tight_layout()
return fig
def savefig(plot, filename):
plot.savefig(filename)
plt.clf()
def plot_multi_head_attention(
data,
uttid_list,
attn_dict,
outdir,
suffix="png",
savefn=savefig,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=4,
):
"""Plot multi head attentions.
:param dict data: utts info from json file
:param List uttid_list: utterance IDs
:param dict[str, torch.Tensor] attn_dict: multi head attention dict.
values should be torch.Tensor (head, input_length, output_length)
:param str outdir: dir to save fig
:param str suffix: filename suffix including image type (e.g., png)
:param savefn: function to save
:param str ikey: key to access input
:param int iaxis: dimension to access input
:param str okey: key to access output
:param int oaxis: dimension to access output
:param subsampling_factor: subsampling factor in encoder
"""
for name, att_ws in attn_dict.items():
for idx, att_w in enumerate(att_ws):
data_i = data[uttid_list[idx]]
filename = "%s/%s.%s.%s" % (outdir, uttid_list[idx], name, suffix)
dec_len = int(data_i[okey][oaxis]["shape"][0]) + 1 # +1 for <eos>
enc_len = int(data_i[ikey][iaxis]["shape"][0])
is_mt = "token" in data_i[ikey][iaxis].keys()
# for ASR/ST
if not is_mt:
enc_len //= subsampling_factor
xtokens, ytokens = None, None
if "encoder" in name:
att_w = att_w[:, :enc_len, :enc_len]
# for MT
if is_mt:
xtokens = data_i[ikey][iaxis]["token"].split()
ytokens = xtokens[:]
elif "decoder" in name:
if "self" in name:
# self-attention
att_w = att_w[:, :dec_len, :dec_len]
if "token" in data_i[okey][oaxis].keys():
ytokens = data_i[okey][oaxis]["token"].split() + ["<eos>"]
xtokens = ["<sos>"] + data_i[okey][oaxis]["token"].split()
else:
# cross-attention
att_w = att_w[:, :dec_len, :enc_len]
if "token" in data_i[okey][oaxis].keys():
ytokens = data_i[okey][oaxis]["token"].split() + ["<eos>"]
# for MT
if is_mt:
xtokens = data_i[ikey][iaxis]["token"].split()
else:
logging.warning("unknown name for shaping attention")
fig = _plot_and_save_attention(att_w, filename, xtokens, ytokens)
savefn(fig, filename)
class PlotAttentionReport(asr_utils.PlotAttentionReport):
def plotfn(self, *args, **kwargs):
kwargs["ikey"] = self.ikey
kwargs["iaxis"] = self.iaxis
kwargs["okey"] = self.okey
kwargs["oaxis"] = self.oaxis
kwargs["subsampling_factor"] = self.factor
plot_multi_head_attention(*args, **kwargs)
def __call__(self, trainer):
attn_dict, uttid_list = self.get_attention_weights()
suffix = "ep.{.updater.epoch}.png".format(trainer)
self.plotfn(self.data_dict, uttid_list, attn_dict, self.outdir, suffix, savefig)
def get_attention_weights(self):
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
batch = self.converter([return_batch], self.device)
if isinstance(batch, tuple):
att_ws = self.att_vis_fn(*batch)
elif isinstance(batch, dict):
att_ws = self.att_vis_fn(**batch)
return att_ws, uttid_list
def log_attentions(self, logger, step):
def log_fig(plot, filename):
logger.add_figure(os.path.basename(filename), plot, step)
plt.clf()
attn_dict, uttid_list = self.get_attention_weights()
self.plotfn(self.data_dict, uttid_list, attn_dict, self.outdir, "", log_fig)