Spaces:
Build error
Build error
File size: 2,089 Bytes
546a9ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
class BaseTrainer:
def __init__(self, opt):
self.opt = opt
if self.opt["cuda"] == True:
self.use_cuda = True
print("Using Cuda\n")
else:
self.use_cuda = False
print("Using CPU\n")
self.is_official = "OFFICIAL" in self.opt
self.opt["logFile"] = "log.txt"
self.saveFolder = None
self.logFileHandle = None
self.tb_writer = None
def log(self, s):
# In official case, the program does not output logs
if self.is_official:
return
try:
if self.logFileHandle is None:
self.logFileHandle = open(
os.path.join(self.saveFolder, self.opt["logFile"]), "a"
)
self.logFileHandle.write(s + "\n")
except Exception as e:
print("ERROR while writing log file:", e)
print(s)
def getSaveFolder(self):
runid = 1
while True:
saveFolder = os.path.join(
self.opt["datadir"],
self.opt["basename"] + "_conf~",
"run_" + str(runid),
)
if not os.path.exists(saveFolder):
self.saveFolder = saveFolder
os.makedirs(self.saveFolder)
print("Saving logs, model and evaluation in " + self.saveFolder)
return
runid = runid + 1
# save copy of conf file
def saveConf(self):
# with open(self.opt['confFile'], encoding='utf-8') as f:
# with open(os.path.join(self.saveFolder, 'conf_copy.tsv'), 'w', encoding='utf-8') as fw:
# for line in f:
# fw.write(line)
with open(
os.path.join(self.saveFolder, "conf_copy.tsv"), "w", encoding="utf-8"
) as fw:
for k in self.opt:
fw.write("{0}\t{1}\n".format(k, self.opt[k]))
def train(self):
pass
def load(self):
pass
|