|
import argparse |
|
from contextlib import contextmanager |
|
import dataclasses |
|
from dataclasses import is_dataclass |
|
from distutils.version import LooseVersion |
|
import logging |
|
from pathlib import Path |
|
import time |
|
from typing import Dict |
|
from typing import Iterable |
|
from typing import List |
|
from typing import Optional |
|
from typing import Sequence |
|
from typing import Tuple |
|
from typing import Union |
|
|
|
import humanfriendly |
|
import numpy as np |
|
import torch |
|
import torch.nn |
|
import torch.optim |
|
from typeguard import check_argument_types |
|
|
|
from espnet2.iterators.abs_iter_factory import AbsIterFactory |
|
from espnet2.main_funcs.average_nbest_models import average_nbest_models |
|
from espnet2.main_funcs.calculate_all_attentions import calculate_all_attentions |
|
from espnet2.schedulers.abs_scheduler import AbsBatchStepScheduler |
|
from espnet2.schedulers.abs_scheduler import AbsEpochStepScheduler |
|
from espnet2.schedulers.abs_scheduler import AbsScheduler |
|
from espnet2.schedulers.abs_scheduler import AbsValEpochStepScheduler |
|
from espnet2.torch_utils.add_gradient_noise import add_gradient_noise |
|
from espnet2.torch_utils.device_funcs import to_device |
|
from espnet2.torch_utils.recursive_op import recursive_average |
|
from espnet2.torch_utils.set_all_random_seed import set_all_random_seed |
|
from espnet2.train.abs_espnet_model import AbsESPnetModel |
|
from espnet2.train.distributed_utils import DistributedOption |
|
from espnet2.train.reporter import Reporter |
|
from espnet2.train.reporter import SubReporter |
|
from espnet2.utils.build_dataclass import build_dataclass |
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"): |
|
from torch.utils.tensorboard import SummaryWriter |
|
else: |
|
from tensorboardX import SummaryWriter |
|
if torch.distributed.is_available(): |
|
if LooseVersion(torch.__version__) > LooseVersion("1.0.1"): |
|
from torch.distributed import ReduceOp |
|
else: |
|
from torch.distributed import reduce_op as ReduceOp |
|
else: |
|
ReduceOp = None |
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
|
from torch.cuda.amp import autocast |
|
from torch.cuda.amp import GradScaler |
|
else: |
|
|
|
@contextmanager |
|
def autocast(enabled=True): |
|
yield |
|
|
|
GradScaler = None |
|
|
|
try: |
|
import fairscale |
|
except ImportError: |
|
fairscale = None |
|
|
|
|
|
@dataclasses.dataclass |
|
class TrainerOptions: |
|
ngpu: int |
|
resume: bool |
|
use_amp: bool |
|
train_dtype: str |
|
grad_noise: bool |
|
accum_grad: int |
|
grad_clip: float |
|
grad_clip_type: float |
|
log_interval: Optional[int] |
|
no_forward_run: bool |
|
use_tensorboard: bool |
|
use_wandb: bool |
|
output_dir: Union[Path, str] |
|
max_epoch: int |
|
seed: int |
|
sharded_ddp: bool |
|
patience: Optional[int] |
|
keep_nbest_models: Union[int, List[int]] |
|
early_stopping_criterion: Sequence[str] |
|
best_model_criterion: Sequence[Sequence[str]] |
|
val_scheduler_criterion: Sequence[str] |
|
unused_parameters: bool |
|
|
|
|
|
class Trainer: |
|
"""Trainer having a optimizer. |
|
|
|
If you'd like to use multiple optimizers, then inherit this class |
|
and override the methods if necessary - at least "train_one_epoch()" |
|
|
|
>>> class TwoOptimizerTrainer(Trainer): |
|
... @classmethod |
|
... def add_arguments(cls, parser): |
|
... ... |
|
... |
|
... @classmethod |
|
... def train_one_epoch(cls, model, optimizers, ...): |
|
... loss1 = model.model1(...) |
|
... loss1.backward() |
|
... optimizers[0].step() |
|
... |
|
... loss2 = model.model2(...) |
|
... loss2.backward() |
|
... optimizers[1].step() |
|
|
|
""" |
|
|
|
def __init__(self): |
|
raise RuntimeError("This class can't be instantiated.") |
|
|
|
@classmethod |
|
def build_options(cls, args: argparse.Namespace) -> TrainerOptions: |
|
"""Build options consumed by train(), eval(), and plot_attention()""" |
|
assert check_argument_types() |
|
return build_dataclass(TrainerOptions, args) |
|
|
|
@classmethod |
|
def add_arguments(cls, parser: argparse.ArgumentParser): |
|
"""Reserved for future development of another Trainer""" |
|
pass |
|
|
|
@staticmethod |
|
def resume( |
|
checkpoint: Union[str, Path], |
|
model: torch.nn.Module, |
|
reporter: Reporter, |
|
optimizers: Sequence[torch.optim.Optimizer], |
|
schedulers: Sequence[Optional[AbsScheduler]], |
|
scaler: Optional[GradScaler], |
|
ngpu: int = 0, |
|
): |
|
states = torch.load( |
|
checkpoint, |
|
map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu", |
|
) |
|
model.load_state_dict(states["model"]) |
|
reporter.load_state_dict(states["reporter"]) |
|
for optimizer, state in zip(optimizers, states["optimizers"]): |
|
optimizer.load_state_dict(state) |
|
for scheduler, state in zip(schedulers, states["schedulers"]): |
|
if scheduler is not None: |
|
scheduler.load_state_dict(state) |
|
if scaler is not None: |
|
if states["scaler"] is None: |
|
logging.warning("scaler state is not found") |
|
else: |
|
scaler.load_state_dict(states["scaler"]) |
|
|
|
logging.info(f"The training was resumed using {checkpoint}") |
|
|
|
@classmethod |
|
def run( |
|
cls, |
|
model: AbsESPnetModel, |
|
optimizers: Sequence[torch.optim.Optimizer], |
|
schedulers: Sequence[Optional[AbsScheduler]], |
|
train_iter_factory: AbsIterFactory, |
|
valid_iter_factory: AbsIterFactory, |
|
plot_attention_iter_factory: Optional[AbsIterFactory], |
|
trainer_options, |
|
distributed_option: DistributedOption, |
|
) -> None: |
|
"""Perform training. This method performs the main process of training.""" |
|
assert check_argument_types() |
|
|
|
assert is_dataclass(trainer_options), type(trainer_options) |
|
assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers)) |
|
|
|
if isinstance(trainer_options.keep_nbest_models, int): |
|
keep_nbest_models = trainer_options.keep_nbest_models |
|
else: |
|
if len(trainer_options.keep_nbest_models) == 0: |
|
logging.warning("No keep_nbest_models is given. Change to [1]") |
|
trainer_options.keep_nbest_models = [1] |
|
keep_nbest_models = max(trainer_options.keep_nbest_models) |
|
|
|
output_dir = Path(trainer_options.output_dir) |
|
reporter = Reporter() |
|
if trainer_options.use_amp: |
|
if LooseVersion(torch.__version__) < LooseVersion("1.6.0"): |
|
raise RuntimeError( |
|
"Require torch>=1.6.0 for Automatic Mixed Precision" |
|
) |
|
if trainer_options.sharded_ddp: |
|
if fairscale is None: |
|
raise RuntimeError( |
|
"Requiring fairscale. Do 'pip install fairscale'" |
|
) |
|
scaler = fairscale.optim.grad_scaler.ShardedGradScaler() |
|
else: |
|
scaler = GradScaler() |
|
else: |
|
scaler = None |
|
|
|
if trainer_options.resume and (output_dir / "checkpoint.pth").exists(): |
|
cls.resume( |
|
checkpoint=output_dir / "checkpoint.pth", |
|
model=model, |
|
optimizers=optimizers, |
|
schedulers=schedulers, |
|
reporter=reporter, |
|
scaler=scaler, |
|
ngpu=trainer_options.ngpu, |
|
) |
|
|
|
start_epoch = reporter.get_epoch() + 1 |
|
if start_epoch == trainer_options.max_epoch + 1: |
|
logging.warning( |
|
f"The training has already reached at max_epoch: {start_epoch}" |
|
) |
|
|
|
if distributed_option.distributed: |
|
if trainer_options.sharded_ddp: |
|
dp_model = fairscale.nn.data_parallel.ShardedDataParallel( |
|
module=model, |
|
sharded_optimizer=optimizers, |
|
) |
|
else: |
|
dp_model = torch.nn.parallel.DistributedDataParallel( |
|
model, |
|
device_ids=( |
|
|
|
[torch.cuda.current_device()] |
|
if distributed_option.ngpu == 1 |
|
|
|
else None |
|
), |
|
output_device=( |
|
torch.cuda.current_device() |
|
if distributed_option.ngpu == 1 |
|
else None |
|
), |
|
find_unused_parameters=trainer_options.unused_parameters, |
|
) |
|
elif distributed_option.ngpu > 1: |
|
dp_model = torch.nn.parallel.DataParallel( |
|
model, |
|
device_ids=list(range(distributed_option.ngpu)), |
|
) |
|
else: |
|
|
|
|
|
dp_model = model |
|
|
|
if trainer_options.use_tensorboard and ( |
|
not distributed_option.distributed or distributed_option.dist_rank == 0 |
|
): |
|
summary_writer = SummaryWriter(str(output_dir / "tensorboard")) |
|
else: |
|
summary_writer = None |
|
|
|
start_time = time.perf_counter() |
|
for iepoch in range(start_epoch, trainer_options.max_epoch + 1): |
|
if iepoch != start_epoch: |
|
logging.info( |
|
"{}/{}epoch started. Estimated time to finish: {}".format( |
|
iepoch, |
|
trainer_options.max_epoch, |
|
humanfriendly.format_timespan( |
|
(time.perf_counter() - start_time) |
|
/ (iepoch - start_epoch) |
|
* (trainer_options.max_epoch - iepoch + 1) |
|
), |
|
) |
|
) |
|
else: |
|
logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started") |
|
set_all_random_seed(trainer_options.seed + iepoch) |
|
|
|
reporter.set_epoch(iepoch) |
|
|
|
with reporter.observe("train") as sub_reporter: |
|
all_steps_are_invalid = cls.train_one_epoch( |
|
model=dp_model, |
|
optimizers=optimizers, |
|
schedulers=schedulers, |
|
iterator=train_iter_factory.build_iter(iepoch), |
|
reporter=sub_reporter, |
|
scaler=scaler, |
|
summary_writer=summary_writer, |
|
options=trainer_options, |
|
distributed_option=distributed_option, |
|
) |
|
|
|
with reporter.observe("valid") as sub_reporter: |
|
cls.validate_one_epoch( |
|
model=dp_model, |
|
iterator=valid_iter_factory.build_iter(iepoch), |
|
reporter=sub_reporter, |
|
options=trainer_options, |
|
distributed_option=distributed_option, |
|
) |
|
|
|
if not distributed_option.distributed or distributed_option.dist_rank == 0: |
|
|
|
if plot_attention_iter_factory is not None: |
|
with reporter.observe("att_plot") as sub_reporter: |
|
cls.plot_attention( |
|
model=model, |
|
output_dir=output_dir / "att_ws", |
|
summary_writer=summary_writer, |
|
iterator=plot_attention_iter_factory.build_iter(iepoch), |
|
reporter=sub_reporter, |
|
options=trainer_options, |
|
) |
|
|
|
|
|
for scheduler in schedulers: |
|
if isinstance(scheduler, AbsValEpochStepScheduler): |
|
scheduler.step( |
|
reporter.get_value(*trainer_options.val_scheduler_criterion) |
|
) |
|
elif isinstance(scheduler, AbsEpochStepScheduler): |
|
scheduler.step() |
|
if trainer_options.sharded_ddp: |
|
for optimizer in optimizers: |
|
if isinstance(optimizer, fairscale.optim.oss.OSS): |
|
optimizer.consolidate_state_dict() |
|
|
|
if not distributed_option.distributed or distributed_option.dist_rank == 0: |
|
|
|
logging.info(reporter.log_message()) |
|
reporter.matplotlib_plot(output_dir / "images") |
|
if summary_writer is not None: |
|
reporter.tensorboard_add_scalar(summary_writer) |
|
if trainer_options.use_wandb: |
|
reporter.wandb_log() |
|
|
|
|
|
torch.save( |
|
{ |
|
"model": model.state_dict(), |
|
"reporter": reporter.state_dict(), |
|
"optimizers": [o.state_dict() for o in optimizers], |
|
"schedulers": [ |
|
s.state_dict() if s is not None else None |
|
for s in schedulers |
|
], |
|
"scaler": scaler.state_dict() if scaler is not None else None, |
|
}, |
|
output_dir / "checkpoint.pth", |
|
) |
|
|
|
|
|
torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pth") |
|
|
|
|
|
p = output_dir / "latest.pth" |
|
if p.is_symlink() or p.exists(): |
|
p.unlink() |
|
p.symlink_to(f"{iepoch}epoch.pth") |
|
|
|
_improved = [] |
|
for _phase, k, _mode in trainer_options.best_model_criterion: |
|
|
|
if reporter.has(_phase, k): |
|
best_epoch = reporter.get_best_epoch(_phase, k, _mode) |
|
|
|
if best_epoch == iepoch: |
|
p = output_dir / f"{_phase}.{k}.best.pth" |
|
if p.is_symlink() or p.exists(): |
|
p.unlink() |
|
p.symlink_to(f"{iepoch}epoch.pth") |
|
_improved.append(f"{_phase}.{k}") |
|
if len(_improved) == 0: |
|
logging.info("There are no improvements in this epoch") |
|
else: |
|
logging.info( |
|
"The best model has been updated: " + ", ".join(_improved) |
|
) |
|
|
|
|
|
_removed = [] |
|
|
|
nbests = set().union( |
|
*[ |
|
set(reporter.sort_epochs(ph, k, m)[:keep_nbest_models]) |
|
for ph, k, m in trainer_options.best_model_criterion |
|
if reporter.has(ph, k) |
|
] |
|
) |
|
for e in range(1, iepoch): |
|
p = output_dir / f"{e}epoch.pth" |
|
if p.exists() and e not in nbests: |
|
p.unlink() |
|
_removed.append(str(p)) |
|
if len(_removed) != 0: |
|
logging.info("The model files were removed: " + ", ".join(_removed)) |
|
|
|
|
|
if all_steps_are_invalid: |
|
logging.warning( |
|
f"The gradients at all steps are invalid in this epoch. " |
|
f"Something seems wrong. This training was stopped at {iepoch}epoch" |
|
) |
|
break |
|
|
|
|
|
if trainer_options.patience is not None: |
|
if reporter.check_early_stopping( |
|
trainer_options.patience, *trainer_options.early_stopping_criterion |
|
): |
|
break |
|
|
|
else: |
|
logging.info( |
|
f"The training was finished at {trainer_options.max_epoch} epochs " |
|
) |
|
|
|
if not distributed_option.distributed or distributed_option.dist_rank == 0: |
|
|
|
average_nbest_models( |
|
reporter=reporter, |
|
output_dir=output_dir, |
|
best_model_criterion=trainer_options.best_model_criterion, |
|
nbest=keep_nbest_models, |
|
) |
|
|
|
@classmethod |
|
def train_one_epoch( |
|
cls, |
|
model: torch.nn.Module, |
|
iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], |
|
optimizers: Sequence[torch.optim.Optimizer], |
|
schedulers: Sequence[Optional[AbsScheduler]], |
|
scaler: Optional[GradScaler], |
|
reporter: SubReporter, |
|
summary_writer: Optional[SummaryWriter], |
|
options: TrainerOptions, |
|
distributed_option: DistributedOption, |
|
) -> bool: |
|
assert check_argument_types() |
|
|
|
grad_noise = options.grad_noise |
|
accum_grad = options.accum_grad |
|
grad_clip = options.grad_clip |
|
grad_clip_type = options.grad_clip_type |
|
log_interval = options.log_interval |
|
no_forward_run = options.no_forward_run |
|
ngpu = options.ngpu |
|
use_wandb = options.use_wandb |
|
distributed = distributed_option.distributed |
|
|
|
if log_interval is None: |
|
try: |
|
log_interval = max(len(iterator) // 20, 10) |
|
except TypeError: |
|
log_interval = 100 |
|
|
|
model.train() |
|
all_steps_are_invalid = True |
|
|
|
|
|
iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") |
|
|
|
start_time = time.perf_counter() |
|
for iiter, (_, batch) in enumerate( |
|
reporter.measure_iter_time(iterator, "iter_time"), 1 |
|
): |
|
assert isinstance(batch, dict), type(batch) |
|
|
|
if distributed: |
|
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
|
if iterator_stop > 0: |
|
break |
|
|
|
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") |
|
if no_forward_run: |
|
all_steps_are_invalid = False |
|
continue |
|
|
|
with autocast(scaler is not None): |
|
with reporter.measure_time("forward_time"): |
|
retval = model(**batch) |
|
|
|
|
|
|
|
|
|
if isinstance(retval, dict): |
|
loss = retval["loss"] |
|
stats = retval["stats"] |
|
weight = retval["weight"] |
|
optim_idx = retval.get("optim_idx") |
|
if optim_idx is not None and not isinstance(optim_idx, int): |
|
if not isinstance(optim_idx, torch.Tensor): |
|
raise RuntimeError( |
|
"optim_idx must be int or 1dim torch.Tensor, " |
|
f"but got {type(optim_idx)}" |
|
) |
|
if optim_idx.dim() >= 2: |
|
raise RuntimeError( |
|
"optim_idx must be int or 1dim torch.Tensor, " |
|
f"but got {optim_idx.dim()}dim tensor" |
|
) |
|
if optim_idx.dim() == 1: |
|
for v in optim_idx: |
|
if v != optim_idx[0]: |
|
raise RuntimeError( |
|
"optim_idx must be 1dim tensor " |
|
"having same values for all entries" |
|
) |
|
optim_idx = optim_idx[0].item() |
|
else: |
|
optim_idx = optim_idx.item() |
|
|
|
|
|
else: |
|
loss, stats, weight = retval |
|
optim_idx = None |
|
|
|
stats = {k: v for k, v in stats.items() if v is not None} |
|
if ngpu > 1 or distributed: |
|
|
|
loss = (loss * weight.type(loss.dtype)).sum() |
|
|
|
|
|
stats, weight = recursive_average(stats, weight, distributed) |
|
|
|
|
|
loss /= weight |
|
if distributed: |
|
|
|
|
|
loss *= torch.distributed.get_world_size() |
|
|
|
loss /= accum_grad |
|
|
|
reporter.register(stats, weight) |
|
|
|
with reporter.measure_time("backward_time"): |
|
if scaler is not None: |
|
|
|
|
|
|
|
|
|
|
|
scaler.scale(loss).backward() |
|
else: |
|
loss.backward() |
|
|
|
if iiter % accum_grad == 0: |
|
if scaler is not None: |
|
|
|
for iopt, optimizer in enumerate(optimizers): |
|
if optim_idx is not None and iopt != optim_idx: |
|
continue |
|
scaler.unscale_(optimizer) |
|
|
|
|
|
if grad_noise: |
|
add_gradient_noise( |
|
model, |
|
reporter.get_total_count(), |
|
duration=100, |
|
eta=1.0, |
|
scale_factor=0.55, |
|
) |
|
|
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_( |
|
model.parameters(), |
|
max_norm=grad_clip, |
|
norm_type=grad_clip_type, |
|
) |
|
|
|
if not isinstance(grad_norm, torch.Tensor): |
|
grad_norm = torch.tensor(grad_norm) |
|
|
|
if not torch.isfinite(grad_norm): |
|
logging.warning( |
|
f"The grad norm is {grad_norm}. Skipping updating the model." |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if scaler is not None: |
|
for iopt, optimizer in enumerate(optimizers): |
|
if optim_idx is not None and iopt != optim_idx: |
|
continue |
|
scaler.step(optimizer) |
|
scaler.update() |
|
|
|
else: |
|
all_steps_are_invalid = False |
|
with reporter.measure_time("optim_step_time"): |
|
for iopt, (optimizer, scheduler) in enumerate( |
|
zip(optimizers, schedulers) |
|
): |
|
if optim_idx is not None and iopt != optim_idx: |
|
continue |
|
if scaler is not None: |
|
|
|
|
|
scaler.step(optimizer) |
|
|
|
scaler.update() |
|
else: |
|
optimizer.step() |
|
if isinstance(scheduler, AbsBatchStepScheduler): |
|
scheduler.step() |
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
reporter.register( |
|
dict( |
|
{ |
|
f"optim{i}_lr{j}": pg["lr"] |
|
for i, optimizer in enumerate(optimizers) |
|
for j, pg in enumerate(optimizer.param_groups) |
|
if "lr" in pg |
|
}, |
|
train_time=time.perf_counter() - start_time, |
|
), |
|
) |
|
start_time = time.perf_counter() |
|
|
|
|
|
reporter.next() |
|
if iiter % log_interval == 0: |
|
logging.info(reporter.log_message(-log_interval)) |
|
if summary_writer is not None: |
|
reporter.tensorboard_add_scalar(summary_writer, -log_interval) |
|
if use_wandb: |
|
reporter.wandb_log() |
|
|
|
else: |
|
if distributed: |
|
iterator_stop.fill_(1) |
|
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
|
|
|
return all_steps_are_invalid |
|
|
|
@classmethod |
|
@torch.no_grad() |
|
def validate_one_epoch( |
|
cls, |
|
model: torch.nn.Module, |
|
iterator: Iterable[Dict[str, torch.Tensor]], |
|
reporter: SubReporter, |
|
options: TrainerOptions, |
|
distributed_option: DistributedOption, |
|
) -> None: |
|
assert check_argument_types() |
|
ngpu = options.ngpu |
|
no_forward_run = options.no_forward_run |
|
distributed = distributed_option.distributed |
|
|
|
model.eval() |
|
|
|
|
|
|
|
iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") |
|
for (_, batch) in iterator: |
|
assert isinstance(batch, dict), type(batch) |
|
if distributed: |
|
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
|
if iterator_stop > 0: |
|
break |
|
|
|
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") |
|
if no_forward_run: |
|
continue |
|
|
|
retval = model(**batch) |
|
if isinstance(retval, dict): |
|
stats = retval["stats"] |
|
weight = retval["weight"] |
|
else: |
|
_, stats, weight = retval |
|
if ngpu > 1 or distributed: |
|
|
|
|
|
stats, weight = recursive_average(stats, weight, distributed) |
|
|
|
reporter.register(stats, weight) |
|
reporter.next() |
|
|
|
else: |
|
if distributed: |
|
iterator_stop.fill_(1) |
|
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
|
|
|
@classmethod |
|
@torch.no_grad() |
|
def plot_attention( |
|
cls, |
|
model: torch.nn.Module, |
|
output_dir: Optional[Path], |
|
summary_writer: Optional[SummaryWriter], |
|
iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], |
|
reporter: SubReporter, |
|
options: TrainerOptions, |
|
) -> None: |
|
assert check_argument_types() |
|
import matplotlib |
|
|
|
ngpu = options.ngpu |
|
no_forward_run = options.no_forward_run |
|
|
|
matplotlib.use("Agg") |
|
import matplotlib.pyplot as plt |
|
from matplotlib.ticker import MaxNLocator |
|
|
|
model.eval() |
|
for ids, batch in iterator: |
|
assert isinstance(batch, dict), type(batch) |
|
assert len(next(iter(batch.values()))) == len(ids), ( |
|
len(next(iter(batch.values()))), |
|
len(ids), |
|
) |
|
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") |
|
if no_forward_run: |
|
continue |
|
|
|
|
|
|
|
att_dict = calculate_all_attentions(model, batch) |
|
|
|
|
|
for k, att_list in att_dict.items(): |
|
assert len(att_list) == len(ids), (len(att_list), len(ids)) |
|
for id_, att_w in zip(ids, att_list): |
|
|
|
if isinstance(att_w, torch.Tensor): |
|
att_w = att_w.detach().cpu().numpy() |
|
|
|
if att_w.ndim == 2: |
|
att_w = att_w[None] |
|
elif att_w.ndim > 3 or att_w.ndim == 1: |
|
raise RuntimeError(f"Must be 2 or 3 dimension: {att_w.ndim}") |
|
|
|
w, h = plt.figaspect(1.0 / len(att_w)) |
|
fig = plt.Figure(figsize=(w * 1.3, h * 1.3)) |
|
axes = fig.subplots(1, len(att_w)) |
|
if len(att_w) == 1: |
|
axes = [axes] |
|
|
|
for ax, aw in zip(axes, att_w): |
|
ax.imshow(aw.astype(np.float32), aspect="auto") |
|
ax.set_title(f"{k}_{id_}") |
|
ax.set_xlabel("Input") |
|
ax.set_ylabel("Output") |
|
ax.xaxis.set_major_locator(MaxNLocator(integer=True)) |
|
ax.yaxis.set_major_locator(MaxNLocator(integer=True)) |
|
|
|
if output_dir is not None: |
|
p = output_dir / id_ / f"{k}.{reporter.get_epoch()}ep.png" |
|
p.parent.mkdir(parents=True, exist_ok=True) |
|
fig.savefig(p) |
|
|
|
if summary_writer is not None: |
|
summary_writer.add_figure( |
|
f"{k}_{id_}", fig, reporter.get_epoch() |
|
) |
|
reporter.next() |
|
|