File size: 2,655 Bytes
26fd00c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from fairseq.dataclass.initialize import add_defaults, hydra_init
from fairseq_cli.train import main as pre_main
from fairseq import distributed_utils, metrics
from fairseq.dataclass.configs import FairseqConfig
from fairseq.dataclass.utils import omegaconf_no_object_check
from fairseq.utils import reset_logging
import hydra
from hydra.core.hydra_config import HydraConfig
import torch
from omegaconf import OmegaConf, open_dict
logger = logging.getLogger("fairseq_cli.hydra_train")
@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config")
def hydra_main(cfg: FairseqConfig) -> float:
_hydra_main(cfg)
def _hydra_main(cfg: FairseqConfig, **kwargs) -> float:
add_defaults(cfg)
if cfg.common.reset_logging:
reset_logging() # Hydra hijacks logging, fix that
else:
# check if directly called or called through hydra_main
if HydraConfig.initialized():
with open_dict(cfg):
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True)
with omegaconf_no_object_check():
cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True))
OmegaConf.set_struct(cfg, True)
try:
if cfg.common.profile:
with torch.cuda.profiler.profile():
with torch.autograd.profiler.emit_nvtx():
distributed_utils.call_main(cfg, pre_main, **kwargs)
else:
distributed_utils.call_main(cfg, pre_main, **kwargs)
except BaseException as e:
if not cfg.common.suppress_crashes:
raise
else:
logger.error("Crashed! " + str(e))
# get best val and return - useful for sweepers
try:
best_val = metrics.get_smoothed_value(
"valid", cfg.checkpoint.best_checkpoint_metric
)
except:
best_val = None
if best_val is None:
best_val = float("inf")
return best_val
def cli_main():
try:
from hydra._internal.utils import get_args
cfg_name = get_args().config_name or "config"
except:
logger.warning("Failed to get config name from hydra args")
cfg_name = "config"
hydra_init(cfg_name)
hydra_main()
if __name__ == "__main__":
cli_main()
|