File size: 7,233 Bytes
edcf5ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Training Engine   Script  ver: Feb 8th 16:00

Based on MAE code.
https://github.com/facebookresearch/mae

"""

import math
import sys
from typing import Iterable
import os
import torch
from torchvision.transforms import ToPILImage
import SSL_structures.misc as misc
import utils.schedulers as lr_sched
from utils.visual_usage import unpatchify, patchify, Draw_tri_fig


def train_one_epoch(model: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, fix_position_ratio_scheduler=None,
                    puzzle_patch_size_scheduler=None, check_samples=1, print_freq=20, log_writer=None, args=None):
    model.train(True)

    # update logger
    metric_logger = misc.MetricLogger(delimiter="  ")
    # 初始化学习率记录
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))

    header = 'Epoch: [{}]'.format(epoch)

    accum_iter = args.accum_iter

    optimizer.zero_grad()

    if log_writer is not None:  # Tensorboard PATH
        print('log_dir: {}'.format(args.log_dir))

    # Iteration
    for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):

        # per iteration lr scheduler基于中间epoch位置
        # 来实现更精确的调节学习率:data_iter_step / len(data_loader) + epoch
        if data_iter_step % accum_iter == 0:
            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)

        # 拿数据
        samples = samples.to(device, non_blocking=True)

        with torch.cuda.amp.autocast():  # 使用自动混合精度加速训练

            if fix_position_ratio_scheduler is not None and puzzle_patch_size_scheduler is not None:  # SAE
                fix_position_ratio = fix_position_ratio_scheduler(epoch)
                puzzle_patch_size = puzzle_patch_size_scheduler(epoch)
            else:
                fix_position_ratio = None
                puzzle_patch_size = None

            if args.model[0:3] == 'sae':
                loss, pred, imgs_puzzled_patches = model(samples, fix_position_ratio=fix_position_ratio,
                                                         puzzle_patch_size=puzzle_patch_size)  # SAE
            else:  # args.model[0:3] == 'mae'
                loss, pred, mask_patch_indicators = model(samples, mask_ratio=args.mask_ratio)  # MAE
                # fixme mae curriculum maybe not good enough for future
        if args.DDP_distributed:
            loss_value = loss.item()
        else:
            loss_value = float(loss.cpu().detach().numpy()) \
                if torch.cuda.device_count() == 1 else sum(loss.cpu().detach().numpy())

        if not math.isfinite(loss_value):  # 检查确保没有loss爆炸
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        loss = loss / accum_iter  # 计算的是每个minibatch的loss,如果有梯度累加则需要减少占比,loss在loss_scaler里面会进行叠加

        # loss backward 核心(不要怕,其实就是功能上集成了loss.backward+opt.step,然后引入了梯度裁剪)
        loss_scaler(loss, optimizer, parameters=model.parameters(),
                    update_grad=(data_iter_step + 1) % accum_iter == 0)

        if (data_iter_step + 1) % accum_iter == 0:
            optimizer.zero_grad()

        torch.cuda.synchronize()  # 等待当前设备上所有流中的所有核心完成

        # 更新记录
        metric_logger.update(loss=loss_value)
        lr = optimizer.param_groups[0]["lr"]
        metric_logger.update(lr=lr)

        # 计算平均在单卡上的loss
        loss_value_reduce = misc.all_reduce_mean(loss_value)

    if log_writer is not None:
        log_writer.add_scalar('train_loss', loss_value_reduce, epoch)
        log_writer.add_scalar('lr', lr, epoch)

        if fix_position_ratio is not None and puzzle_patch_size is not None:
            log_writer.add_scalar('puzzle_patch_size', puzzle_patch_size, epoch)
            log_writer.add_scalar('fix_position_ratio', fix_position_ratio, epoch)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    if fix_position_ratio is not None and puzzle_patch_size is not None:
        print("Averaged stats:", metric_logger, 'fix_position_ratio:', fix_position_ratio,
              '  puzzle_patch_size:', puzzle_patch_size)
    else:
        print("Averaged stats:", metric_logger)

    # TODO: currently, only paint at the end of each epoch Train,
    if args.model[0:3] == 'sae':
        imgs_puzzled_batch = unpatchify(imgs_puzzled_patches, patch_size=16)
    else:  # MAE
        sample_img_patches = patchify(samples, patch_size=16)  # on GPU
        masked_img_patches = sample_img_patches * \
                             mask_patch_indicators.unsqueeze(-1).expand(-1, -1,
                                                                        sample_img_patches.shape[-1])
        masked_img_batch = unpatchify(masked_img_patches, patch_size=16)

    # paint images at the end of each epoch on main process
    if misc.is_main_process():
        for sampleIDX in range(check_samples):

            sample_img = samples.cpu()[sampleIDX]
            sample_img = ToPILImage()(sample_img)
            sample_img.save(os.path.join(args.output_dir, 'figs', 'sample_e_' + str(epoch)
                                         + '_sampleIDX_' + str(sampleIDX) + '.jpg'))

            recons_img_batch = unpatchify(pred, patch_size=16)
            recons_img = recons_img_batch.cpu()[sampleIDX]
            recons_img = ToPILImage()(recons_img)
            recons_img.save(os.path.join(args.output_dir, 'figs', 'recons_e_' + str(epoch)
                                         + '_sampleIDX_' + str(sampleIDX) + '.jpg'))

            if args.model[0:3] == 'sae':  # SAE
                puzzled_img = imgs_puzzled_batch.cpu()[sampleIDX]
                puzzled_img = ToPILImage()(puzzled_img)
                puzzled_img.save(os.path.join(args.output_dir, 'figs', 'puzzled_e_' + str(epoch)
                                              + '_sampleIDX_' + str(sampleIDX) + '.jpg'))

                picpath = os.path.join(args.output_dir, 'figs', 'puzzled_e_' + str(epoch)
                                       + '_sampleIDX_' + str(sampleIDX) + '.jpg')
                Draw_tri_fig(sample_img, puzzled_img, recons_img, picpath)

            else:  # MAE
                masked_img = masked_img_batch.cpu()[sampleIDX]  # put on CPU
                masked_img = ToPILImage()(masked_img)
                masked_img.save(os.path.join(args.output_dir, 'figs', 'masked_e_' + str(epoch)
                                             + '_sampleIDX_' + str(sampleIDX) + '.jpg'))

                picpath = os.path.join(args.output_dir, 'figs', 'masked_e_' + str(epoch)
                                       + '_sampleIDX_' + str(sampleIDX) + '.jpg')
                Draw_tri_fig(sample_img, masked_img, recons_img, picpath)

    # 返回记录,其他的已经在对象内迭代
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}