|
|
|
|
|
|
|
|
|
"""isort:skip_file""" |
|
|
|
import importlib |
|
import os |
|
|
|
from fairseq import registry |
|
from fairseq.criterions.fairseq_criterion import ( |
|
FairseqCriterion, |
|
LegacyFairseqCriterion, |
|
) |
|
from omegaconf import DictConfig |
|
|
|
|
|
( |
|
build_criterion_, |
|
register_criterion, |
|
CRITERION_REGISTRY, |
|
CRITERION_DATACLASS_REGISTRY, |
|
) = registry.setup_registry( |
|
"--criterion", base_class=FairseqCriterion, default="cross_entropy" |
|
) |
|
|
|
|
|
def build_criterion(cfg: DictConfig, task): |
|
return build_criterion_(cfg, task) |
|
|
|
|
|
|
|
for file in sorted(os.listdir(os.path.dirname(__file__))): |
|
if file.endswith(".py") and not file.startswith("_"): |
|
file_name = file[: file.find(".py")] |
|
importlib.import_module("fairseq.criterions." + file_name) |
|
|