akhaliq3
spaces demo
546a9ba
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
class BaseTrainer:
def __init__(self, opt):
self.opt = opt
if self.opt["cuda"] == True:
self.use_cuda = True
print("Using Cuda\n")
else:
self.use_cuda = False
print("Using CPU\n")
self.is_official = "OFFICIAL" in self.opt
self.opt["logFile"] = "log.txt"
self.saveFolder = None
self.logFileHandle = None
self.tb_writer = None
def log(self, s):
# In official case, the program does not output logs
if self.is_official:
return
try:
if self.logFileHandle is None:
self.logFileHandle = open(
os.path.join(self.saveFolder, self.opt["logFile"]), "a"
)
self.logFileHandle.write(s + "\n")
except Exception as e:
print("ERROR while writing log file:", e)
print(s)
def getSaveFolder(self):
runid = 1
while True:
saveFolder = os.path.join(
self.opt["datadir"],
self.opt["basename"] + "_conf~",
"run_" + str(runid),
)
if not os.path.exists(saveFolder):
self.saveFolder = saveFolder
os.makedirs(self.saveFolder)
print("Saving logs, model and evaluation in " + self.saveFolder)
return
runid = runid + 1
# save copy of conf file
def saveConf(self):
# with open(self.opt['confFile'], encoding='utf-8') as f:
# with open(os.path.join(self.saveFolder, 'conf_copy.tsv'), 'w', encoding='utf-8') as fw:
# for line in f:
# fw.write(line)
with open(
os.path.join(self.saveFolder, "conf_copy.tsv"), "w", encoding="utf-8"
) as fw:
for k in self.opt:
fw.write("{0}\t{1}\n".format(k, self.opt[k]))
def train(self):
pass
def load(self):
pass