File size: 5,300 Bytes
546a9ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import torch
from torch.utils.tensorboard import SummaryWriter
import random
import numpy as np

from pkg_resources import parse_version
from model.third_party.HMNet.Models.Trainers.BaseTrainer import BaseTrainer
from model.third_party.HMNet.Utils.GeneralUtils import bcolors
from model.third_party.HMNet.Utils.distributed import distributed


class DistributedTrainer(BaseTrainer):
    def __init__(self, opt):
        super().__init__(opt)

        self.seed = int(self.opt["SEED"]) if "SEED" in self.opt else 0

        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)

        (
            self.opt["device"],
            _,
            self.opt["world_size"],
            self.opt["local_size"],
            self.opt["rank"],
            self.opt["local_rank"],
            _,
            self.opt["run"],
        ) = distributed(opt, not self.use_cuda)

        self.getSaveFolder()
        self.opt["logFile"] = f"log_{self.opt['rank']}.txt"
        self.saveConf()

        self.high_pytorch_version = parse_version(torch.__version__) >= parse_version(
            "1.2.0"
        )
        if self.opt["rank"] == 0:
            print(
                bcolors.OKGREEN,
                torch.__version__,
                bcolors.ENDC,
                "is",
                "high" if self.high_pytorch_version else "low",
            )

        if self.use_cuda:
            # torch.cuda.manual_seed_all(self.seed)
            # ddp: only set seed on GPU associated with this process
            torch.cuda.manual_seed(self.seed)

        # ddp: print stats and update learning rate
        if self.opt["rank"] == 0:
            print(
                "Number of GPUs is",
                bcolors.OKGREEN,
                self.opt["world_size"],
                bcolors.ENDC,
            )
            # print('Boost learning rate from', bcolors.OKGREEN, self.opt['START_LEARNING_RATE'], bcolors.ENDC, 'to',
            #     bcolors.OKGREEN, self.opt['START_LEARNING_RATE'] * self.opt['world_size'], bcolors.ENDC)
            print(
                "Effective batch size is increased from",
                bcolors.OKGREEN,
                self.opt["MINI_BATCH"],
                bcolors.ENDC,
                "to",
                bcolors.OKGREEN,
                self.opt["MINI_BATCH"] * self.opt["world_size"],
                bcolors.ENDC,
            )

        self.grad_acc_steps = 1
        if "GRADIENT_ACCUMULATE_STEP" in self.opt:
            if self.opt["rank"] == 0:
                print(
                    "Gradient accumulation steps =",
                    bcolors.OKGREEN,
                    self.opt["GRADIENT_ACCUMULATE_STEP"],
                    bcolors.ENDC,
                )
                # print('Boost learning rate from', bcolors.OKGREEN, self.opt['START_LEARNING_RATE'], bcolors.ENDC, 'to',
                # bcolors.OKGREEN, self.opt['START_LEARNING_RATE'] * self.opt['world_size'] * self.opt['GRADIENT_ACCUMULATE_STEP'], bcolors.ENDC)
                print(
                    "Effective batch size =",
                    bcolors.OKGREEN,
                    self.opt["MINI_BATCH"]
                    * self.opt["world_size"]
                    * self.opt["GRADIENT_ACCUMULATE_STEP"],
                    bcolors.ENDC,
                )
            self.grad_acc_steps = int(self.opt["GRADIENT_ACCUMULATE_STEP"])
        # self.opt['START_LEARNING_RATE'] *= self.opt['world_size'] * self.grad_acc_steps

    def tb_log_scalar(self, name, value, step):
        if self.opt["rank"] == 0:
            if self.tb_writer is None:
                self.tb_writer = SummaryWriter(
                    os.path.join(self.saveFolder, "tensorboard")
                )
            self.tb_writer.add_scalar(name, value, step)

    def log(self, s):
        # When 'OFFICIAL' flag is set in the config file, the program does not output logs
        if self.is_official:
            return
        try:
            if self.logFileHandle is None:
                self.logFileHandle = open(
                    os.path.join(self.saveFolder, self.opt["logFile"]), "a"
                )
            self.logFileHandle.write(s + "\n")
        except Exception as e:
            print("ERROR while writing log file:", e)
        print(s)

    def getSaveFolder(self):
        runid = 1
        while True:
            saveFolder = os.path.join(
                self.opt["datadir"],
                self.opt["basename"] + "_conf~",
                "run_" + str(runid),
            )
            if not os.path.isdir(saveFolder):
                if self.opt["world_size"] > 1:
                    torch.distributed.barrier()
                if self.opt["rank"] == 0:
                    os.makedirs(saveFolder)
                self.saveFolder = saveFolder
                if self.opt["world_size"] > 1:
                    torch.distributed.barrier()
                print(
                    "Saving logs, model, checkpoint, and evaluation in "
                    + self.saveFolder
                )
                return
            runid = runid + 1

    def saveConf(self):
        if self.opt["rank"] == 0:
            super().saveConf()