|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
import logging |
|
import os |
|
import math |
|
import torch |
|
from typing import Dict, Optional |
|
|
|
from fairseq import search |
|
from fairseq.data import FairseqDataset, iterators |
|
from fairseq.optim.amp_optimizer import AMPOptimizer |
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.tasks import FairseqTask, register_task |
|
from omegaconf import DictConfig |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class OFAConfig(FairseqDataclass): |
|
data: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "comma separated path to data list, will be iterated upon during epochs " |
|
"in round-robin manner; valid data are always in the last" |
|
}, |
|
) |
|
selected_cols: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "selected cols"}, |
|
) |
|
bpe: Optional[str] = field( |
|
default='gpt2', |
|
metadata={"help": "which bpe to use"}, |
|
) |
|
bpe_dir: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "bpe dir"}, |
|
) |
|
max_source_positions: int = field( |
|
default=1024, metadata={"help": "max number of tokens in the source sequence"} |
|
) |
|
max_target_positions: int = field( |
|
default=1024, metadata={"help": "max number of tokens in the target sequence"} |
|
) |
|
max_src_length: int = field( |
|
default=128, metadata={"help": "the maximum src sequence length"} |
|
) |
|
max_tgt_length: int = field( |
|
default=30, metadata={"help": "the maximum target sequence length"} |
|
) |
|
|
|
code_dict_size: int = field( |
|
default=8192, metadata={"help": "code dict size"} |
|
) |
|
patch_image_size: int = field( |
|
default=480, metadata={"help": "patch image size"} |
|
) |
|
orig_patch_image_size: int = field( |
|
default=256, metadata={"help": "patch image size"} |
|
) |
|
num_bins: int = field( |
|
default=1000, metadata={"help": "number of quantization bins"} |
|
) |
|
|
|
imagenet_default_mean_and_std: bool = field( |
|
default=False, |
|
metadata={"help": "imagenet normalize"}, |
|
) |
|
constraint_range: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "constraint range"} |
|
) |
|
|
|
|
|
@register_task("ofa", dataclass=OFAConfig) |
|
class OFATask(FairseqTask): |
|
def __init__(self, cfg: OFAConfig, src_dict, tgt_dict): |
|
super().__init__(cfg) |
|
self.src_dict = src_dict |
|
self.tgt_dict = tgt_dict |
|
|
|
@classmethod |
|
def setup_task(cls, cfg: DictConfig, **kwargs): |
|
"""Setup the task.""" |
|
|
|
|
|
src_dict = cls.load_dictionary( |
|
os.path.join(cfg.bpe_dir, "dict.txt") |
|
) |
|
tgt_dict = cls.load_dictionary( |
|
os.path.join(cfg.bpe_dir, "dict.txt") |
|
) |
|
src_dict.add_symbol("<mask>") |
|
tgt_dict.add_symbol("<mask>") |
|
for i in range(cfg.code_dict_size): |
|
src_dict.add_symbol("<code_{}>".format(i)) |
|
tgt_dict.add_symbol("<code_{}>".format(i)) |
|
|
|
for i in range(cfg.num_bins): |
|
src_dict.add_symbol("<bin_{}>".format(i)) |
|
tgt_dict.add_symbol("<bin_{}>".format(i)) |
|
|
|
logger.info("source dictionary: {} types".format(len(src_dict))) |
|
logger.info("target dictionary: {} types".format(len(tgt_dict))) |
|
return cls(cfg, src_dict, tgt_dict) |
|
|
|
def get_batch_iterator( |
|
self, |
|
dataset, |
|
max_tokens=None, |
|
max_sentences=None, |
|
max_positions=None, |
|
ignore_invalid_inputs=False, |
|
required_batch_size_multiple=1, |
|
seed=1, |
|
num_shards=1, |
|
shard_id=0, |
|
num_workers=0, |
|
epoch=1, |
|
data_buffer_size=0, |
|
disable_iterator_cache=False, |
|
): |
|
assert isinstance(dataset, FairseqDataset) |
|
|
|
|
|
dataset.set_epoch(epoch) |
|
|
|
|
|
batch_sampler = [ |
|
[j for j in range(i, min(i + max_sentences, len(dataset)))] |
|
for i in range(0, len(dataset), max_sentences) |
|
] |
|
total_row_count = dataset.dataset.get_total_row_count() |
|
num_batches = math.ceil(math.ceil(total_row_count / num_shards) / max_sentences) |
|
if len(batch_sampler) < num_batches: |
|
batch_sampler.append([]) |
|
|
|
|
|
epoch_iter = iterators.EpochBatchIterator( |
|
dataset=dataset, |
|
collate_fn=dataset.collater, |
|
batch_sampler=batch_sampler, |
|
seed=seed, |
|
num_shards=1, |
|
shard_id=0, |
|
num_workers=num_workers, |
|
epoch=epoch, |
|
buffer_size=data_buffer_size |
|
) |
|
|
|
return epoch_iter |
|
|
|
def build_model(self, cfg: FairseqDataclass): |
|
model = super().build_model(cfg) |
|
if self.cfg.bpe == 'bert': |
|
bpe_dict = { |
|
"_name": "bert", |
|
"bpe_vocab_file": os.path.join(self.cfg.bpe_dir, "vocab.txt"), |
|
"bpe_cased": False |
|
} |
|
bpe_dict = DictConfig(bpe_dict) |
|
self.bpe = self.build_bpe(bpe_dict) |
|
else: |
|
bpe_dict = { |
|
"_name": "gpt2", |
|
"gpt2_encoder_json": os.path.join(self.cfg.bpe_dir, "encoder.json"), |
|
"gpt2_vocab_bpe": os.path.join(self.cfg.bpe_dir, "vocab.bpe") |
|
} |
|
bpe_dict = DictConfig(bpe_dict) |
|
self.bpe = self.build_bpe(bpe_dict) |
|
return model |
|
|
|
def build_generator( |
|
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None, |
|
): |
|
""" |
|
Build a :class:`~fairseq.SequenceGenerator` instance for this |
|
task. |
|
|
|
Args: |
|
models (List[~fairseq.models.FairseqModel]): ensemble of models |
|
args (fairseq.dataclass.configs.GenerationConfig): |
|
configuration object (dataclass) for generation |
|
extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass |
|
through to SequenceGenerator |
|
prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): |
|
If provided, this function constrains the beam search to |
|
allowed tokens only at each step. The provided function |
|
should take 2 arguments: the batch ID (`batch_id: int`) |
|
and a unidimensional tensor of token ids (`inputs_ids: |
|
torch.Tensor`). It has to return a `List[int]` with the |
|
allowed tokens for the next generation step conditioned |
|
on the previously generated tokens (`inputs_ids`) and |
|
the batch ID (`batch_id`). This argument is useful for |
|
constrained generation conditioned on the prefix, as |
|
described in "Autoregressive Entity Retrieval" |
|
(https://arxiv.org/abs/2010.00904) and |
|
https://github.com/facebookresearch/GENRE. |
|
""" |
|
if getattr(args, "score_reference", False): |
|
from fairseq.sequence_scorer import SequenceScorer |
|
|
|
return SequenceScorer( |
|
self.target_dictionary, |
|
compute_alignment=getattr(args, "print_alignment", False), |
|
) |
|
|
|
from fairseq.sequence_generator import ( |
|
|
|
SequenceGeneratorWithAlignment, |
|
) |
|
from models.sequence_generator import SequenceGenerator |
|
|
|
|
|
sampling = getattr(args, "sampling", False) |
|
sampling_topk = getattr(args, "sampling_topk", -1) |
|
sampling_topp = getattr(args, "sampling_topp", -1.0) |
|
diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) |
|
diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) |
|
match_source_len = getattr(args, "match_source_len", False) |
|
diversity_rate = getattr(args, "diversity_rate", -1) |
|
constrained = getattr(args, "constraints", False) |
|
if prefix_allowed_tokens_fn is None: |
|
prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) |
|
if ( |
|
sum( |
|
int(cond) |
|
for cond in [ |
|
sampling, |
|
diverse_beam_groups > 0, |
|
match_source_len, |
|
diversity_rate > 0, |
|
] |
|
) |
|
> 1 |
|
): |
|
raise ValueError("Provided Search parameters are mutually exclusive.") |
|
assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" |
|
assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" |
|
|
|
if sampling: |
|
search_strategy = search.Sampling( |
|
self.target_dictionary, sampling_topk, sampling_topp |
|
) |
|
elif diverse_beam_groups > 0: |
|
search_strategy = search.DiverseBeamSearch( |
|
self.target_dictionary, diverse_beam_groups, diverse_beam_strength |
|
) |
|
elif match_source_len: |
|
|
|
|
|
|
|
search_strategy = search.LengthConstrainedBeamSearch( |
|
self.target_dictionary, |
|
min_len_a=1, |
|
min_len_b=0, |
|
max_len_a=1, |
|
max_len_b=0, |
|
) |
|
elif diversity_rate > -1: |
|
search_strategy = search.DiverseSiblingsSearch( |
|
self.target_dictionary, diversity_rate |
|
) |
|
elif constrained: |
|
search_strategy = search.LexicallyConstrainedBeamSearch( |
|
self.target_dictionary, args.constraints |
|
) |
|
elif prefix_allowed_tokens_fn: |
|
search_strategy = search.PrefixConstrainedBeamSearch( |
|
self.target_dictionary, prefix_allowed_tokens_fn |
|
) |
|
else: |
|
search_strategy = search.BeamSearch(self.target_dictionary) |
|
|
|
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} |
|
if seq_gen_cls is None: |
|
if getattr(args, "print_alignment", False): |
|
seq_gen_cls = SequenceGeneratorWithAlignment |
|
extra_gen_cls_kwargs["print_alignment"] = args.print_alignment |
|
else: |
|
seq_gen_cls = SequenceGenerator |
|
|
|
return seq_gen_cls( |
|
models, |
|
self.target_dictionary, |
|
beam_size=getattr(args, "beam", 5), |
|
max_len_a=getattr(args, "max_len_a", 0), |
|
max_len_b=getattr(args, "max_len_b", 200), |
|
min_len=getattr(args, "min_len", 1), |
|
normalize_scores=(not getattr(args, "unnormalized", False)), |
|
len_penalty=getattr(args, "lenpen", 1), |
|
unk_penalty=getattr(args, "unkpen", 0), |
|
temperature=getattr(args, "temperature", 1.0), |
|
match_source_len=getattr(args, "match_source_len", False), |
|
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), |
|
search_strategy=search_strategy, |
|
constraint_range=self.cfg.constraint_range, |
|
**extra_gen_cls_kwargs, |
|
) |
|
|
|
def train_step( |
|
self, sample, model, criterion, optimizer, update_num, ignore_grad=False, **extra_kwargs |
|
): |
|
""" |
|
Do forward and backward, and return the loss as computed by *criterion* |
|
for the given *model* and *sample*. |
|
|
|
Args: |
|
sample (dict): the mini-batch. The format is defined by the |
|
:class:`~fairseq.data.FairseqDataset`. |
|
model (~fairseq.models.BaseFairseqModel): the model |
|
criterion (~fairseq.criterions.FairseqCriterion): the criterion |
|
optimizer (~fairseq.optim.FairseqOptimizer): the optimizer |
|
update_num (int): the current update |
|
ignore_grad (bool): multiply loss by 0 if this is set to True |
|
|
|
Returns: |
|
tuple: |
|
- the loss |
|
- the sample size, which is used as the denominator for the |
|
gradient |
|
- logging outputs to display while training |
|
""" |
|
model.train() |
|
model.set_num_updates(update_num) |
|
with torch.autograd.profiler.record_function("forward"): |
|
with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): |
|
loss, sample_size, logging_output = criterion(model, sample, update_num=update_num) |
|
if ignore_grad: |
|
loss *= 0 |
|
with torch.autograd.profiler.record_function("backward"): |
|
optimizer.backward(loss) |
|
return loss, sample_size, logging_output |
|
|
|
def max_positions(self): |
|
"""Return the max sentence length allowed by the task.""" |
|
return (self.cfg.max_source_positions, self.cfg.max_target_positions) |
|
|
|
@property |
|
def source_dictionary(self): |
|
"""Return the source :class:`~fairseq.data.Dictionary`.""" |
|
return self.src_dict |
|
|
|
@property |
|
def target_dictionary(self): |
|
"""Return the target :class:`~fairseq.data.Dictionary`.""" |
|
return self.tgt_dict |
|
|