HuBERT / fairseq /tasks /sentence_prediction.py
aliabd
full working demo
d5175d3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import numpy as np
from fairseq import utils
from fairseq.data import (
ConcatSentencesDataset,
Dictionary,
IdDataset,
NestedDictionaryDataset,
NumelDataset,
NumSamplesDataset,
OffsetTokensDataset,
PrependTokenDataset,
RawLabelDataset,
RightPadDataset,
RollDataset,
SortDataset,
StripTokenDataset,
data_utils,
)
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task("sentence_prediction")
class SentencePredictionTask(LegacyFairseqTask):
"""
Sentence (or sentence pair) prediction (classification or regression) task.
Args:
dictionary (Dictionary): the dictionary for the input of the task
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument("data", metavar="FILE", help="file prefix for data")
parser.add_argument(
"--num-classes",
type=int,
default=-1,
help="number of classes or regression targets",
)
parser.add_argument(
"--init-token",
type=int,
default=None,
help="add token at the beginning of each batch item",
)
parser.add_argument(
"--separator-token",
type=int,
default=None,
help="add separator token between inputs",
)
parser.add_argument("--regression-target", action="store_true", default=False)
parser.add_argument("--no-shuffle", action="store_true", default=False)
parser.add_argument(
"--shorten-method",
default="none",
choices=["none", "truncate", "random_crop"],
help="if not none, shorten sequences that exceed --tokens-per-sample",
)
parser.add_argument(
"--shorten-data-split-list",
default="",
help="comma-separated list of dataset splits to apply shortening to, "
'e.g., "train,valid" (default: all dataset splits)',
)
parser.add_argument(
"--add-prev-output-tokens",
action="store_true",
default=False,
help="add prev_output_tokens to sample, used for encoder-decoder arch",
)
def __init__(self, args, data_dictionary, label_dictionary):
super().__init__(args)
self.dictionary = data_dictionary
self._label_dictionary = label_dictionary
if not hasattr(args, "max_positions"):
self._max_positions = (
args.max_source_positions,
args.max_target_positions,
)
else:
self._max_positions = args.max_positions
args.tokens_per_sample = self._max_positions
@classmethod
def load_dictionary(cls, args, filename, source=True):
"""Load the dictionary from the filename
Args:
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
dictionary.add_symbol("<mask>")
return dictionary
@classmethod
def setup_task(cls, args, **kwargs):
assert args.num_classes > 0, "Must set --num-classes"
# load data dictionary
data_dict = cls.load_dictionary(
args,
os.path.join(args.data, "input0", "dict.txt"),
source=True,
)
logger.info("[input] dictionary: {} types".format(len(data_dict)))
# load label dictionary
if not args.regression_target:
label_dict = cls.load_dictionary(
args,
os.path.join(args.data, "label", "dict.txt"),
source=False,
)
logger.info("[label] dictionary: {} types".format(len(label_dict)))
else:
label_dict = data_dict
return cls(args, data_dict, label_dict)
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split (e.g., train, valid, test)."""
def get_path(key, split):
return os.path.join(self.args.data, key, split)
def make_dataset(key, dictionary):
split_path = get_path(key, split)
try:
dataset = data_utils.load_indexed_dataset(
split_path,
dictionary,
self.args.dataset_impl,
combine=combine,
)
except Exception as e:
if "StorageException: [404] Path not found" in str(e):
logger.warning(f"dataset {e} not found")
dataset = None
else:
raise e
return dataset
input0 = make_dataset("input0", self.source_dictionary)
assert input0 is not None, "could not find dataset: {}".format(
get_path("input0", split)
)
input1 = make_dataset("input1", self.source_dictionary)
if self.args.init_token is not None:
input0 = PrependTokenDataset(input0, self.args.init_token)
if input1 is None:
src_tokens = input0
else:
if self.args.separator_token is not None:
input1 = PrependTokenDataset(input1, self.args.separator_token)
src_tokens = ConcatSentencesDataset(input0, input1)
with data_utils.numpy_seed(self.args.seed):
shuffle = np.random.permutation(len(src_tokens))
src_tokens = maybe_shorten_dataset(
src_tokens,
split,
self.args.shorten_data_split_list,
self.args.shorten_method,
self.max_positions(),
self.args.seed,
)
dataset = {
"id": IdDataset(),
"net_input": {
"src_tokens": RightPadDataset(
src_tokens,
pad_idx=self.source_dictionary.pad(),
),
"src_lengths": NumelDataset(src_tokens, reduce=False),
},
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(src_tokens, reduce=True),
}
if self.args.add_prev_output_tokens:
prev_tokens_dataset = RightPadDataset(
RollDataset(src_tokens, 1),
pad_idx=self.dictionary.pad(),
)
dataset["net_input"].update(
prev_output_tokens=prev_tokens_dataset,
)
if not self.args.regression_target:
label_dataset = make_dataset("label", self.label_dictionary)
if label_dataset is not None:
dataset.update(
target=OffsetTokensDataset(
StripTokenDataset(
label_dataset,
id_to_strip=self.label_dictionary.eos(),
),
offset=-self.label_dictionary.nspecial,
)
)
else:
label_path = "{0}.label".format(get_path("label", split))
if os.path.exists(label_path):
def parse_regression_target(i, line):
values = line.split()
assert (
len(values) == self.args.num_classes
), f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"'
return [float(x) for x in values]
with open(label_path) as h:
dataset.update(
target=RawLabelDataset(
[
parse_regression_target(i, line.strip())
for i, line in enumerate(h.readlines())
]
)
)
nested_dataset = NestedDictionaryDataset(
dataset,
sizes=[src_tokens.sizes],
)
if self.args.no_shuffle:
dataset = nested_dataset
else:
dataset = SortDataset(
nested_dataset,
# shuffle
sort_order=[shuffle],
)
logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset)))
self.datasets[split] = dataset
return self.datasets[split]
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
model.register_classification_head(
getattr(args, "classification_head_name", "sentence_classification_head"),
num_classes=self.args.num_classes,
)
return model
def max_positions(self):
return self._max_positions
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary
@property
def label_dictionary(self):
return self._label_dictionary