File size: 2,089 Bytes
546a9ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os


class BaseTrainer:
    def __init__(self, opt):
        self.opt = opt
        if self.opt["cuda"] == True:
            self.use_cuda = True
            print("Using Cuda\n")
        else:
            self.use_cuda = False
            print("Using CPU\n")

        self.is_official = "OFFICIAL" in self.opt
        self.opt["logFile"] = "log.txt"
        self.saveFolder = None
        self.logFileHandle = None
        self.tb_writer = None

    def log(self, s):
        # In official case, the program does not output logs
        if self.is_official:
            return
        try:
            if self.logFileHandle is None:
                self.logFileHandle = open(
                    os.path.join(self.saveFolder, self.opt["logFile"]), "a"
                )
            self.logFileHandle.write(s + "\n")
        except Exception as e:
            print("ERROR while writing log file:", e)
        print(s)

    def getSaveFolder(self):
        runid = 1
        while True:
            saveFolder = os.path.join(
                self.opt["datadir"],
                self.opt["basename"] + "_conf~",
                "run_" + str(runid),
            )
            if not os.path.exists(saveFolder):
                self.saveFolder = saveFolder
                os.makedirs(self.saveFolder)
                print("Saving logs, model and evaluation in " + self.saveFolder)
                return
            runid = runid + 1

    # save copy of conf file
    def saveConf(self):
        # with open(self.opt['confFile'], encoding='utf-8') as f:
        #     with open(os.path.join(self.saveFolder, 'conf_copy.tsv'), 'w', encoding='utf-8') as fw:
        #         for line in f:
        #             fw.write(line)
        with open(
            os.path.join(self.saveFolder, "conf_copy.tsv"), "w", encoding="utf-8"
        ) as fw:
            for k in self.opt:
                fw.write("{0}\t{1}\n".format(k, self.opt[k]))

    def train(self):
        pass

    def load(self):
        pass