|
|
|
|
|
|
|
|
|
|
|
|
|
import ast |
|
import hashlib |
|
import logging |
|
import os |
|
import shutil |
|
import sys |
|
from dataclasses import dataclass, field, is_dataclass |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import editdistance |
|
import torch |
|
import torch.distributed as dist |
|
from examples.speech_recognition.new.decoders.decoder_config import ( |
|
DecoderConfig, |
|
FlashlightDecoderConfig, |
|
) |
|
from examples.speech_recognition.new.decoders.decoder import Decoder |
|
from fairseq import checkpoint_utils, distributed_utils, progress_bar, tasks, utils |
|
from fairseq.data.data_utils import post_process |
|
from fairseq.dataclass.configs import ( |
|
CheckpointConfig, |
|
CommonConfig, |
|
CommonEvalConfig, |
|
DatasetConfig, |
|
DistributedTrainingConfig, |
|
FairseqDataclass, |
|
) |
|
from fairseq.logging.meters import StopwatchMeter, TimeMeter |
|
from fairseq.logging.progress_bar import BaseProgressBar |
|
from fairseq.models.fairseq_model import FairseqModel |
|
from omegaconf import OmegaConf |
|
|
|
import hydra |
|
from hydra.core.config_store import ConfigStore |
|
|
|
logging.root.setLevel(logging.INFO) |
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
config_path = Path(__file__).resolve().parent / "conf" |
|
|
|
|
|
@dataclass |
|
class DecodingConfig(DecoderConfig, FlashlightDecoderConfig): |
|
unique_wer_file: bool = field( |
|
default=False, |
|
metadata={"help": "If set, use a unique file for storing WER"}, |
|
) |
|
results_path: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "If set, write hypothesis and reference sentences into this directory" |
|
}, |
|
) |
|
|
|
|
|
@dataclass |
|
class InferConfig(FairseqDataclass): |
|
task: Any = None |
|
decoding: DecodingConfig = DecodingConfig() |
|
common: CommonConfig = CommonConfig() |
|
common_eval: CommonEvalConfig = CommonEvalConfig() |
|
checkpoint: CheckpointConfig = CheckpointConfig() |
|
distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() |
|
dataset: DatasetConfig = DatasetConfig() |
|
is_ax: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume" |
|
}, |
|
) |
|
|
|
|
|
def reset_logging(): |
|
root = logging.getLogger() |
|
for handler in root.handlers: |
|
root.removeHandler(handler) |
|
root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper()) |
|
handler = logging.StreamHandler(sys.stdout) |
|
handler.setFormatter( |
|
logging.Formatter( |
|
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
) |
|
) |
|
root.addHandler(handler) |
|
|
|
|
|
class InferenceProcessor: |
|
cfg: InferConfig |
|
|
|
def __init__(self, cfg: InferConfig) -> None: |
|
self.cfg = cfg |
|
self.task = tasks.setup_task(cfg.task) |
|
self.tgt_dict = self.task.target_dictionary |
|
|
|
models, saved_cfg = self.load_model_ensemble() |
|
self.models = models |
|
self.saved_cfg = saved_cfg |
|
|
|
self.task.load_dataset( |
|
self.cfg.dataset.gen_subset, |
|
task_cfg=saved_cfg.task, |
|
) |
|
self.generator = Decoder(cfg.decoding, self.tgt_dict) |
|
self.gen_timer = StopwatchMeter() |
|
self.wps_meter = TimeMeter() |
|
self.num_sentences = 0 |
|
self.total_errors = 0 |
|
self.total_length = 0 |
|
|
|
self.hypo_words_file = None |
|
self.hypo_units_file = None |
|
self.ref_words_file = None |
|
self.ref_units_file = None |
|
|
|
self.progress_bar = self.build_progress_bar() |
|
|
|
def __enter__(self) -> "InferenceProcessor": |
|
if self.cfg.decoding.results_path is not None: |
|
self.hypo_words_file = self.get_res_file("hypo.word") |
|
self.hypo_units_file = self.get_res_file("hypo.units") |
|
self.ref_words_file = self.get_res_file("ref.word") |
|
self.ref_units_file = self.get_res_file("ref.units") |
|
return self |
|
|
|
def __exit__(self, *exc) -> bool: |
|
if self.cfg.decoding.results_path is not None: |
|
self.hypo_words_file.close() |
|
self.hypo_units_file.close() |
|
self.ref_words_file.close() |
|
self.ref_units_file.close() |
|
return False |
|
|
|
def __iter__(self) -> Any: |
|
for sample in self.progress_bar: |
|
if not self.cfg.common.cpu: |
|
sample = utils.move_to_cuda(sample) |
|
|
|
|
|
if "net_input" not in sample: |
|
continue |
|
yield sample |
|
|
|
def log(self, *args, **kwargs): |
|
self.progress_bar.log(*args, **kwargs) |
|
|
|
def print(self, *args, **kwargs): |
|
self.progress_bar.print(*args, **kwargs) |
|
|
|
def get_res_file(self, fname: str) -> None: |
|
fname = os.path.join(self.cfg.decoding.results_path, fname) |
|
if self.data_parallel_world_size > 1: |
|
fname = f"{fname}.{self.data_parallel_rank}" |
|
return open(fname, "w", buffering=1) |
|
|
|
def merge_shards(self) -> None: |
|
"""Merges all shard files into shard 0, then removes shard suffix.""" |
|
|
|
shard_id = self.data_parallel_rank |
|
num_shards = self.data_parallel_world_size |
|
|
|
if self.data_parallel_world_size > 1: |
|
|
|
def merge_shards_with_root(fname: str) -> None: |
|
fname = os.path.join(self.cfg.decoding.results_path, fname) |
|
logger.info("Merging %s on shard %d", fname, shard_id) |
|
base_fpath = Path(f"{fname}.0") |
|
with open(base_fpath, "a") as out_file: |
|
for s in range(1, num_shards): |
|
shard_fpath = Path(f"{fname}.{s}") |
|
with open(shard_fpath, "r") as in_file: |
|
for line in in_file: |
|
out_file.write(line) |
|
shard_fpath.unlink() |
|
shutil.move(f"{fname}.0", fname) |
|
|
|
dist.barrier() |
|
if shard_id == (0 % num_shards): |
|
merge_shards_with_root("hypo.word") |
|
if shard_id == (1 % num_shards): |
|
merge_shards_with_root("hypo.units") |
|
if shard_id == (2 % num_shards): |
|
merge_shards_with_root("ref.word") |
|
if shard_id == (3 % num_shards): |
|
merge_shards_with_root("ref.units") |
|
dist.barrier() |
|
|
|
def optimize_model(self, model: FairseqModel) -> None: |
|
model.make_generation_fast_() |
|
if self.cfg.common.fp16: |
|
model.half() |
|
if not self.cfg.common.cpu: |
|
model.cuda() |
|
|
|
def load_model_ensemble(self) -> Tuple[List[FairseqModel], FairseqDataclass]: |
|
arg_overrides = ast.literal_eval(self.cfg.common_eval.model_overrides) |
|
models, saved_cfg = checkpoint_utils.load_model_ensemble( |
|
utils.split_paths(self.cfg.common_eval.path, separator="\\"), |
|
arg_overrides=arg_overrides, |
|
task=self.task, |
|
suffix=self.cfg.checkpoint.checkpoint_suffix, |
|
strict=(self.cfg.checkpoint.checkpoint_shard_count == 1), |
|
num_shards=self.cfg.checkpoint.checkpoint_shard_count, |
|
) |
|
for model in models: |
|
self.optimize_model(model) |
|
return models, saved_cfg |
|
|
|
def get_dataset_itr(self, disable_iterator_cache: bool = False) -> None: |
|
return self.task.get_batch_iterator( |
|
dataset=self.task.dataset(self.cfg.dataset.gen_subset), |
|
max_tokens=self.cfg.dataset.max_tokens, |
|
max_sentences=self.cfg.dataset.batch_size, |
|
max_positions=(sys.maxsize, sys.maxsize), |
|
ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test, |
|
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, |
|
seed=self.cfg.common.seed, |
|
num_shards=self.data_parallel_world_size, |
|
shard_id=self.data_parallel_rank, |
|
num_workers=self.cfg.dataset.num_workers, |
|
data_buffer_size=self.cfg.dataset.data_buffer_size, |
|
disable_iterator_cache=disable_iterator_cache, |
|
).next_epoch_itr(shuffle=False) |
|
|
|
def build_progress_bar( |
|
self, |
|
epoch: Optional[int] = None, |
|
prefix: Optional[str] = None, |
|
default_log_format: str = "tqdm", |
|
) -> BaseProgressBar: |
|
return progress_bar.progress_bar( |
|
iterator=self.get_dataset_itr(), |
|
log_format=self.cfg.common.log_format, |
|
log_interval=self.cfg.common.log_interval, |
|
epoch=epoch, |
|
prefix=prefix, |
|
tensorboard_logdir=self.cfg.common.tensorboard_logdir, |
|
default_log_format=default_log_format, |
|
) |
|
|
|
@property |
|
def data_parallel_world_size(self): |
|
if self.cfg.distributed_training.distributed_world_size == 1: |
|
return 1 |
|
return distributed_utils.get_data_parallel_world_size() |
|
|
|
@property |
|
def data_parallel_rank(self): |
|
if self.cfg.distributed_training.distributed_world_size == 1: |
|
return 0 |
|
return distributed_utils.get_data_parallel_rank() |
|
|
|
def process_sentence( |
|
self, |
|
sample: Dict[str, Any], |
|
hypo: Dict[str, Any], |
|
sid: int, |
|
batch_id: int, |
|
) -> Tuple[int, int]: |
|
speaker = None |
|
|
|
if "target_label" in sample: |
|
toks = sample["target_label"] |
|
else: |
|
toks = sample["target"] |
|
toks = toks[batch_id, :] |
|
|
|
|
|
hyp_pieces = self.tgt_dict.string(hypo["tokens"].int().cpu()) |
|
if "words" in hypo: |
|
hyp_words = " ".join(hypo["words"]) |
|
else: |
|
hyp_words = post_process(hyp_pieces, self.cfg.common_eval.post_process) |
|
|
|
|
|
target_tokens = utils.strip_pad(toks, self.tgt_dict.pad()) |
|
tgt_pieces = self.tgt_dict.string(target_tokens.int().cpu()) |
|
tgt_words = post_process(tgt_pieces, self.cfg.common_eval.post_process) |
|
|
|
if self.cfg.decoding.results_path is not None: |
|
print(f"{hyp_pieces} ({speaker}-{sid})", file=self.hypo_units_file) |
|
print(f"{hyp_words} ({speaker}-{sid})", file=self.hypo_words_file) |
|
print(f"{tgt_pieces} ({speaker}-{sid})", file=self.ref_units_file) |
|
print(f"{tgt_words} ({speaker}-{sid})", file=self.ref_words_file) |
|
|
|
if not self.cfg.common_eval.quiet: |
|
logger.info(f"HYPO: {hyp_words}") |
|
logger.info(f"REF: {tgt_words}") |
|
logger.info("---------------------") |
|
|
|
hyp_words, tgt_words = hyp_words.split(), tgt_words.split() |
|
|
|
return editdistance.eval(hyp_words, tgt_words), len(tgt_words) |
|
|
|
def process_sample(self, sample: Dict[str, Any]) -> None: |
|
self.gen_timer.start() |
|
hypos = self.task.inference_step( |
|
generator=self.generator, |
|
models=self.models, |
|
sample=sample, |
|
) |
|
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) |
|
self.gen_timer.stop(num_generated_tokens) |
|
self.wps_meter.update(num_generated_tokens) |
|
|
|
for batch_id, sample_id in enumerate(sample["id"].tolist()): |
|
errs, length = self.process_sentence( |
|
sample=sample, |
|
sid=sample_id, |
|
batch_id=batch_id, |
|
hypo=hypos[batch_id][0], |
|
) |
|
self.total_errors += errs |
|
self.total_length += length |
|
|
|
self.log({"wps": round(self.wps_meter.avg)}) |
|
if "nsentences" in sample: |
|
self.num_sentences += sample["nsentences"] |
|
else: |
|
self.num_sentences += sample["id"].numel() |
|
|
|
def log_generation_time(self) -> None: |
|
logger.info( |
|
"Processed %d sentences (%d tokens) in %.1fs %.2f " |
|
"sentences per second, %.2f tokens per second)", |
|
self.num_sentences, |
|
self.gen_timer.n, |
|
self.gen_timer.sum, |
|
self.num_sentences / self.gen_timer.sum, |
|
1.0 / self.gen_timer.avg, |
|
) |
|
|
|
|
|
def parse_wer(wer_file: Path) -> float: |
|
with open(wer_file, "r") as f: |
|
return float(f.readline().strip().split(" ")[1]) |
|
|
|
|
|
def get_wer_file(cfg: InferConfig) -> Path: |
|
"""Hashes the decoding parameters to a unique file ID.""" |
|
base_path = "wer" |
|
if cfg.decoding.results_path is not None: |
|
base_path = os.path.join(cfg.decoding.results_path, base_path) |
|
|
|
if cfg.decoding.unique_wer_file: |
|
yaml_str = OmegaConf.to_yaml(cfg.decoding) |
|
fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16) |
|
return Path(f"{base_path}.{fid % 1000000}") |
|
else: |
|
return Path(base_path) |
|
|
|
|
|
def main(cfg: InferConfig) -> float: |
|
"""Entry point for main processing logic. |
|
|
|
Args: |
|
cfg: The inferance configuration to use. |
|
wer: Optional shared memory pointer for returning the WER. If not None, |
|
the final WER value will be written here instead of being returned. |
|
|
|
Returns: |
|
The final WER if `wer` is None, otherwise None. |
|
""" |
|
|
|
yaml_str, wer_file = OmegaConf.to_yaml(cfg.decoding), get_wer_file(cfg) |
|
|
|
|
|
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: |
|
cfg.dataset.max_tokens = 4000000 |
|
if not cfg.common.cpu and not torch.cuda.is_available(): |
|
raise ValueError("CUDA not found; set `cpu=True` to run without CUDA") |
|
|
|
with InferenceProcessor(cfg) as processor: |
|
for sample in processor: |
|
processor.process_sample(sample) |
|
|
|
processor.log_generation_time() |
|
|
|
if cfg.decoding.results_path is not None: |
|
processor.merge_shards() |
|
|
|
errs_t, leng_t = processor.total_errors, processor.total_length |
|
|
|
if cfg.common.cpu: |
|
logger.warning("Merging WER requires CUDA.") |
|
elif processor.data_parallel_world_size > 1: |
|
stats = torch.LongTensor([errs_t, leng_t]).cuda() |
|
dist.all_reduce(stats, op=dist.ReduceOp.SUM) |
|
errs_t, leng_t = stats[0].item(), stats[1].item() |
|
|
|
wer = errs_t * 100.0 / leng_t |
|
|
|
if distributed_utils.is_master(cfg.distributed_training): |
|
with open(wer_file, "w") as f: |
|
f.write( |
|
( |
|
f"WER: {wer}\n" |
|
f"err / num_ref_words = {errs_t} / {leng_t}\n\n" |
|
f"{yaml_str}" |
|
) |
|
) |
|
|
|
return wer |
|
|
|
|
|
@hydra.main(config_path=config_path, config_name="infer") |
|
def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]: |
|
container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) |
|
cfg = OmegaConf.create(container) |
|
OmegaConf.set_struct(cfg, True) |
|
|
|
if cfg.common.reset_logging: |
|
reset_logging() |
|
|
|
|
|
wer = float("inf") |
|
|
|
try: |
|
if cfg.common.profile: |
|
with torch.cuda.profiler.profile(): |
|
with torch.autograd.profiler.emit_nvtx(): |
|
distributed_utils.call_main(cfg, main) |
|
else: |
|
distributed_utils.call_main(cfg, main) |
|
|
|
wer = parse_wer(get_wer_file(cfg)) |
|
except BaseException as e: |
|
if not cfg.common.suppress_crashes: |
|
raise |
|
else: |
|
logger.error("Crashed! %s", str(e)) |
|
|
|
logger.info("Word error rate: %.4f", wer) |
|
if cfg.is_ax: |
|
return wer, None |
|
|
|
return wer |
|
|
|
|
|
def cli_main() -> None: |
|
try: |
|
from hydra._internal.utils import ( |
|
get_args, |
|
) |
|
|
|
cfg_name = get_args().config_name or "infer" |
|
except ImportError: |
|
logger.warning("Failed to get config name from hydra args") |
|
cfg_name = "infer" |
|
|
|
cs = ConfigStore.instance() |
|
cs.store(name=cfg_name, node=InferConfig) |
|
|
|
for k in InferConfig.__dataclass_fields__: |
|
if is_dataclass(InferConfig.__dataclass_fields__[k].type): |
|
v = InferConfig.__dataclass_fields__[k].default |
|
cs.store(name=k, node=v) |
|
|
|
hydra_main() |
|
|
|
|
|
if __name__ == "__main__": |
|
cli_main() |
|
|