PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
d28af7f verified
raw
history blame
3.79 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from mmpt.utils import recursive_config
class BaseJob(object):
def __init__(self, yaml_file, dryrun=False):
self.yaml_file = yaml_file
self.config = recursive_config(yaml_file)
self.dryrun = dryrun
def submit(self, **kwargs):
raise NotImplementedError
def _normalize_cmd(self, cmd_list):
cmd_list = list(cmd_list)
yaml_index = cmd_list.index("[yaml]")
cmd_list[yaml_index] = self.yaml_file
return cmd_list
class LocalJob(BaseJob):
CMD_CONFIG = {
"local_single": [
"fairseq-train", "[yaml]", "--user-dir", "mmpt",
"--task", "mmtask", "--arch", "mmarch",
"--criterion", "mmloss",
],
"local_small": [
"fairseq-train", "[yaml]", "--user-dir", "mmpt",
"--task", "mmtask", "--arch", "mmarch",
"--criterion", "mmloss",
"--distributed-world-size", "2"
],
"local_big": [
"fairseq-train", "[yaml]", "--user-dir", "mmpt",
"--task", "mmtask", "--arch", "mmarch",
"--criterion", "mmloss",
"--distributed-world-size", "8"
],
"local_predict": ["python", "mmpt_cli/predict.py", "[yaml]"],
}
def __init__(self, yaml_file, job_type=None, dryrun=False):
super().__init__(yaml_file, dryrun)
if job_type is None:
self.job_type = "local_single"
if self.config.task_type is not None:
self.job_type = self.config.task_type
else:
self.job_type = job_type
if self.job_type in ["local_single", "local_small"]:
if self.config.fairseq.dataset.batch_size > 32:
print("decreasing batch_size to 32 for local testing?")
def submit(self):
cmd_list = self._normalize_cmd(LocalJob.CMD_CONFIG[self.job_type])
if "predict" not in self.job_type:
# append fairseq args.
from mmpt.utils import load_config
config = load_config(config_file=self.yaml_file)
for field in config.fairseq:
for key in config.fairseq[field]:
if key in ["fp16", "reset_optimizer", "reset_dataloader", "reset_meters"]: # a list of binary flag.
param = ["--" + key.replace("_", "-")]
else:
if key == "lr":
value = str(config.fairseq[field][key][0])
elif key == "adam_betas":
value = "'"+str(config.fairseq[field][key])+"'"
else:
value = str(config.fairseq[field][key])
param = [
"--" + key.replace("_", "-"),
value
]
cmd_list.extend(param)
print("launching", " ".join(cmd_list))
if not self.dryrun:
os.system(" ".join(cmd_list))
return JobStatus("12345678")
class JobStatus(object):
def __init__(self, job_id):
self.job_id = job_id
def __repr__(self):
return self.job_id
def __str__(self):
return self.job_id
def done(self):
return False
def running(self):
return False
def result(self):
if self.done():
return "{} is done.".format(self.job_id)
else:
return "{} is running.".format(self.job_id)
def stderr(self):
return self.result()
def stdout(self):
return self.result()