|
|
|
|
|
|
|
|
|
"""isort:skip_file""" |
|
|
|
import argparse |
|
import importlib |
|
import os |
|
|
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.dataclass.utils import merge_with_parent |
|
from hydra.core.config_store import ConfigStore |
|
|
|
from .fairseq_task import FairseqTask, LegacyFairseqTask |
|
|
|
|
|
|
|
TASK_DATACLASS_REGISTRY = {} |
|
TASK_REGISTRY = {} |
|
TASK_CLASS_NAMES = set() |
|
|
|
|
|
def setup_task(cfg: FairseqDataclass, **kwargs): |
|
task = None |
|
task_name = getattr(cfg, "task", None) |
|
|
|
if isinstance(task_name, str): |
|
|
|
task = TASK_REGISTRY[task_name] |
|
if task_name in TASK_DATACLASS_REGISTRY: |
|
dc = TASK_DATACLASS_REGISTRY[task_name] |
|
cfg = dc.from_namespace(cfg) |
|
else: |
|
task_name = getattr(cfg, "_name", None) |
|
|
|
if task_name and task_name in TASK_DATACLASS_REGISTRY: |
|
dc = TASK_DATACLASS_REGISTRY[task_name] |
|
cfg = merge_with_parent(dc(), cfg) |
|
task = TASK_REGISTRY[task_name] |
|
|
|
assert ( |
|
task is not None |
|
), f"Could not infer task type from {cfg}. Available argparse tasks: {TASK_REGISTRY.keys()}. Available hydra tasks: {TASK_DATACLASS_REGISTRY.keys()}" |
|
|
|
return task.setup_task(cfg, **kwargs) |
|
|
|
|
|
def register_task(name, dataclass=None): |
|
""" |
|
New tasks can be added to fairseq with the |
|
:func:`~fairseq.tasks.register_task` function decorator. |
|
|
|
For example:: |
|
|
|
@register_task('classification') |
|
class ClassificationTask(FairseqTask): |
|
(...) |
|
|
|
.. note:: |
|
|
|
All Tasks must implement the :class:`~fairseq.tasks.FairseqTask` |
|
interface. |
|
|
|
Args: |
|
name (str): the name of the task |
|
""" |
|
|
|
def register_task_cls(cls): |
|
if name in TASK_REGISTRY: |
|
raise ValueError("Cannot register duplicate task ({})".format(name)) |
|
if not issubclass(cls, FairseqTask): |
|
raise ValueError( |
|
"Task ({}: {}) must extend FairseqTask".format(name, cls.__name__) |
|
) |
|
if cls.__name__ in TASK_CLASS_NAMES: |
|
raise ValueError( |
|
"Cannot register task with duplicate class name ({})".format( |
|
cls.__name__ |
|
) |
|
) |
|
TASK_REGISTRY[name] = cls |
|
TASK_CLASS_NAMES.add(cls.__name__) |
|
|
|
if dataclass is not None and not issubclass(dataclass, FairseqDataclass): |
|
raise ValueError( |
|
"Dataclass {} must extend FairseqDataclass".format(dataclass) |
|
) |
|
|
|
cls.__dataclass = dataclass |
|
if dataclass is not None: |
|
TASK_DATACLASS_REGISTRY[name] = dataclass |
|
|
|
cs = ConfigStore.instance() |
|
node = dataclass() |
|
node._name = name |
|
cs.store(name=name, group="task", node=node, provider="fairseq") |
|
|
|
return cls |
|
|
|
return register_task_cls |
|
|
|
|
|
def get_task(name): |
|
return TASK_REGISTRY[name] |
|
|
|
|
|
def import_tasks(tasks_dir, namespace): |
|
for file in os.listdir(tasks_dir): |
|
path = os.path.join(tasks_dir, file) |
|
if ( |
|
not file.startswith("_") |
|
and not file.startswith(".") |
|
and (file.endswith(".py") or os.path.isdir(path)) |
|
): |
|
task_name = file[: file.find(".py")] if file.endswith(".py") else file |
|
importlib.import_module(namespace + "." + task_name) |
|
|
|
|
|
if task_name in TASK_REGISTRY: |
|
parser = argparse.ArgumentParser(add_help=False) |
|
group_task = parser.add_argument_group("Task name") |
|
|
|
group_task.add_argument('--task', metavar=task_name, |
|
help='Enable this task with: ``--task=' + task_name + '``') |
|
|
|
group_args = parser.add_argument_group( |
|
"Additional command-line arguments" |
|
) |
|
TASK_REGISTRY[task_name].add_args(group_args) |
|
globals()[task_name + "_parser"] = parser |
|
|
|
|
|
|
|
tasks_dir = os.path.dirname(__file__) |
|
import_tasks(tasks_dir, "fairseq.tasks") |
|
|