File size: 5,818 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from collections import defaultdict
from typing import Dict
from typing import List
import torch
from espnet.nets.pytorch_backend.rnn.attentions import AttAdd
from espnet.nets.pytorch_backend.rnn.attentions import AttCov
from espnet.nets.pytorch_backend.rnn.attentions import AttCovLoc
from espnet.nets.pytorch_backend.rnn.attentions import AttDot
from espnet.nets.pytorch_backend.rnn.attentions import AttForward
from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA
from espnet.nets.pytorch_backend.rnn.attentions import AttLoc
from espnet.nets.pytorch_backend.rnn.attentions import AttLoc2D
from espnet.nets.pytorch_backend.rnn.attentions import AttLocRec
from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadAdd
from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadDot
from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadLoc
from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadMultiResLoc
from espnet.nets.pytorch_backend.rnn.attentions import NoAtt
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
from espnet2.train.abs_espnet_model import AbsESPnetModel
@torch.no_grad()
def calculate_all_attentions(
model: AbsESPnetModel, batch: Dict[str, torch.Tensor]
) -> Dict[str, List[torch.Tensor]]:
"""Derive the outputs from the all attention layers
Args:
model:
batch: same as forward
Returns:
return_dict: A dict of a list of tensor.
key_names x batch x (D1, D2, ...)
"""
bs = len(next(iter(batch.values())))
assert all(len(v) == bs for v in batch.values()), {
k: v.shape for k, v in batch.items()
}
# 1. Register forward_hook fn to save the output from specific layers
outputs = {}
handles = {}
for name, modu in model.named_modules():
def hook(module, input, output, name=name):
if isinstance(module, MultiHeadedAttention):
# NOTE(kamo): MultiHeadedAttention doesn't return attention weight
# attn: (B, Head, Tout, Tin)
outputs[name] = module.attn.detach().cpu()
elif isinstance(module, AttLoc2D):
c, w = output
# w: previous concate attentions
# w: (B, nprev, Tin)
att_w = w[:, -1].detach().cpu()
outputs.setdefault(name, []).append(att_w)
elif isinstance(module, (AttCov, AttCovLoc)):
c, w = output
assert isinstance(w, list), type(w)
# w: list of previous attentions
# w: nprev x (B, Tin)
att_w = w[-1].detach().cpu()
outputs.setdefault(name, []).append(att_w)
elif isinstance(module, AttLocRec):
# w: (B, Tin)
c, (w, (att_h, att_c)) = output
att_w = w.detach().cpu()
outputs.setdefault(name, []).append(att_w)
elif isinstance(
module,
(
AttMultiHeadDot,
AttMultiHeadAdd,
AttMultiHeadLoc,
AttMultiHeadMultiResLoc,
),
):
c, w = output
# w: nhead x (B, Tin)
assert isinstance(w, list), type(w)
att_w = [_w.detach().cpu() for _w in w]
outputs.setdefault(name, []).append(att_w)
elif isinstance(
module,
(
AttAdd,
AttDot,
AttForward,
AttForwardTA,
AttLoc,
NoAtt,
),
):
c, w = output
att_w = w.detach().cpu()
outputs.setdefault(name, []).append(att_w)
handle = modu.register_forward_hook(hook)
handles[name] = handle
# 2. Just forward one by one sample.
# Batch-mode can't be used to keep requirements small for each models.
keys = []
for k in batch:
if not k.endswith("_lengths"):
keys.append(k)
return_dict = defaultdict(list)
for ibatch in range(bs):
# *: (B, L, ...) -> (1, L2, ...)
_sample = {
k: batch[k][ibatch, None, : batch[k + "_lengths"][ibatch]]
if k + "_lengths" in batch
else batch[k][ibatch, None]
for k in keys
}
# *_lengths: (B,) -> (1,)
_sample.update(
{
k + "_lengths": batch[k + "_lengths"][ibatch, None]
for k in keys
if k + "_lengths" in batch
}
)
model(**_sample)
# Derive the attention results
for name, output in outputs.items():
if isinstance(output, list):
if isinstance(output[0], list):
# output: nhead x (Tout, Tin)
output = torch.stack(
[
# Tout x (1, Tin) -> (Tout, Tin)
torch.cat([o[idx] for o in output], dim=0)
for idx in range(len(output[0]))
],
dim=0,
)
else:
# Tout x (1, Tin) -> (Tout, Tin)
output = torch.cat(output, dim=0)
else:
# output: (1, NHead, Tout, Tin) -> (NHead, Tout, Tin)
output = output.squeeze(0)
# output: (Tout, Tin) or (NHead, Tout, Tin)
return_dict[name].append(output)
outputs.clear()
# 3. Remove all hooks
for _, handle in handles.items():
handle.remove()
return dict(return_dict)
|