File size: 1,219 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import chainer
import logging


def check_early_stop(trainer, epochs):
    """Checks an early stopping trigger and warns the user if it's the case

    :param trainer: The trainer used for training
    :param epochs: The maximum number of epochs
    """
    end_epoch = trainer.updater.get_iterator("main").epoch
    if end_epoch < (epochs - 1):
        logging.warning(
            "Hit early stop at epoch "
            + str(end_epoch)
            + "\nYou can change the patience or set it to 0 to run all epochs"
        )


def set_early_stop(trainer, args, is_lm=False):
    """Sets the early stop trigger given the program arguments

    :param trainer: The trainer used for training
    :param args: The program arguments
    :param is_lm: If the trainer is for a LM (epoch instead of epochs)
    """
    patience = args.patience
    criterion = args.early_stop_criterion
    epochs = args.epoch if is_lm else args.epochs
    mode = "max" if "acc" in criterion else "min"
    if patience > 0:
        trainer.stop_trigger = chainer.training.triggers.EarlyStoppingTrigger(
            monitor=criterion,
            mode=mode,
            patients=patience,
            max_trigger=(epochs, "epoch"),
        )