PyTorch
ssl-aasist
custom_code
ssl-aasist / fairseq /fairseq_cli /hydra_validate.py
ash56's picture
Add files using upload-large-folder tool
9043f3c verified
raw
history blame
6.42 kB
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import sys
from itertools import chain
import torch
from hydra.core.hydra_config import HydraConfig
from omegaconf import OmegaConf, open_dict
import hydra
from fairseq import checkpoint_utils, distributed_utils, utils
from fairseq.dataclass.configs import FairseqConfig
from fairseq.dataclass.initialize import add_defaults, hydra_init
from fairseq.dataclass.utils import omegaconf_no_object_check
from fairseq.distributed import utils as distributed_utils
from fairseq.logging import metrics, progress_bar
from fairseq.utils import reset_logging
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("fairseq_cli.validate")
@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config")
def hydra_main(cfg: FairseqConfig) -> float:
return _hydra_main(cfg)
def _hydra_main(cfg: FairseqConfig, **kwargs) -> float:
add_defaults(cfg)
if cfg.common.reset_logging:
reset_logging() # Hydra hijacks logging, fix that
else:
# check if directly called or called through hydra_main
if HydraConfig.initialized():
with open_dict(cfg):
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
cfg.job_logging_cfg = OmegaConf.to_container(
HydraConfig.get().job_logging, resolve=True
)
with omegaconf_no_object_check():
cfg = OmegaConf.create(
OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
)
OmegaConf.set_struct(cfg, True)
assert (
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
), "Must specify batch size either with --max-tokens or --batch-size"
distributed_utils.call_main(cfg, validate, **kwargs)
def validate(cfg):
utils.import_user_module(cfg.common)
use_fp16 = cfg.common.fp16
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
if use_cuda:
torch.cuda.set_device(cfg.distributed_training.device_id)
if cfg.distributed_training.distributed_world_size > 1:
data_parallel_world_size = distributed_utils.get_data_parallel_world_size()
data_parallel_rank = distributed_utils.get_data_parallel_rank()
else:
data_parallel_world_size = 1
data_parallel_rank = 0
overrides = {"task": {"data": cfg.task.data}}
# Load ensemble
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[cfg.common_eval.path],
arg_overrides=overrides,
suffix=cfg.checkpoint.checkpoint_suffix,
)
model = models[0]
# Move models to GPU
for model in models:
model.eval()
if use_fp16:
model.half()
if use_cuda:
model.cuda()
# Print args
logger.info(saved_cfg)
# Build criterion
criterion = task.build_criterion(saved_cfg.criterion, from_checkpoint=True)
criterion.eval()
for subset in cfg.dataset.valid_subset.split(","):
try:
task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task)
dataset = task.dataset(subset)
except KeyError:
raise Exception("Cannot find dataset: " + subset)
# Initialize data iterator
itr = task.get_batch_iterator(
dataset=dataset,
max_tokens=cfg.dataset.max_tokens,
max_sentences=cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions(
task.max_positions(),
*[m.max_positions() for m in models],
),
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
seed=cfg.common.seed,
num_shards=data_parallel_world_size,
shard_id=data_parallel_rank,
num_workers=cfg.dataset.num_workers,
data_buffer_size=cfg.dataset.data_buffer_size,
).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar(
itr,
log_format=cfg.common.log_format,
log_interval=cfg.common.log_interval,
prefix=f"valid on '{subset}' subset",
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
)
def apply_half(t):
if t.dtype is torch.float32:
return t.to(dtype=torch.half)
return t
log_outputs = []
for i, sample in enumerate(progress):
sample = utils.move_to_cuda(sample) if use_cuda else sample
if use_fp16:
sample = utils.apply_to_sample(apply_half, sample)
_loss, _sample_size, log_output = task.valid_step(sample, model, criterion)
with metrics.aggregate() as agg:
task.reduce_metrics([log_output], criterion)
progress.log(agg.get_smoothed_values(), step=i)
# progress.log(log_output, step=i) from vision
log_outputs.append(log_output)
if data_parallel_world_size > 1:
log_outputs = distributed_utils.all_gather_list(
log_outputs,
max_size=cfg.common.all_gather_list_size,
group=distributed_utils.get_data_parallel_group(),
)
log_outputs = list(chain.from_iterable(log_outputs))
with metrics.aggregate() as agg:
task.reduce_metrics(log_outputs, criterion)
log_output = agg.get_smoothed_values()
progress.print(log_output, tag=subset, step=i)
def cli_main():
try:
from hydra._internal.utils import get_args
cfg_name = get_args().config_name or "config"
except:
logger.warning("Failed to get config name from hydra args")
cfg_name = "config"
hydra_init(cfg_name)
hydra_main()
if __name__ == "__main__":
cli_main()