# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import os from omegaconf import OmegaConf from mmpt.utils import recursive_config, overwrite_dir from mmpt_cli.localjob import LocalJob class JobLauncher(object): JOB_CONFIG = { "local": LocalJob, } def __init__(self, yaml_file): self.yaml_file = yaml_file job_key = "local" if yaml_file.endswith(".yaml"): config = recursive_config(yaml_file) if config.task_type is not None: job_key = config.task_type.split("_")[0] else: raise ValueError("unknown extension of job file:", yaml_file) self.job_key = job_key def __call__(self, job_type=None, dryrun=False): if job_type is not None: self.job_key = job_type.split("_")[0] print("[JobLauncher] job_key", self.job_key) job = JobLauncher.JOB_CONFIG[self.job_key]( self.yaml_file, job_type=job_type, dryrun=dryrun) return job.submit() class Pipeline(object): """a job that loads yaml config.""" def __init__(self, fn): """ load a yaml config of a job and save generated configs as yaml for each task. return: a list of files to run as specified by `run_task`. """ if fn.endswith(".py"): # a python command. self.backend = "python" self.run_yamls = [fn] return job_config = recursive_config(fn) if job_config.base_dir is None: # single file job config. self.run_yamls = [fn] return self.project_dir = os.path.join("projects", job_config.project_dir) self.run_dir = os.path.join("runs", job_config.project_dir) if job_config.run_task is not None: run_yamls = [] for stage in job_config.run_task: # each stage can have multiple tasks running in parallel. if OmegaConf.is_list(stage): stage_yamls = [] for task_file in stage: stage_yamls.append( os.path.join(self.project_dir, task_file)) run_yamls.append(stage_yamls) else: run_yamls.append(os.path.join(self.project_dir, stage)) self.run_yamls = run_yamls configs_to_save = self._overwrite_task(job_config) self._save_configs(configs_to_save) def __getitem__(self, idx): yaml_files = self.run_yamls[idx] if isinstance(yaml_files, list): return [JobLauncher(yaml_file) for yaml_file in yaml_files] return [JobLauncher(yaml_files)] def __len__(self): return len(self.run_yamls) def _save_configs(self, configs_to_save: dict): # save os.makedirs(self.project_dir, exist_ok=True) for config_file in configs_to_save: config = configs_to_save[config_file] print("saving", config_file) OmegaConf.save(config=config, f=config_file) def _overwrite_task(self, job_config): configs_to_save = {} self.base_project_dir = os.path.join("projects", job_config.base_dir) self.base_run_dir = os.path.join("runs", job_config.base_dir) for config_sets in job_config.task_group: overwrite_config = job_config.task_group[config_sets] if ( overwrite_config.task_list is None or len(overwrite_config.task_list) == 0 ): print( "[warning]", job_config.task_group, "has no task_list specified.") # we don't want this added to a final config. task_list = overwrite_config.pop("task_list", None) for config_file in task_list: config_file_path = os.path.join( self.base_project_dir, config_file) config = recursive_config(config_file_path) # overwrite it. if overwrite_config: config = OmegaConf.merge(config, overwrite_config) overwrite_dir(config, self.run_dir, basedir=self.base_run_dir) save_file_path = os.path.join(self.project_dir, config_file) configs_to_save[save_file_path] = config return configs_to_save def main(args): job_type = args.jobtype if args.jobtype else None # parse multiple pipelines. pipelines = [Pipeline(fn) for fn in args.yamls.split(",")] for pipe_id, pipeline in enumerate(pipelines): if not hasattr(pipeline, "project_dir"): for job in pipeline[0]: job(job_type=job_type, dryrun=args.dryrun) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("yamls", type=str) parser.add_argument( "--dryrun", action="store_true", help="run config and prepare to submit without launch the job.", ) parser.add_argument( "--jobtype", type=str, default="", help="force to run jobs as specified.") args = parser.parse_args() main(args)