File size: 981 Bytes
9457143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import diffusion
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
    LearningRateMonitor
)


class ModelCallback:
    def __init__(
        self,
        root_path: str,
        ckpt_monitor: str = "val_loss",
        ckpt_mode: str = "min",
    ):
        ckpt_path = os.path.join(os.path.join(root_path, "model/"))
        if not os.path.exists(root_path):
            os.makedirs(root_path)
        if not os.path.exists(ckpt_path):
            os.makedirs(ckpt_path)

        self.ckpt_callback = ModelCheckpoint(
            monitor=ckpt_monitor,
            dirpath=ckpt_path,
            filename="model",
            save_top_k=1,
            mode=ckpt_mode,
            save_weights_only=True
        )

        self.lr_callback = LearningRateMonitor("step")

        self.ema_callback = diffusion.EMACallback(decay=0.995)

    def get_callback(self):
        return [
            self.ckpt_callback, self.lr_callback, self.ema_callback
        ]