|
|
|
|
|
|
|
import logging |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy |
|
import os |
|
|
|
from espnet.asr import asr_utils |
|
|
|
|
|
def _plot_and_save_attention(att_w, filename, xtokens=None, ytokens=None): |
|
|
|
from matplotlib.ticker import MaxNLocator |
|
|
|
d = os.path.dirname(filename) |
|
if not os.path.exists(d): |
|
os.makedirs(d) |
|
w, h = plt.figaspect(1.0 / len(att_w)) |
|
fig = plt.Figure(figsize=(w * 2, h * 2)) |
|
axes = fig.subplots(1, len(att_w)) |
|
if len(att_w) == 1: |
|
axes = [axes] |
|
for ax, aw in zip(axes, att_w): |
|
|
|
ax.imshow(aw.astype(numpy.float32), aspect="auto") |
|
ax.set_xlabel("Input") |
|
ax.set_ylabel("Output") |
|
ax.xaxis.set_major_locator(MaxNLocator(integer=True)) |
|
ax.yaxis.set_major_locator(MaxNLocator(integer=True)) |
|
|
|
if xtokens is not None: |
|
ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, len(xtokens))) |
|
ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, 1), minor=True) |
|
ax.set_xticklabels(xtokens + [""], rotation=40) |
|
if ytokens is not None: |
|
ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, len(ytokens))) |
|
ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, 1), minor=True) |
|
ax.set_yticklabels(ytokens + [""]) |
|
fig.tight_layout() |
|
return fig |
|
|
|
|
|
def savefig(plot, filename): |
|
plot.savefig(filename) |
|
plt.clf() |
|
|
|
|
|
def plot_multi_head_attention( |
|
data, |
|
uttid_list, |
|
attn_dict, |
|
outdir, |
|
suffix="png", |
|
savefn=savefig, |
|
ikey="input", |
|
iaxis=0, |
|
okey="output", |
|
oaxis=0, |
|
subsampling_factor=4, |
|
): |
|
"""Plot multi head attentions. |
|
|
|
:param dict data: utts info from json file |
|
:param List uttid_list: utterance IDs |
|
:param dict[str, torch.Tensor] attn_dict: multi head attention dict. |
|
values should be torch.Tensor (head, input_length, output_length) |
|
:param str outdir: dir to save fig |
|
:param str suffix: filename suffix including image type (e.g., png) |
|
:param savefn: function to save |
|
:param str ikey: key to access input |
|
:param int iaxis: dimension to access input |
|
:param str okey: key to access output |
|
:param int oaxis: dimension to access output |
|
:param subsampling_factor: subsampling factor in encoder |
|
|
|
""" |
|
for name, att_ws in attn_dict.items(): |
|
for idx, att_w in enumerate(att_ws): |
|
data_i = data[uttid_list[idx]] |
|
filename = "%s/%s.%s.%s" % (outdir, uttid_list[idx], name, suffix) |
|
dec_len = int(data_i[okey][oaxis]["shape"][0]) + 1 |
|
enc_len = int(data_i[ikey][iaxis]["shape"][0]) |
|
is_mt = "token" in data_i[ikey][iaxis].keys() |
|
|
|
if not is_mt: |
|
enc_len //= subsampling_factor |
|
xtokens, ytokens = None, None |
|
if "encoder" in name: |
|
att_w = att_w[:, :enc_len, :enc_len] |
|
|
|
if is_mt: |
|
xtokens = data_i[ikey][iaxis]["token"].split() |
|
ytokens = xtokens[:] |
|
elif "decoder" in name: |
|
if "self" in name: |
|
|
|
att_w = att_w[:, :dec_len, :dec_len] |
|
if "token" in data_i[okey][oaxis].keys(): |
|
ytokens = data_i[okey][oaxis]["token"].split() + ["<eos>"] |
|
xtokens = ["<sos>"] + data_i[okey][oaxis]["token"].split() |
|
else: |
|
|
|
att_w = att_w[:, :dec_len, :enc_len] |
|
if "token" in data_i[okey][oaxis].keys(): |
|
ytokens = data_i[okey][oaxis]["token"].split() + ["<eos>"] |
|
|
|
if is_mt: |
|
xtokens = data_i[ikey][iaxis]["token"].split() |
|
else: |
|
logging.warning("unknown name for shaping attention") |
|
fig = _plot_and_save_attention(att_w, filename, xtokens, ytokens) |
|
savefn(fig, filename) |
|
|
|
|
|
class PlotAttentionReport(asr_utils.PlotAttentionReport): |
|
def plotfn(self, *args, **kwargs): |
|
kwargs["ikey"] = self.ikey |
|
kwargs["iaxis"] = self.iaxis |
|
kwargs["okey"] = self.okey |
|
kwargs["oaxis"] = self.oaxis |
|
kwargs["subsampling_factor"] = self.factor |
|
plot_multi_head_attention(*args, **kwargs) |
|
|
|
def __call__(self, trainer): |
|
attn_dict, uttid_list = self.get_attention_weights() |
|
suffix = "ep.{.updater.epoch}.png".format(trainer) |
|
self.plotfn(self.data_dict, uttid_list, attn_dict, self.outdir, suffix, savefig) |
|
|
|
def get_attention_weights(self): |
|
return_batch, uttid_list = self.transform(self.data, return_uttid=True) |
|
batch = self.converter([return_batch], self.device) |
|
if isinstance(batch, tuple): |
|
att_ws = self.att_vis_fn(*batch) |
|
elif isinstance(batch, dict): |
|
att_ws = self.att_vis_fn(**batch) |
|
return att_ws, uttid_list |
|
|
|
def log_attentions(self, logger, step): |
|
def log_fig(plot, filename): |
|
logger.add_figure(os.path.basename(filename), plot, step) |
|
plt.clf() |
|
|
|
attn_dict, uttid_list = self.get_attention_weights() |
|
self.plotfn(self.data_dict, uttid_list, attn_dict, self.outdir, "", log_fig) |
|
|