Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- fairseq/examples/textless_nlp/speech-resynth/img/fig.png +3 -0
- fairseq/fairseq/benchmark/__init__.py +7 -0
- fairseq/fairseq/benchmark/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/fairseq/benchmark/__pycache__/dummy_dataset.cpython-310.pyc +0 -0
- fairseq/fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc +0 -0
- fairseq/fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc +0 -0
- fairseq/fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc +0 -0
- fairseq/fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc +0 -0
- fairseq/fairseq/benchmark/benchmark_multihead_attention.py +172 -0
- fairseq/fairseq/benchmark/dummy_dataset.py +36 -0
- fairseq/fairseq/benchmark/dummy_lm.py +83 -0
- fairseq/fairseq/benchmark/dummy_masked_lm.py +94 -0
- fairseq/fairseq/benchmark/dummy_model.py +96 -0
- fairseq/fairseq/benchmark/dummy_mt.py +119 -0
- fairseq/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp +55 -0
- fairseq/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu +82 -0
- fairseq/fairseq/clib/libbase/balanced_assignment.cpp +109 -0
- fairseq/fairseq/clib/libbleu/libbleu.cpp +157 -0
- fairseq/fairseq/clib/libbleu/module.cpp +33 -0
- fairseq/fairseq/clib/libnat/edit_dist.cpp +231 -0
- fairseq/fairseq/clib/libnat_cuda/binding.cpp +67 -0
- fairseq/fairseq/clib/libnat_cuda/edit_dist.cu +344 -0
- fairseq/fairseq/clib/libnat_cuda/edit_dist.h +25 -0
- fairseq/fairseq/config/__init__.py +4 -0
- fairseq/fairseq/config/config.yaml +19 -0
- fairseq/fairseq/config/fb_run_config/slurm.yaml +29 -0
- fairseq/fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml +36 -0
- fairseq/fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml +36 -0
- fairseq/fairseq/config/model/transformer_lm/transformer_lm_big.yaml +36 -0
- fairseq/fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml +36 -0
- fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml +36 -0
- fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml +36 -0
- fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml +36 -0
- fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml +36 -0
- fairseq/fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml +36 -0
- fairseq/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml +5 -0
- fairseq/fairseq/config/model/wav2vec2/wav2vec2_base.yaml +8 -0
- fairseq/fairseq/config/model/wav2vec2/wav2vec2_large.yaml +20 -0
- fairseq/fairseq/criterions/__init__.py +36 -0
- fairseq/fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/adaptive_loss.py +124 -0
- fairseq/fairseq/criterions/composite_loss.py +100 -0
- fairseq/fairseq/criterions/cross_entropy.py +91 -0
- fairseq/fairseq/criterions/ctc.py +325 -0
- fairseq/fairseq/criterions/fairseq_criterion.py +121 -0
- fairseq/fairseq/criterions/fastspeech2_loss.py +137 -0
- fairseq/fairseq/criterions/hubert_criterion.py +195 -0
- fairseq/fairseq/criterions/label_smoothed_cross_entropy.py +168 -0
- fairseq/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py +221 -0
.gitattributes
CHANGED
@@ -39,3 +39,4 @@ fairseq/alignment_train_cuda_binding.cpython-310-x86_64-linux-gnu.so filter=lfs
|
|
39 |
fairseq/alignment_train_cpu_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
40 |
fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
fairseq/examples/hubert/tests/6313-76958-0021.flac filter=lfs diff=lfs merge=lfs -text
|
|
|
|
39 |
fairseq/alignment_train_cpu_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
40 |
fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
fairseq/examples/hubert/tests/6313-76958-0021.flac filter=lfs diff=lfs merge=lfs -text
|
42 |
+
fairseq/examples/textless_nlp/speech-resynth/img/fig.png filter=lfs diff=lfs merge=lfs -text
|
fairseq/examples/textless_nlp/speech-resynth/img/fig.png
ADDED
![]() |
Git LFS Details
|
fairseq/fairseq/benchmark/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# import models/tasks to register them
|
7 |
+
from . import dummy_dataset, dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa
|
fairseq/fairseq/benchmark/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (335 Bytes). View file
|
|
fairseq/fairseq/benchmark/__pycache__/dummy_dataset.cpython-310.pyc
ADDED
Binary file (1.79 kB). View file
|
|
fairseq/fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc
ADDED
Binary file (3.18 kB). View file
|
|
fairseq/fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc
ADDED
Binary file (3.39 kB). View file
|
|
fairseq/fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc
ADDED
Binary file (3.48 kB). View file
|
|
fairseq/fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc
ADDED
Binary file (4.56 kB). View file
|
|
fairseq/fairseq/benchmark/benchmark_multihead_attention.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import itertools
|
7 |
+
import random
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch.utils import benchmark
|
11 |
+
|
12 |
+
from fairseq.modules.multihead_attention import MultiheadAttention
|
13 |
+
|
14 |
+
BATCH = [20, 41, 97]
|
15 |
+
SEQ = 64
|
16 |
+
EMB = 48
|
17 |
+
HEADS = 4
|
18 |
+
DROP = 0.1
|
19 |
+
DEVICE = torch.device("cuda")
|
20 |
+
ATTN_MASK_DTYPE = [torch.uint8, torch.bool, torch.float]
|
21 |
+
KEY_PADDING_MASK_DTYPE = [torch.uint8, torch.bool]
|
22 |
+
|
23 |
+
|
24 |
+
def _reset_seeds():
|
25 |
+
torch.manual_seed(0)
|
26 |
+
random.seed(0)
|
27 |
+
|
28 |
+
|
29 |
+
def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
|
30 |
+
if to_dtype == torch.float:
|
31 |
+
mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
|
32 |
+
return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
|
33 |
+
return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)
|
34 |
+
|
35 |
+
|
36 |
+
def benchmark_multihead_attention(
|
37 |
+
label="",
|
38 |
+
attn_dtype=torch.uint8,
|
39 |
+
key_padding_dtype=torch.uint8,
|
40 |
+
add_bias_kv=False,
|
41 |
+
add_zero_attn=False,
|
42 |
+
static_kv=False,
|
43 |
+
batch_size=20,
|
44 |
+
embedding=EMB,
|
45 |
+
seq_len=SEQ,
|
46 |
+
num_heads=HEADS,
|
47 |
+
):
|
48 |
+
|
49 |
+
results = []
|
50 |
+
# device = torch.device("cuda")
|
51 |
+
|
52 |
+
xformers_att_config = '{"name": "scaled_dot_product"}'
|
53 |
+
|
54 |
+
attn_mask = _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len)
|
55 |
+
key_padding_mask = _get_mask(
|
56 |
+
to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len
|
57 |
+
)
|
58 |
+
|
59 |
+
q = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
|
60 |
+
k = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
|
61 |
+
v = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
|
62 |
+
|
63 |
+
_reset_seeds()
|
64 |
+
|
65 |
+
original_mha = MultiheadAttention(
|
66 |
+
embedding,
|
67 |
+
num_heads,
|
68 |
+
dropout=0.0,
|
69 |
+
xformers_att_config=None,
|
70 |
+
add_bias_kv=add_bias_kv,
|
71 |
+
add_zero_attn=add_zero_attn,
|
72 |
+
)
|
73 |
+
|
74 |
+
xformers_mha = MultiheadAttention(
|
75 |
+
embedding,
|
76 |
+
num_heads,
|
77 |
+
dropout=0.0,
|
78 |
+
xformers_att_config=xformers_att_config,
|
79 |
+
add_bias_kv=add_bias_kv,
|
80 |
+
add_zero_attn=add_zero_attn,
|
81 |
+
)
|
82 |
+
|
83 |
+
def original_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
|
84 |
+
original_mha(
|
85 |
+
query=q,
|
86 |
+
key=k,
|
87 |
+
value=v,
|
88 |
+
key_padding_mask=key_padding_mask,
|
89 |
+
attn_mask=attn_mask,
|
90 |
+
static_kv=static_kv,
|
91 |
+
)
|
92 |
+
|
93 |
+
def xformers_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
|
94 |
+
xformers_mha(
|
95 |
+
query=q,
|
96 |
+
key=k,
|
97 |
+
value=v,
|
98 |
+
key_padding_mask=key_padding_mask,
|
99 |
+
attn_mask=attn_mask,
|
100 |
+
static_kv=static_kv,
|
101 |
+
)
|
102 |
+
|
103 |
+
def original_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
|
104 |
+
output, _ = original_mha(
|
105 |
+
query=q,
|
106 |
+
key=k,
|
107 |
+
value=v,
|
108 |
+
key_padding_mask=key_padding_mask,
|
109 |
+
attn_mask=attn_mask,
|
110 |
+
static_kv=static_kv,
|
111 |
+
)
|
112 |
+
loss = torch.norm(output)
|
113 |
+
loss.backward()
|
114 |
+
|
115 |
+
def xformers_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
|
116 |
+
output, _ = xformers_mha(
|
117 |
+
query=q,
|
118 |
+
key=k,
|
119 |
+
value=v,
|
120 |
+
key_padding_mask=key_padding_mask,
|
121 |
+
attn_mask=attn_mask,
|
122 |
+
static_kv=static_kv,
|
123 |
+
)
|
124 |
+
loss = torch.norm(output)
|
125 |
+
loss.backward()
|
126 |
+
|
127 |
+
fns = [
|
128 |
+
original_bench_fw,
|
129 |
+
xformers_bench_fw,
|
130 |
+
original_bench_fw_bw,
|
131 |
+
xformers_bench_fw_bw,
|
132 |
+
]
|
133 |
+
|
134 |
+
for fn in fns:
|
135 |
+
results.append(
|
136 |
+
benchmark.Timer(
|
137 |
+
stmt="fn(q, k, v, key_padding_mask, attn_mask, static_kv)",
|
138 |
+
globals={
|
139 |
+
"q": q,
|
140 |
+
"k": k,
|
141 |
+
"v": v,
|
142 |
+
"key_padding_mask": key_padding_mask,
|
143 |
+
"attn_mask": attn_mask,
|
144 |
+
"static_kv": static_kv,
|
145 |
+
"fn": fn,
|
146 |
+
},
|
147 |
+
label="multihead fw + bw",
|
148 |
+
sub_label=f"{fn.__name__}",
|
149 |
+
description=label,
|
150 |
+
).blocked_autorange(min_run_time=1)
|
151 |
+
)
|
152 |
+
|
153 |
+
compare = benchmark.Compare(results)
|
154 |
+
compare.print()
|
155 |
+
|
156 |
+
|
157 |
+
def run_benchmarks():
|
158 |
+
for attn_dtype, key_padding_dtype, add_bias_kv, add_zero_attn in itertools.product(
|
159 |
+
ATTN_MASK_DTYPE, KEY_PADDING_MASK_DTYPE, [True, False], [True, False]
|
160 |
+
):
|
161 |
+
label = f"attn_dtype {attn_dtype}, key_padding_dtype {key_padding_dtype}, \
|
162 |
+
add_bias_kv {add_bias_kv}, add_zero_attn {add_zero_attn}"
|
163 |
+
benchmark_multihead_attention(
|
164 |
+
label=label,
|
165 |
+
attn_dtype=attn_dtype,
|
166 |
+
key_padding_dtype=key_padding_dtype,
|
167 |
+
add_bias_kv=add_bias_kv,
|
168 |
+
add_zero_attn=add_zero_attn,
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
run_benchmarks()
|
fairseq/fairseq/benchmark/dummy_dataset.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from fairseq.data import FairseqDataset
|
3 |
+
|
4 |
+
|
5 |
+
class DummyDataset(FairseqDataset):
|
6 |
+
def __init__(self, batch, num_items, item_size):
|
7 |
+
super().__init__()
|
8 |
+
self.batch = batch
|
9 |
+
self.num_items = num_items
|
10 |
+
self.item_size = item_size
|
11 |
+
|
12 |
+
def __getitem__(self, index):
|
13 |
+
return index
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return self.num_items
|
17 |
+
|
18 |
+
def collater(self, samples):
|
19 |
+
return self.batch
|
20 |
+
|
21 |
+
@property
|
22 |
+
def sizes(self):
|
23 |
+
return np.array([self.item_size] * self.num_items)
|
24 |
+
|
25 |
+
def num_tokens(self, index):
|
26 |
+
return self.item_size
|
27 |
+
|
28 |
+
def size(self, index):
|
29 |
+
return self.item_size
|
30 |
+
|
31 |
+
def ordered_indices(self):
|
32 |
+
return np.arange(self.num_items)
|
33 |
+
|
34 |
+
@property
|
35 |
+
def supports_prefetch(self):
|
36 |
+
return False
|
fairseq/fairseq/benchmark/dummy_lm.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from .dummy_dataset import DummyDataset
|
12 |
+
from fairseq.data import Dictionary
|
13 |
+
from fairseq.dataclass import FairseqDataclass
|
14 |
+
from fairseq.tasks import FairseqTask, register_task
|
15 |
+
from omegaconf import II
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class DummyLMConfig(FairseqDataclass):
|
23 |
+
dict_size: int = 49996
|
24 |
+
dataset_size: int = 100000
|
25 |
+
tokens_per_sample: int = field(
|
26 |
+
default=512, metadata={"help": "max sequence length"}
|
27 |
+
)
|
28 |
+
add_bos_token: bool = False
|
29 |
+
batch_size: Optional[int] = II("dataset.batch_size")
|
30 |
+
max_tokens: Optional[int] = II("dataset.max_tokens")
|
31 |
+
max_target_positions: int = II("task.tokens_per_sample")
|
32 |
+
|
33 |
+
|
34 |
+
@register_task("dummy_lm", dataclass=DummyLMConfig)
|
35 |
+
class DummyLMTask(FairseqTask):
|
36 |
+
def __init__(self, cfg: DummyLMConfig):
|
37 |
+
super().__init__(cfg)
|
38 |
+
|
39 |
+
# load dictionary
|
40 |
+
self.dictionary = Dictionary()
|
41 |
+
for i in range(cfg.dict_size):
|
42 |
+
self.dictionary.add_symbol("word{}".format(i))
|
43 |
+
self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
44 |
+
logger.info("dictionary: {} types".format(len(self.dictionary)))
|
45 |
+
|
46 |
+
seq = torch.arange(cfg.tokens_per_sample + 1) + self.dictionary.pad() + 1
|
47 |
+
|
48 |
+
self.dummy_src = seq[:-1]
|
49 |
+
self.dummy_tgt = seq[1:]
|
50 |
+
|
51 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
52 |
+
"""Load a given dataset split.
|
53 |
+
Args:
|
54 |
+
split (str): name of the split (e.g., train, valid, test)
|
55 |
+
"""
|
56 |
+
if self.cfg.batch_size is not None:
|
57 |
+
bsz = self.cfg.batch_size
|
58 |
+
else:
|
59 |
+
bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
|
60 |
+
self.datasets[split] = DummyDataset(
|
61 |
+
{
|
62 |
+
"id": 1,
|
63 |
+
"net_input": {
|
64 |
+
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
65 |
+
"src_lengths": torch.full(
|
66 |
+
(bsz,), self.cfg.tokens_per_sample, dtype=torch.long
|
67 |
+
),
|
68 |
+
},
|
69 |
+
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
70 |
+
"nsentences": bsz,
|
71 |
+
"ntokens": bsz * self.cfg.tokens_per_sample,
|
72 |
+
},
|
73 |
+
num_items=self.cfg.dataset_size,
|
74 |
+
item_size=self.cfg.tokens_per_sample,
|
75 |
+
)
|
76 |
+
|
77 |
+
@property
|
78 |
+
def source_dictionary(self):
|
79 |
+
return self.dictionary
|
80 |
+
|
81 |
+
@property
|
82 |
+
def target_dictionary(self):
|
83 |
+
return self.dictionary
|
fairseq/fairseq/benchmark/dummy_masked_lm.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from omegaconf import II
|
12 |
+
|
13 |
+
from .dummy_dataset import DummyDataset
|
14 |
+
from fairseq.data import Dictionary
|
15 |
+
from fairseq.dataclass import FairseqDataclass
|
16 |
+
from fairseq.tasks import FairseqTask, register_task
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class DummyMaskedLMConfig(FairseqDataclass):
|
23 |
+
dict_size: int = 49996
|
24 |
+
dataset_size: int = 100000
|
25 |
+
tokens_per_sample: int = field(
|
26 |
+
default=512,
|
27 |
+
metadata={
|
28 |
+
"help": "max number of total tokens over all"
|
29 |
+
" segments per sample for BERT dataset"
|
30 |
+
},
|
31 |
+
)
|
32 |
+
batch_size: Optional[int] = II("dataset.batch_size")
|
33 |
+
max_tokens: Optional[int] = II("dataset.max_tokens")
|
34 |
+
max_target_positions: int = II("task.tokens_per_sample")
|
35 |
+
|
36 |
+
|
37 |
+
@register_task("dummy_masked_lm", dataclass=DummyMaskedLMConfig)
|
38 |
+
class DummyMaskedLMTask(FairseqTask):
|
39 |
+
def __init__(self, cfg: DummyMaskedLMConfig):
|
40 |
+
super().__init__(cfg)
|
41 |
+
|
42 |
+
self.dictionary = Dictionary()
|
43 |
+
for i in range(cfg.dict_size):
|
44 |
+
self.dictionary.add_symbol("word{}".format(i))
|
45 |
+
logger.info("dictionary: {} types".format(len(self.dictionary)))
|
46 |
+
# add mask token
|
47 |
+
self.mask_idx = self.dictionary.add_symbol("<mask>")
|
48 |
+
self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
49 |
+
|
50 |
+
mask_idx = 0
|
51 |
+
pad_idx = 1
|
52 |
+
seq = torch.arange(cfg.tokens_per_sample) + pad_idx + 1
|
53 |
+
mask = torch.arange(2, cfg.tokens_per_sample, 7) # ~15%
|
54 |
+
src = seq.clone()
|
55 |
+
src[mask] = mask_idx
|
56 |
+
tgt = torch.full_like(seq, pad_idx)
|
57 |
+
tgt[mask] = seq[mask]
|
58 |
+
|
59 |
+
self.dummy_src = src
|
60 |
+
self.dummy_tgt = tgt
|
61 |
+
|
62 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
63 |
+
"""Load a given dataset split.
|
64 |
+
Args:
|
65 |
+
split (str): name of the split (e.g., train, valid, test)
|
66 |
+
"""
|
67 |
+
if self.cfg.batch_size is not None:
|
68 |
+
bsz = self.cfg.batch_size
|
69 |
+
else:
|
70 |
+
bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
|
71 |
+
self.datasets[split] = DummyDataset(
|
72 |
+
{
|
73 |
+
"id": 1,
|
74 |
+
"net_input": {
|
75 |
+
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
76 |
+
"src_lengths": torch.full(
|
77 |
+
(bsz,), self.cfg.tokens_per_sample, dtype=torch.long
|
78 |
+
),
|
79 |
+
},
|
80 |
+
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
81 |
+
"nsentences": bsz,
|
82 |
+
"ntokens": bsz * self.cfg.tokens_per_sample,
|
83 |
+
},
|
84 |
+
num_items=self.cfg.dataset_size,
|
85 |
+
item_size=self.cfg.tokens_per_sample,
|
86 |
+
)
|
87 |
+
|
88 |
+
@property
|
89 |
+
def source_dictionary(self):
|
90 |
+
return self.dictionary
|
91 |
+
|
92 |
+
@property
|
93 |
+
def target_dictionary(self):
|
94 |
+
return self.dictionary
|
fairseq/fairseq/benchmark/dummy_model.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from fairseq.data import Dictionary
|
9 |
+
from fairseq.models import (
|
10 |
+
FairseqDecoder,
|
11 |
+
FairseqLanguageModel,
|
12 |
+
register_model,
|
13 |
+
register_model_architecture,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
@register_model("dummy_model")
|
18 |
+
class DummyModel(FairseqLanguageModel):
|
19 |
+
def __init__(self, args, encoder):
|
20 |
+
super().__init__(encoder)
|
21 |
+
self.args = args
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def add_args(parser):
|
25 |
+
parser.add_argument("--num-layers", type=int, default=24)
|
26 |
+
parser.add_argument("--embed-dim", type=int, default=1024)
|
27 |
+
|
28 |
+
@classmethod
|
29 |
+
def build_model(cls, args, task):
|
30 |
+
encoder = DummyEncoder(
|
31 |
+
num_embed=len(task.target_dictionary),
|
32 |
+
embed_dim=args.embed_dim,
|
33 |
+
num_layers=args.num_layers,
|
34 |
+
)
|
35 |
+
return cls(args, encoder)
|
36 |
+
|
37 |
+
def forward(self, src_tokens, masked_tokens=None, **kwargs):
|
38 |
+
return self.decoder(src_tokens, masked_tokens=masked_tokens)
|
39 |
+
|
40 |
+
|
41 |
+
class DummyEncoder(FairseqDecoder):
|
42 |
+
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
|
43 |
+
super().__init__(Dictionary())
|
44 |
+
self.embed = nn.Embedding(
|
45 |
+
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
|
46 |
+
)
|
47 |
+
self.layers_a = nn.ModuleList(
|
48 |
+
[
|
49 |
+
nn.Sequential(
|
50 |
+
nn.LayerNorm(embed_dim),
|
51 |
+
nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection
|
52 |
+
nn.Linear(3 * embed_dim, embed_dim), # skip self-attention
|
53 |
+
nn.Linear(embed_dim, embed_dim), # output projection
|
54 |
+
nn.Dropout(),
|
55 |
+
)
|
56 |
+
for i in range(num_layers)
|
57 |
+
]
|
58 |
+
)
|
59 |
+
self.layers_b = nn.ModuleList(
|
60 |
+
[
|
61 |
+
nn.Sequential(
|
62 |
+
nn.LayerNorm(embed_dim),
|
63 |
+
nn.Linear(embed_dim, 4 * embed_dim), # FFN
|
64 |
+
nn.ReLU(),
|
65 |
+
nn.Linear(4 * embed_dim, embed_dim), # FFN
|
66 |
+
nn.Dropout(0.1),
|
67 |
+
)
|
68 |
+
for i in range(num_layers)
|
69 |
+
]
|
70 |
+
)
|
71 |
+
self.out_proj = nn.Linear(embed_dim, num_embed)
|
72 |
+
|
73 |
+
def forward(self, tokens, masked_tokens=None):
|
74 |
+
x = self.embed(tokens)
|
75 |
+
for layer_a, layer_b in zip(self.layers_a, self.layers_b):
|
76 |
+
x = x + layer_a(x)
|
77 |
+
x = x + layer_b(x)
|
78 |
+
x = self.out_proj(x)
|
79 |
+
if masked_tokens is not None:
|
80 |
+
x = x[masked_tokens]
|
81 |
+
return (x,)
|
82 |
+
|
83 |
+
def max_positions(self):
|
84 |
+
return 1024
|
85 |
+
|
86 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
87 |
+
logits = net_output[0].float()
|
88 |
+
if log_probs:
|
89 |
+
return F.log_softmax(logits, dim=-1)
|
90 |
+
else:
|
91 |
+
return F.softmax(logits, dim=-1)
|
92 |
+
|
93 |
+
|
94 |
+
@register_model_architecture("dummy_model", "dummy_model")
|
95 |
+
def base_architecture(args):
|
96 |
+
pass
|
fairseq/fairseq/benchmark/dummy_mt.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from fairseq.data import Dictionary, FairseqDataset
|
12 |
+
from fairseq.tasks import LegacyFairseqTask, register_task
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
@register_task("dummy_mt")
|
18 |
+
class DummyMTTask(LegacyFairseqTask):
|
19 |
+
@staticmethod
|
20 |
+
def add_args(parser):
|
21 |
+
"""Add task-specific arguments to the parser."""
|
22 |
+
parser.add_argument("--dict-size", default=49996, type=int)
|
23 |
+
parser.add_argument("--dataset-size", default=100000, type=int)
|
24 |
+
parser.add_argument("--src-len", default=30, type=int)
|
25 |
+
parser.add_argument("--tgt-len", default=30, type=int)
|
26 |
+
|
27 |
+
def __init__(self, args, dictionary):
|
28 |
+
super().__init__(args)
|
29 |
+
self.dictionary = dictionary
|
30 |
+
self.seed = args.seed
|
31 |
+
|
32 |
+
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
33 |
+
|
34 |
+
self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1
|
35 |
+
self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def setup_task(cls, args, **kwargs):
|
39 |
+
"""Setup the task."""
|
40 |
+
dictionary = Dictionary()
|
41 |
+
for i in range(args.dict_size):
|
42 |
+
dictionary.add_symbol("word{}".format(i))
|
43 |
+
logger.info("dictionary: {} types".format(len(dictionary)))
|
44 |
+
|
45 |
+
args.max_source_positions = args.src_len + dictionary.pad() + 2
|
46 |
+
args.max_target_positions = args.tgt_len + dictionary.pad() + 2
|
47 |
+
|
48 |
+
return cls(args, dictionary)
|
49 |
+
|
50 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
51 |
+
"""Load a given dataset split.
|
52 |
+
Args:
|
53 |
+
split (str): name of the split (e.g., train, valid, test)
|
54 |
+
"""
|
55 |
+
item_size = max(self.args.src_len, self.args.tgt_len)
|
56 |
+
if self.args.batch_size is not None:
|
57 |
+
bsz = self.args.batch_size
|
58 |
+
else:
|
59 |
+
bsz = max(1, self.args.max_tokens // item_size)
|
60 |
+
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
|
61 |
+
self.datasets[split] = DummyDataset(
|
62 |
+
{
|
63 |
+
"id": 1,
|
64 |
+
"net_input": {
|
65 |
+
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
66 |
+
"src_lengths": torch.full(
|
67 |
+
(bsz,), self.args.src_len, dtype=torch.long
|
68 |
+
),
|
69 |
+
"prev_output_tokens": tgt.clone(),
|
70 |
+
},
|
71 |
+
"target": tgt,
|
72 |
+
"nsentences": bsz,
|
73 |
+
"ntokens": bsz * self.args.tgt_len,
|
74 |
+
},
|
75 |
+
num_items=self.args.dataset_size,
|
76 |
+
item_size=item_size,
|
77 |
+
)
|
78 |
+
|
79 |
+
@property
|
80 |
+
def source_dictionary(self):
|
81 |
+
return self.dictionary
|
82 |
+
|
83 |
+
@property
|
84 |
+
def target_dictionary(self):
|
85 |
+
return self.dictionary
|
86 |
+
|
87 |
+
|
88 |
+
class DummyDataset(FairseqDataset):
|
89 |
+
def __init__(self, batch, num_items, item_size):
|
90 |
+
super().__init__()
|
91 |
+
self.batch = batch
|
92 |
+
self.num_items = num_items
|
93 |
+
self.item_size = item_size
|
94 |
+
|
95 |
+
def __getitem__(self, index):
|
96 |
+
return index
|
97 |
+
|
98 |
+
def __len__(self):
|
99 |
+
return self.num_items
|
100 |
+
|
101 |
+
def collater(self, samples):
|
102 |
+
return self.batch
|
103 |
+
|
104 |
+
@property
|
105 |
+
def sizes(self):
|
106 |
+
return np.array([self.item_size] * self.num_items)
|
107 |
+
|
108 |
+
def num_tokens(self, index):
|
109 |
+
return self.item_size
|
110 |
+
|
111 |
+
def size(self, index):
|
112 |
+
return self.item_size
|
113 |
+
|
114 |
+
def ordered_indices(self):
|
115 |
+
return np.arange(self.num_items)
|
116 |
+
|
117 |
+
@property
|
118 |
+
def supports_prefetch(self):
|
119 |
+
return False
|
fairseq/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
Copyright (c) Microsoft Corporation.
|
3 |
+
Licensed under the MIT License.
|
4 |
+
*/
|
5 |
+
|
6 |
+
#include <torch/extension.h>
|
7 |
+
#include <vector>
|
8 |
+
|
9 |
+
/*
|
10 |
+
CPP Binding for CUDA OP
|
11 |
+
*/
|
12 |
+
|
13 |
+
// CUDA forward declarations
|
14 |
+
torch::Tensor ngram_repeat_block_cuda_forward(
|
15 |
+
torch::Tensor tokens,
|
16 |
+
torch::Tensor lprobs,
|
17 |
+
int bsz,
|
18 |
+
int step,
|
19 |
+
int beam_size,
|
20 |
+
int no_repeat_ngram_size);
|
21 |
+
|
22 |
+
#define CHECK_CUDA(x) \
|
23 |
+
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
24 |
+
#define CHECK_CONTIGUOUS(x) \
|
25 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
26 |
+
#define CHECK_INPUT(x) \
|
27 |
+
CHECK_CUDA(x); \
|
28 |
+
CHECK_CONTIGUOUS(x)
|
29 |
+
|
30 |
+
// Input check and call to CUDA OP
|
31 |
+
// Backward method not required
|
32 |
+
torch::Tensor ngram_repeat_block_forward(
|
33 |
+
torch::Tensor tokens,
|
34 |
+
torch::Tensor lprobs,
|
35 |
+
int bsz,
|
36 |
+
int step,
|
37 |
+
int beam_size,
|
38 |
+
int no_repeat_ngram_size) {
|
39 |
+
CHECK_INPUT(tokens);
|
40 |
+
CHECK_INPUT(lprobs);
|
41 |
+
assert(bsz > 0);
|
42 |
+
assert(step >= 0);
|
43 |
+
assert(beam_size > 0);
|
44 |
+
assert(no_repeat_ngram_size > 0);
|
45 |
+
|
46 |
+
return ngram_repeat_block_cuda_forward(
|
47 |
+
tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size);
|
48 |
+
}
|
49 |
+
|
50 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
51 |
+
m.def(
|
52 |
+
"forward",
|
53 |
+
&ngram_repeat_block_forward,
|
54 |
+
"No Repeat Ngram Block forward (CUDA)");
|
55 |
+
}
|
fairseq/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
Copyright (c) Microsoft Corporation.
|
3 |
+
Licensed under the MIT License.
|
4 |
+
*/
|
5 |
+
|
6 |
+
/*
|
7 |
+
Kernel implementation for blocking repeated n-grams.
|
8 |
+
*/
|
9 |
+
|
10 |
+
#include <cuda.h>
|
11 |
+
#include <cuda_runtime.h>
|
12 |
+
#include <math.h>
|
13 |
+
#include <torch/extension.h>
|
14 |
+
#include <vector>
|
15 |
+
|
16 |
+
// Ban repeated ngrams of length = 'no_repeat_ngram_size'
|
17 |
+
__global__ void banRepeatedTokens(
|
18 |
+
long* __restrict__ tokens,
|
19 |
+
float* __restrict__ lprobs,
|
20 |
+
int max_predict_len,
|
21 |
+
int vocab_size,
|
22 |
+
int no_repeat_ngram_size) {
|
23 |
+
auto row = blockIdx.x;
|
24 |
+
auto col = threadIdx.x;
|
25 |
+
auto start = row * (max_predict_len) + col;
|
26 |
+
// Each thread compares ngram starting from
|
27 |
+
// thread index with final ngram starting from
|
28 |
+
// step - no_repeat_ngram_size +2
|
29 |
+
auto check_start_pos = blockDim.x;
|
30 |
+
auto lprob_start = row * vocab_size;
|
31 |
+
bool is_banned = true;
|
32 |
+
extern __shared__ long tokens_shm[];
|
33 |
+
tokens_shm[col] = tokens[start];
|
34 |
+
if (col == blockDim.x - 1) {
|
35 |
+
for (int i = 1; i < no_repeat_ngram_size; i++) {
|
36 |
+
if (col + i < max_predict_len) {
|
37 |
+
tokens_shm[col + i] = tokens[start + i];
|
38 |
+
}
|
39 |
+
}
|
40 |
+
}
|
41 |
+
__syncthreads();
|
42 |
+
|
43 |
+
for (int k = 0; k < no_repeat_ngram_size - 1; k++) {
|
44 |
+
if (tokens_shm[col + k] != tokens_shm[check_start_pos + k]) {
|
45 |
+
is_banned = false;
|
46 |
+
}
|
47 |
+
}
|
48 |
+
if (is_banned == true) {
|
49 |
+
auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1];
|
50 |
+
lprobs[lprob_start + token_to_be_banned] = -INFINITY;
|
51 |
+
}
|
52 |
+
}
|
53 |
+
|
54 |
+
// Allocate blocks and threads based on
|
55 |
+
// batch size and sequence length and launch
|
56 |
+
// kernel
|
57 |
+
torch::Tensor ngram_repeat_block_cuda_forward(
|
58 |
+
const torch::Tensor tokens,
|
59 |
+
torch::Tensor lprobs,
|
60 |
+
int bsz,
|
61 |
+
int step,
|
62 |
+
int beam_size,
|
63 |
+
int no_repeat_ngram_size) {
|
64 |
+
int threads = step - no_repeat_ngram_size + 2;
|
65 |
+
if (threads <= 0)
|
66 |
+
return lprobs;
|
67 |
+
int max_predict_len = tokens.size(1);
|
68 |
+
int vocab_size = lprobs.size(1);
|
69 |
+
auto token_ptr = tokens.data_ptr<long>();
|
70 |
+
auto lprob_ptr = lprobs.data_ptr<float>();
|
71 |
+
int blocks = bsz * beam_size;
|
72 |
+
int shared_mem_size = (step + 1) * sizeof(long);
|
73 |
+
|
74 |
+
// Launching N blocks where N is number of samples in a batch (beams*bsz)
|
75 |
+
// Launching T threads where T is number of previous ngrams in a sample
|
76 |
+
// Allocating shared mem per block for fastser access of input tokens since
|
77 |
+
// each token will be accessed N times to compare with current Ngram where
|
78 |
+
// N is Ngram size.
|
79 |
+
banRepeatedTokens<<<blocks, threads, shared_mem_size>>>(
|
80 |
+
token_ptr, lprob_ptr, max_predict_len, vocab_size, no_repeat_ngram_size);
|
81 |
+
return lprobs;
|
82 |
+
}
|
fairseq/fairseq/clib/libbase/balanced_assignment.cpp
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright 2017-present, Facebook, Inc.
|
3 |
+
* All rights reserved.
|
4 |
+
*
|
5 |
+
* This source code is licensed under the license found in the
|
6 |
+
* LICENSE file in the root directory of this source tree.
|
7 |
+
*/
|
8 |
+
|
9 |
+
/*
|
10 |
+
C++ code for solving the linear assignment problem.
|
11 |
+
Based on the Auction Algorithm from
|
12 |
+
https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the
|
13 |
+
implementation from: https://github.com/bkj/auction-lap Adapted to be more
|
14 |
+
efficient when each worker is looking for k jobs instead of 1.
|
15 |
+
*/
|
16 |
+
#include <torch/extension.h>
|
17 |
+
#include <iostream>
|
18 |
+
using namespace torch::indexing;
|
19 |
+
torch::Tensor balanced_assignment(torch::Tensor job_and_worker_to_score) {
|
20 |
+
int max_iterations = 100;
|
21 |
+
torch::Tensor epsilon =
|
22 |
+
(job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50;
|
23 |
+
epsilon.clamp_min_(1e-04);
|
24 |
+
torch::Tensor worker_and_job_to_score =
|
25 |
+
job_and_worker_to_score.detach().transpose(0, 1).contiguous();
|
26 |
+
int num_workers = worker_and_job_to_score.size(0);
|
27 |
+
int num_jobs = worker_and_job_to_score.size(1);
|
28 |
+
auto device = worker_and_job_to_score.device();
|
29 |
+
int jobs_per_worker = num_jobs / num_workers;
|
30 |
+
torch::Tensor value = worker_and_job_to_score.clone();
|
31 |
+
int counter = 0;
|
32 |
+
torch::Tensor max_value = worker_and_job_to_score.max();
|
33 |
+
|
34 |
+
torch::Tensor bid_indices;
|
35 |
+
torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs});
|
36 |
+
torch::Tensor bids =
|
37 |
+
worker_and_job_to_score.new_empty({num_workers, num_jobs});
|
38 |
+
torch::Tensor bid_increments =
|
39 |
+
worker_and_job_to_score.new_empty({num_workers, jobs_per_worker});
|
40 |
+
torch::Tensor top_values =
|
41 |
+
worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1});
|
42 |
+
torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs});
|
43 |
+
|
44 |
+
torch::Tensor top_index = top_values.to(torch::kLong);
|
45 |
+
torch::Tensor high_bidders = top_index.new_empty({num_jobs});
|
46 |
+
torch::Tensor have_bids = high_bidders.to(torch::kBool);
|
47 |
+
torch::Tensor jobs_indices =
|
48 |
+
torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device));
|
49 |
+
torch::Tensor true_tensor =
|
50 |
+
torch::ones({1}, torch::dtype(torch::kBool).device(device));
|
51 |
+
|
52 |
+
while (true) {
|
53 |
+
bids.zero_();
|
54 |
+
torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1);
|
55 |
+
|
56 |
+
// Each worker bids the difference in value between that job and the k+1th
|
57 |
+
// job
|
58 |
+
torch::sub_out(
|
59 |
+
bid_increments,
|
60 |
+
top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}),
|
61 |
+
top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1));
|
62 |
+
|
63 |
+
bid_increments.add_(epsilon);
|
64 |
+
bids.scatter_(
|
65 |
+
1,
|
66 |
+
top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}),
|
67 |
+
bid_increments);
|
68 |
+
|
69 |
+
if (counter < max_iterations && counter > 0) {
|
70 |
+
// Put in a minimal bid to retain items from the last round if no-one else
|
71 |
+
// bids for them this round
|
72 |
+
bids.view(-1).index_put_({bid_indices}, epsilon);
|
73 |
+
}
|
74 |
+
|
75 |
+
// Find the highest bidding worker per job
|
76 |
+
torch::max_out(high_bids, high_bidders, bids, 0);
|
77 |
+
torch::gt_out(have_bids, high_bids, 0);
|
78 |
+
|
79 |
+
if (have_bids.all().item<bool>()) {
|
80 |
+
// All jobs were bid for
|
81 |
+
break;
|
82 |
+
}
|
83 |
+
|
84 |
+
// Make popular items more expensive
|
85 |
+
cost.add_(high_bids);
|
86 |
+
torch::sub_out(value, worker_and_job_to_score, cost);
|
87 |
+
|
88 |
+
bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids});
|
89 |
+
|
90 |
+
if (counter < max_iterations) {
|
91 |
+
// Make sure that this item will be in the winning worker's top-k next
|
92 |
+
// time.
|
93 |
+
value.view(-1).index_put_({bid_indices}, max_value);
|
94 |
+
} else {
|
95 |
+
// Suboptimal approximation that converges quickly from current solution
|
96 |
+
value.view(-1).index_put_(
|
97 |
+
{bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices}));
|
98 |
+
}
|
99 |
+
|
100 |
+
counter += 1;
|
101 |
+
}
|
102 |
+
|
103 |
+
return top_index.index({Slice(None, None), Slice(0, jobs_per_worker)})
|
104 |
+
.reshape(-1);
|
105 |
+
}
|
106 |
+
|
107 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
108 |
+
m.def("balanced_assignment", &balanced_assignment, "Balanced Assignment");
|
109 |
+
}
|
fairseq/fairseq/clib/libbleu/libbleu.cpp
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright 2017-present, Facebook, Inc.
|
3 |
+
* All rights reserved.
|
4 |
+
*
|
5 |
+
* This source code is licensed under the license found in the
|
6 |
+
* LICENSE file in the root directory of this source tree.
|
7 |
+
*/
|
8 |
+
|
9 |
+
#include <array>
|
10 |
+
#include <cstdio>
|
11 |
+
#include <cstring>
|
12 |
+
#include <map>
|
13 |
+
|
14 |
+
// NOLINTNEXTLINE
|
15 |
+
typedef struct {
|
16 |
+
size_t reflen;
|
17 |
+
size_t predlen;
|
18 |
+
size_t match1;
|
19 |
+
size_t count1;
|
20 |
+
size_t match2;
|
21 |
+
size_t count2;
|
22 |
+
size_t match3;
|
23 |
+
size_t count3;
|
24 |
+
size_t match4;
|
25 |
+
size_t count4;
|
26 |
+
} bleu_stat;
|
27 |
+
|
28 |
+
// left trim (remove pad)
|
29 |
+
void bleu_ltrim(size_t* len, int** sent, int pad) {
|
30 |
+
size_t start = 0;
|
31 |
+
while (start < *len) {
|
32 |
+
if (*(*sent + start) != pad) {
|
33 |
+
break;
|
34 |
+
}
|
35 |
+
start++;
|
36 |
+
}
|
37 |
+
*sent += start;
|
38 |
+
*len -= start;
|
39 |
+
}
|
40 |
+
|
41 |
+
// right trim remove (eos)
|
42 |
+
void bleu_rtrim(size_t* len, int** sent, int pad, int eos) {
|
43 |
+
size_t end = *len - 1;
|
44 |
+
while (end > 0) {
|
45 |
+
if (*(*sent + end) != eos && *(*sent + end) != pad) {
|
46 |
+
break;
|
47 |
+
}
|
48 |
+
end--;
|
49 |
+
}
|
50 |
+
*len = end + 1;
|
51 |
+
}
|
52 |
+
|
53 |
+
// left and right trim
|
54 |
+
void bleu_trim(size_t* len, int** sent, int pad, int eos) {
|
55 |
+
bleu_ltrim(len, sent, pad);
|
56 |
+
bleu_rtrim(len, sent, pad, eos);
|
57 |
+
}
|
58 |
+
|
59 |
+
size_t bleu_hash(int len, int* data) {
|
60 |
+
size_t h = 14695981039346656037ul;
|
61 |
+
size_t prime = 0x100000001b3;
|
62 |
+
char* b = (char*)data;
|
63 |
+
size_t blen = sizeof(int) * len;
|
64 |
+
|
65 |
+
while (blen-- > 0) {
|
66 |
+
h ^= *b++;
|
67 |
+
h *= prime;
|
68 |
+
}
|
69 |
+
|
70 |
+
return h;
|
71 |
+
}
|
72 |
+
|
73 |
+
void bleu_addngram(
|
74 |
+
size_t* ntotal,
|
75 |
+
size_t* nmatch,
|
76 |
+
size_t n,
|
77 |
+
size_t reflen,
|
78 |
+
int* ref,
|
79 |
+
size_t predlen,
|
80 |
+
int* pred) {
|
81 |
+
if (predlen < n) {
|
82 |
+
return;
|
83 |
+
}
|
84 |
+
|
85 |
+
predlen = predlen - n + 1;
|
86 |
+
(*ntotal) += predlen;
|
87 |
+
|
88 |
+
if (reflen < n) {
|
89 |
+
return;
|
90 |
+
}
|
91 |
+
|
92 |
+
reflen = reflen - n + 1;
|
93 |
+
|
94 |
+
std::map<size_t, size_t> count;
|
95 |
+
while (predlen > 0) {
|
96 |
+
size_t w = bleu_hash(n, pred++);
|
97 |
+
count[w]++;
|
98 |
+
predlen--;
|
99 |
+
}
|
100 |
+
|
101 |
+
while (reflen > 0) {
|
102 |
+
size_t w = bleu_hash(n, ref++);
|
103 |
+
if (count[w] > 0) {
|
104 |
+
(*nmatch)++;
|
105 |
+
count[w] -= 1;
|
106 |
+
}
|
107 |
+
reflen--;
|
108 |
+
}
|
109 |
+
}
|
110 |
+
|
111 |
+
extern "C" {
|
112 |
+
|
113 |
+
#ifdef _WIN64
|
114 |
+
__declspec(dllexport)
|
115 |
+
#endif
|
116 |
+
void bleu_zero_init(bleu_stat* stat) {
|
117 |
+
std::memset(stat, 0, sizeof(bleu_stat));
|
118 |
+
}
|
119 |
+
|
120 |
+
#ifdef _WIN64
|
121 |
+
__declspec(dllexport)
|
122 |
+
#endif
|
123 |
+
void bleu_one_init(bleu_stat* stat) {
|
124 |
+
bleu_zero_init(stat);
|
125 |
+
stat->count1 = 0;
|
126 |
+
stat->count2 = 1;
|
127 |
+
stat->count3 = 1;
|
128 |
+
stat->count4 = 1;
|
129 |
+
stat->match1 = 0;
|
130 |
+
stat->match2 = 1;
|
131 |
+
stat->match3 = 1;
|
132 |
+
stat->match4 = 1;
|
133 |
+
}
|
134 |
+
|
135 |
+
#ifdef _WIN64
|
136 |
+
__declspec(dllexport)
|
137 |
+
#endif
|
138 |
+
void bleu_add(
|
139 |
+
bleu_stat* stat,
|
140 |
+
size_t reflen,
|
141 |
+
int* ref,
|
142 |
+
size_t predlen,
|
143 |
+
int* pred,
|
144 |
+
int pad,
|
145 |
+
int eos) {
|
146 |
+
|
147 |
+
bleu_trim(&reflen, &ref, pad, eos);
|
148 |
+
bleu_trim(&predlen, &pred, pad, eos);
|
149 |
+
stat->reflen += reflen;
|
150 |
+
stat->predlen += predlen;
|
151 |
+
|
152 |
+
bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred);
|
153 |
+
bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred);
|
154 |
+
bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred);
|
155 |
+
bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred);
|
156 |
+
}
|
157 |
+
}
|
fairseq/fairseq/clib/libbleu/module.cpp
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright 2017-present, Facebook, Inc.
|
3 |
+
* All rights reserved.
|
4 |
+
*
|
5 |
+
* This source code is licensed under the license found in the
|
6 |
+
* LICENSE file in the root directory of this source tree.
|
7 |
+
*/
|
8 |
+
|
9 |
+
#include <Python.h>
|
10 |
+
|
11 |
+
static PyMethodDef method_def[] = {{NULL, NULL, 0, NULL}}; // NOLINT
|
12 |
+
|
13 |
+
static struct PyModuleDef module_def = {
|
14 |
+
PyModuleDef_HEAD_INIT,
|
15 |
+
"libbleu", /* name of module */
|
16 |
+
// NOLINTNEXTLINE
|
17 |
+
NULL, /* module documentation, may be NULL */
|
18 |
+
-1, /* size of per-interpreter state of the module,
|
19 |
+
or -1 if the module keeps state in global variables. */
|
20 |
+
method_def}; // NOLINT
|
21 |
+
|
22 |
+
#if PY_MAJOR_VERSION == 2
|
23 |
+
PyMODINIT_FUNC init_libbleu()
|
24 |
+
#else
|
25 |
+
PyMODINIT_FUNC PyInit_libbleu()
|
26 |
+
#endif
|
27 |
+
{
|
28 |
+
PyObject* m = PyModule_Create(&module_def);
|
29 |
+
if (!m) {
|
30 |
+
return NULL;
|
31 |
+
}
|
32 |
+
return m;
|
33 |
+
}
|
fairseq/fairseq/clib/libnat/edit_dist.cpp
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright 2017-present, Facebook, Inc.
|
3 |
+
* All rights reserved.
|
4 |
+
*
|
5 |
+
* This source code is licensed under the license found in the
|
6 |
+
* LICENSE file in the root directory of this source tree.
|
7 |
+
*/
|
8 |
+
|
9 |
+
#include <pybind11/detail/common.h>
|
10 |
+
#include <pybind11/pybind11.h>
|
11 |
+
#include <torch/torch.h> // @manual=//caffe2:torch_extension
|
12 |
+
#include <algorithm>
|
13 |
+
#include <cstdint>
|
14 |
+
#include <iosfwd>
|
15 |
+
#include <memory>
|
16 |
+
#include <new>
|
17 |
+
#include <string>
|
18 |
+
#include <utility>
|
19 |
+
#include <vector>
|
20 |
+
|
21 |
+
using namespace ::std;
|
22 |
+
|
23 |
+
vector<vector<uint32_t>> edit_distance2_with_dp(
|
24 |
+
vector<uint32_t>& x,
|
25 |
+
vector<uint32_t>& y) {
|
26 |
+
uint32_t lx = x.size();
|
27 |
+
uint32_t ly = y.size();
|
28 |
+
vector<vector<uint32_t>> d(lx + 1, vector<uint32_t>(ly + 1));
|
29 |
+
for (uint32_t i = 0; i < lx + 1; i++) {
|
30 |
+
d[i][0] = i;
|
31 |
+
}
|
32 |
+
for (uint32_t j = 0; j < ly + 1; j++) {
|
33 |
+
d[0][j] = j;
|
34 |
+
}
|
35 |
+
for (uint32_t i = 1; i < lx + 1; i++) {
|
36 |
+
for (uint32_t j = 1; j < ly + 1; j++) {
|
37 |
+
d[i][j] =
|
38 |
+
min(min(d[i - 1][j], d[i][j - 1]) + 1,
|
39 |
+
d[i - 1][j - 1] + 2 * (x.at(i - 1) == y.at(j - 1) ? 0 : 1));
|
40 |
+
}
|
41 |
+
}
|
42 |
+
return d;
|
43 |
+
}
|
44 |
+
|
45 |
+
vector<vector<uint32_t>> edit_distance2_backtracking(
|
46 |
+
vector<vector<uint32_t>>& d,
|
47 |
+
vector<uint32_t>& x,
|
48 |
+
vector<uint32_t>& y,
|
49 |
+
uint32_t terminal_symbol) {
|
50 |
+
vector<uint32_t> seq;
|
51 |
+
vector<vector<uint32_t>> edit_seqs(x.size() + 2, vector<uint32_t>());
|
52 |
+
/*
|
53 |
+
edit_seqs:
|
54 |
+
0~x.size() cell is the insertion sequences
|
55 |
+
last cell is the delete sequence
|
56 |
+
*/
|
57 |
+
|
58 |
+
if (x.size() == 0) {
|
59 |
+
edit_seqs.at(0) = y;
|
60 |
+
return edit_seqs;
|
61 |
+
}
|
62 |
+
|
63 |
+
uint32_t i = d.size() - 1;
|
64 |
+
uint32_t j = d.at(0).size() - 1;
|
65 |
+
|
66 |
+
while ((i >= 0) && (j >= 0)) {
|
67 |
+
if ((i == 0) && (j == 0)) {
|
68 |
+
break;
|
69 |
+
}
|
70 |
+
|
71 |
+
if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
|
72 |
+
seq.push_back(1); // insert
|
73 |
+
seq.push_back(y.at(j - 1));
|
74 |
+
j--;
|
75 |
+
} else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
|
76 |
+
seq.push_back(2); // delete
|
77 |
+
seq.push_back(x.at(i - 1));
|
78 |
+
i--;
|
79 |
+
} else {
|
80 |
+
seq.push_back(3); // keep
|
81 |
+
seq.push_back(x.at(i - 1));
|
82 |
+
i--;
|
83 |
+
j--;
|
84 |
+
}
|
85 |
+
}
|
86 |
+
|
87 |
+
uint32_t prev_op, op, s, word;
|
88 |
+
prev_op = 0, s = 0;
|
89 |
+
for (uint32_t k = 0; k < seq.size() / 2; k++) {
|
90 |
+
op = seq.at(seq.size() - 2 * k - 2);
|
91 |
+
word = seq.at(seq.size() - 2 * k - 1);
|
92 |
+
if (prev_op != 1) {
|
93 |
+
s++;
|
94 |
+
}
|
95 |
+
if (op == 1) // insert
|
96 |
+
{
|
97 |
+
edit_seqs.at(s - 1).push_back(word);
|
98 |
+
} else if (op == 2) // delete
|
99 |
+
{
|
100 |
+
edit_seqs.at(x.size() + 1).push_back(1);
|
101 |
+
} else {
|
102 |
+
edit_seqs.at(x.size() + 1).push_back(0);
|
103 |
+
}
|
104 |
+
|
105 |
+
prev_op = op;
|
106 |
+
}
|
107 |
+
|
108 |
+
for (uint32_t k = 0; k < edit_seqs.size(); k++) {
|
109 |
+
if (edit_seqs[k].size() == 0) {
|
110 |
+
edit_seqs[k].push_back(terminal_symbol);
|
111 |
+
}
|
112 |
+
}
|
113 |
+
return edit_seqs;
|
114 |
+
}
|
115 |
+
|
116 |
+
vector<vector<uint32_t>> edit_distance2_backtracking_with_delete(
|
117 |
+
vector<vector<uint32_t>>& d,
|
118 |
+
vector<uint32_t>& x,
|
119 |
+
vector<uint32_t>& y,
|
120 |
+
uint32_t terminal_symbol,
|
121 |
+
uint32_t deletion_symbol) {
|
122 |
+
vector<uint32_t> seq;
|
123 |
+
vector<vector<uint32_t>> edit_seqs(x.size() + 1, vector<uint32_t>());
|
124 |
+
/*
|
125 |
+
edit_seqs:
|
126 |
+
0~x.size() cell is the insertion sequences
|
127 |
+
last cell is the delete sequence
|
128 |
+
*/
|
129 |
+
|
130 |
+
if (x.size() == 0) {
|
131 |
+
edit_seqs.at(0) = y;
|
132 |
+
return edit_seqs;
|
133 |
+
}
|
134 |
+
|
135 |
+
uint32_t i = d.size() - 1;
|
136 |
+
uint32_t j = d.at(0).size() - 1;
|
137 |
+
|
138 |
+
while ((i >= 0) && (j >= 0)) {
|
139 |
+
if ((i == 0) && (j == 0)) {
|
140 |
+
break;
|
141 |
+
}
|
142 |
+
|
143 |
+
if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
|
144 |
+
seq.push_back(1); // insert
|
145 |
+
seq.push_back(y.at(j - 1));
|
146 |
+
j--;
|
147 |
+
} else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
|
148 |
+
seq.push_back(2); // delete
|
149 |
+
seq.push_back(x.at(i - 1));
|
150 |
+
i--;
|
151 |
+
} else {
|
152 |
+
seq.push_back(3); // keep
|
153 |
+
seq.push_back(x.at(i - 1));
|
154 |
+
i--;
|
155 |
+
j--;
|
156 |
+
}
|
157 |
+
}
|
158 |
+
|
159 |
+
uint32_t prev_op, op, s, word;
|
160 |
+
prev_op = 0, s = 0;
|
161 |
+
for (uint32_t k = 0; k < seq.size() / 2; k++) {
|
162 |
+
op = seq.at(seq.size() - 2 * k - 2);
|
163 |
+
word = seq.at(seq.size() - 2 * k - 1);
|
164 |
+
if (prev_op != 1) {
|
165 |
+
s++;
|
166 |
+
}
|
167 |
+
if (op == 1) // insert
|
168 |
+
{
|
169 |
+
edit_seqs.at(s - 1).push_back(word);
|
170 |
+
} else if (op == 2) // delete
|
171 |
+
{
|
172 |
+
edit_seqs.at(s - 1).push_back(deletion_symbol);
|
173 |
+
}
|
174 |
+
|
175 |
+
prev_op = op;
|
176 |
+
}
|
177 |
+
|
178 |
+
for (uint32_t k = 0; k < edit_seqs.size(); k++) {
|
179 |
+
if (edit_seqs.at(k).size() == 0) {
|
180 |
+
edit_seqs.at(k).push_back(terminal_symbol);
|
181 |
+
}
|
182 |
+
}
|
183 |
+
return edit_seqs;
|
184 |
+
}
|
185 |
+
|
186 |
+
vector<uint32_t> compute_ed2(
|
187 |
+
vector<vector<uint32_t>>& xs,
|
188 |
+
vector<vector<uint32_t>>& ys) {
|
189 |
+
vector<uint32_t> distances(xs.size());
|
190 |
+
for (uint32_t i = 0; i < xs.size(); i++) {
|
191 |
+
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
|
192 |
+
distances.at(i) = d.at(xs.at(i).size()).at(ys.at(i).size());
|
193 |
+
}
|
194 |
+
return distances;
|
195 |
+
}
|
196 |
+
|
197 |
+
vector<vector<vector<uint32_t>>> suggested_ed2_path(
|
198 |
+
vector<vector<uint32_t>>& xs,
|
199 |
+
vector<vector<uint32_t>>& ys,
|
200 |
+
uint32_t terminal_symbol) {
|
201 |
+
vector<vector<vector<uint32_t>>> seq(xs.size());
|
202 |
+
for (uint32_t i = 0; i < xs.size(); i++) {
|
203 |
+
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
|
204 |
+
seq.at(i) =
|
205 |
+
edit_distance2_backtracking(d, xs.at(i), ys.at(i), terminal_symbol);
|
206 |
+
}
|
207 |
+
return seq;
|
208 |
+
}
|
209 |
+
|
210 |
+
vector<vector<vector<uint32_t>>> suggested_ed2_path_with_delete(
|
211 |
+
vector<vector<uint32_t>>& xs,
|
212 |
+
vector<vector<uint32_t>>& ys,
|
213 |
+
uint32_t terminal_symbol,
|
214 |
+
uint32_t deletion_symbol) {
|
215 |
+
vector<vector<vector<uint32_t>>> seq(xs.size());
|
216 |
+
for (uint32_t i = 0; i < xs.size(); i++) {
|
217 |
+
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
|
218 |
+
seq.at(i) = edit_distance2_backtracking_with_delete(
|
219 |
+
d, xs.at(i), ys.at(i), terminal_symbol, deletion_symbol);
|
220 |
+
}
|
221 |
+
return seq;
|
222 |
+
}
|
223 |
+
|
224 |
+
PYBIND11_MODULE(libnat, m) {
|
225 |
+
m.def("compute_ed2", &compute_ed2, "compute_ed2");
|
226 |
+
m.def("suggested_ed2_path", &suggested_ed2_path, "suggested_ed2_path");
|
227 |
+
m.def(
|
228 |
+
"suggested_ed2_path_with_delete",
|
229 |
+
&suggested_ed2_path_with_delete,
|
230 |
+
"suggested_ed2_path_with_delete");
|
231 |
+
}
|
fairseq/fairseq/clib/libnat_cuda/binding.cpp
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright 2017-present, Facebook, Inc.
|
3 |
+
* All rights reserved.
|
4 |
+
*
|
5 |
+
* This source code is licensed under the license found in the
|
6 |
+
* LICENSE file in the root directory of this source tree.
|
7 |
+
*/
|
8 |
+
|
9 |
+
/*
|
10 |
+
This code is partially adpoted from
|
11 |
+
https://github.com/1ytic/pytorch-edit-distance
|
12 |
+
*/
|
13 |
+
|
14 |
+
#include <torch/types.h>
|
15 |
+
#include "edit_dist.h"
|
16 |
+
|
17 |
+
#ifndef TORCH_CHECK
|
18 |
+
#define TORCH_CHECK AT_CHECK
|
19 |
+
#endif
|
20 |
+
|
21 |
+
#define CHECK_CUDA(x) \
|
22 |
+
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
23 |
+
#define CHECK_CONTIGUOUS(x) \
|
24 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
25 |
+
#define CHECK_INPUT(x) \
|
26 |
+
CHECK_CUDA(x); \
|
27 |
+
CHECK_CONTIGUOUS(x)
|
28 |
+
|
29 |
+
torch::Tensor LevenshteinDistance(
|
30 |
+
torch::Tensor source,
|
31 |
+
torch::Tensor target,
|
32 |
+
torch::Tensor source_length,
|
33 |
+
torch::Tensor target_length) {
|
34 |
+
CHECK_INPUT(source);
|
35 |
+
CHECK_INPUT(target);
|
36 |
+
CHECK_INPUT(source_length);
|
37 |
+
CHECK_INPUT(target_length);
|
38 |
+
return LevenshteinDistanceCuda(source, target, source_length, target_length);
|
39 |
+
}
|
40 |
+
|
41 |
+
torch::Tensor GenerateDeletionLabel(
|
42 |
+
torch::Tensor source,
|
43 |
+
torch::Tensor operations) {
|
44 |
+
CHECK_INPUT(source);
|
45 |
+
CHECK_INPUT(operations);
|
46 |
+
return GenerateDeletionLabelCuda(source, operations);
|
47 |
+
}
|
48 |
+
|
49 |
+
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabel(
|
50 |
+
torch::Tensor target,
|
51 |
+
torch::Tensor operations) {
|
52 |
+
CHECK_INPUT(target);
|
53 |
+
CHECK_INPUT(operations);
|
54 |
+
return GenerateInsertionLabelCuda(target, operations);
|
55 |
+
}
|
56 |
+
|
57 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
58 |
+
m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance");
|
59 |
+
m.def(
|
60 |
+
"generate_deletion_labels",
|
61 |
+
&GenerateDeletionLabel,
|
62 |
+
"Generate Deletion Label");
|
63 |
+
m.def(
|
64 |
+
"generate_insertion_labels",
|
65 |
+
&GenerateInsertionLabel,
|
66 |
+
"Generate Insertion Label");
|
67 |
+
}
|
fairseq/fairseq/clib/libnat_cuda/edit_dist.cu
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright 2017-present, Facebook, Inc.
|
3 |
+
* All rights reserved.
|
4 |
+
*
|
5 |
+
* This source code is licensed under the license found in the
|
6 |
+
* LICENSE file in the root directory of this source tree.
|
7 |
+
*/
|
8 |
+
|
9 |
+
#include "edit_dist.h"
|
10 |
+
|
11 |
+
#include <c10/cuda/CUDAStream.h>
|
12 |
+
#include <cuda.h>
|
13 |
+
#include <cuda_runtime.h>
|
14 |
+
#include <device_launch_parameters.h>
|
15 |
+
#include <utility> // std::pair
|
16 |
+
|
17 |
+
template <typename scalar_t>
|
18 |
+
__global__ void generate_deletion_label_kernel(
|
19 |
+
const scalar_t* __restrict__ source,
|
20 |
+
const size_t source_size,
|
21 |
+
const size_t operation_size,
|
22 |
+
int* __restrict__ operations,
|
23 |
+
int* __restrict__ labels) {
|
24 |
+
const int index = blockIdx.x;
|
25 |
+
const int offset = index * operation_size;
|
26 |
+
const int offset_label = index * source_size;
|
27 |
+
|
28 |
+
for (int i = 0; i < source_size; i++) {
|
29 |
+
labels[offset_label + i] = 0;
|
30 |
+
}
|
31 |
+
|
32 |
+
int k = 0;
|
33 |
+
for (int i = 0; i < operation_size; i++) {
|
34 |
+
if (operations[offset + i] == 0) {
|
35 |
+
break;
|
36 |
+
} else if (operations[offset + i] == 1) {
|
37 |
+
continue;
|
38 |
+
} else {
|
39 |
+
labels[offset_label + k] = 3 - operations[offset + i];
|
40 |
+
k++;
|
41 |
+
}
|
42 |
+
}
|
43 |
+
}
|
44 |
+
|
45 |
+
template <typename scalar_t>
|
46 |
+
__global__ void generate_insertion_label_kernel(
|
47 |
+
const scalar_t* __restrict__ target,
|
48 |
+
const size_t target_size,
|
49 |
+
const size_t operation_size,
|
50 |
+
int* __restrict__ operations,
|
51 |
+
int* __restrict__ labels,
|
52 |
+
int* __restrict__ masks) {
|
53 |
+
const int index = blockIdx.x;
|
54 |
+
const int offset = index * operation_size;
|
55 |
+
const int offset_label = index * target_size;
|
56 |
+
|
57 |
+
int k = 0;
|
58 |
+
int u = 0;
|
59 |
+
int m = 0;
|
60 |
+
|
61 |
+
for (int i = 0; i < target_size; i++) {
|
62 |
+
labels[offset_label + i] = 0;
|
63 |
+
masks[offset_label + i] = 0;
|
64 |
+
}
|
65 |
+
|
66 |
+
for (int i = 0; i < operation_size - 1; i++) {
|
67 |
+
if (operations[offset + i] == 0) {
|
68 |
+
break;
|
69 |
+
} else if (operations[offset + i] == 2) {
|
70 |
+
continue;
|
71 |
+
} else if (operations[offset + i] == 1) {
|
72 |
+
masks[offset_label + m] = 1;
|
73 |
+
u++;
|
74 |
+
m++;
|
75 |
+
} else {
|
76 |
+
labels[offset_label + k] = u;
|
77 |
+
masks[offset_label + m] = 0;
|
78 |
+
k++;
|
79 |
+
m++;
|
80 |
+
u = 0;
|
81 |
+
}
|
82 |
+
}
|
83 |
+
}
|
84 |
+
|
85 |
+
template <typename scalar_t>
|
86 |
+
__global__ void levenshtein_distance_kernel(
|
87 |
+
const scalar_t* __restrict__ source,
|
88 |
+
const scalar_t* __restrict__ target,
|
89 |
+
const int* __restrict__ source_length,
|
90 |
+
const int* __restrict__ target_length,
|
91 |
+
const size_t source_size,
|
92 |
+
const size_t target_size,
|
93 |
+
int* __restrict__ operations,
|
94 |
+
int* __restrict__ errors_curr) {
|
95 |
+
const int index = blockIdx.x;
|
96 |
+
const int offset = index * (source_size + target_size);
|
97 |
+
const int d = index * (source_size + 1) * (target_size + 1);
|
98 |
+
const int t = target_size + 1;
|
99 |
+
|
100 |
+
auto err_idx = [d, t](int i, int j) { return d + i * t + j; };
|
101 |
+
auto opt_idx = [offset](int k) { return offset + k; };
|
102 |
+
|
103 |
+
const int hyp_len = source_length[index];
|
104 |
+
const int ref_len = target_length[index];
|
105 |
+
const scalar_t* hyp_begin = source + index * source_size;
|
106 |
+
const scalar_t* ref_begin = target + index * target_size;
|
107 |
+
|
108 |
+
// dynamic programming
|
109 |
+
for (int i = 0; i <= hyp_len; i++) {
|
110 |
+
errors_curr[err_idx(i, 0)] = i;
|
111 |
+
}
|
112 |
+
for (int j = 0; j <= ref_len; j++) {
|
113 |
+
errors_curr[err_idx(0, j)] = j;
|
114 |
+
}
|
115 |
+
for (int i = 1; i <= hyp_len; i++) {
|
116 |
+
for (int j = 1; j <= ref_len; j++) {
|
117 |
+
errors_curr[err_idx(i, j)] = min(
|
118 |
+
min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) +
|
119 |
+
1,
|
120 |
+
errors_curr[err_idx(i - 1, j - 1)] +
|
121 |
+
2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1));
|
122 |
+
}
|
123 |
+
}
|
124 |
+
|
125 |
+
// back-tracing
|
126 |
+
int i = hyp_len;
|
127 |
+
int j = ref_len;
|
128 |
+
int o = hyp_len + ref_len;
|
129 |
+
|
130 |
+
for (int k = 0; k < source_size + target_size; k++) {
|
131 |
+
operations[opt_idx(k)] = 0;
|
132 |
+
}
|
133 |
+
|
134 |
+
while ((i >= 0) && (j >= 0)) {
|
135 |
+
if ((i == 0) && (j == 0)) {
|
136 |
+
break;
|
137 |
+
}
|
138 |
+
|
139 |
+
if ((j > 0) &&
|
140 |
+
(errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) {
|
141 |
+
o--;
|
142 |
+
operations[opt_idx(o)] = 1;
|
143 |
+
j--; // insertion
|
144 |
+
} else if (
|
145 |
+
(i > 0) &&
|
146 |
+
(errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) {
|
147 |
+
o--;
|
148 |
+
operations[opt_idx(o)] = 2;
|
149 |
+
i--; // deletion
|
150 |
+
} else {
|
151 |
+
o--;
|
152 |
+
operations[opt_idx(o)] = 3;
|
153 |
+
i--;
|
154 |
+
j--; // do nothing
|
155 |
+
}
|
156 |
+
}
|
157 |
+
|
158 |
+
// moving to the left
|
159 |
+
for (int k = 0; k < hyp_len + ref_len; k++) {
|
160 |
+
if (k + o < hyp_len + ref_len) {
|
161 |
+
operations[opt_idx(k)] = operations[opt_idx(k + o)];
|
162 |
+
} else {
|
163 |
+
operations[opt_idx(k)] = 0; // padding
|
164 |
+
}
|
165 |
+
}
|
166 |
+
}
|
167 |
+
|
168 |
+
template <typename scalar_t>
|
169 |
+
__global__ void faster_levenshtein_distance_kernel(
|
170 |
+
const scalar_t* __restrict__ source,
|
171 |
+
const scalar_t* __restrict__ target,
|
172 |
+
const int* __restrict__ source_length,
|
173 |
+
const int* __restrict__ target_length,
|
174 |
+
const size_t source_size,
|
175 |
+
const size_t target_size,
|
176 |
+
int* __restrict__ operations) {
|
177 |
+
extern __shared__ short errors[];
|
178 |
+
auto errors_curr = errors;
|
179 |
+
|
180 |
+
const int index = blockIdx.x;
|
181 |
+
const int offset = index * (source_size + target_size);
|
182 |
+
const int t = target_size + 1;
|
183 |
+
|
184 |
+
auto err_idx = [t](int i, int j) { return i * t + j; };
|
185 |
+
auto opt_idx = [offset](int k) { return offset + k; };
|
186 |
+
|
187 |
+
const int hyp_len = source_length[index];
|
188 |
+
const int ref_len = target_length[index];
|
189 |
+
const scalar_t* hyp_begin = source + index * source_size;
|
190 |
+
const scalar_t* ref_begin = target + index * target_size;
|
191 |
+
|
192 |
+
// dynamic programming
|
193 |
+
for (int i = 0; i <= hyp_len; i++) {
|
194 |
+
errors_curr[err_idx(i, 0)] = i;
|
195 |
+
}
|
196 |
+
for (int j = 0; j <= ref_len; j++) {
|
197 |
+
errors_curr[err_idx(0, j)] = j;
|
198 |
+
}
|
199 |
+
for (int i = 1; i <= hyp_len; i++) {
|
200 |
+
for (int j = 1; j <= ref_len; j++) {
|
201 |
+
errors_curr[err_idx(i, j)] = min(
|
202 |
+
min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) +
|
203 |
+
1,
|
204 |
+
errors_curr[err_idx(i - 1, j - 1)] +
|
205 |
+
2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1));
|
206 |
+
}
|
207 |
+
}
|
208 |
+
|
209 |
+
// back-tracing
|
210 |
+
int i = hyp_len;
|
211 |
+
int j = ref_len;
|
212 |
+
int o = hyp_len + ref_len;
|
213 |
+
|
214 |
+
for (int k = 0; k < source_size + target_size; k++) {
|
215 |
+
operations[opt_idx(k)] = 0;
|
216 |
+
}
|
217 |
+
|
218 |
+
while ((i >= 0) && (j >= 0)) {
|
219 |
+
if ((i == 0) && (j == 0)) {
|
220 |
+
break;
|
221 |
+
}
|
222 |
+
|
223 |
+
if ((j > 0) &&
|
224 |
+
(errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) {
|
225 |
+
o--;
|
226 |
+
operations[opt_idx(o)] = 1;
|
227 |
+
j--; // insertion
|
228 |
+
} else if (
|
229 |
+
(i > 0) &&
|
230 |
+
(errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) {
|
231 |
+
o--;
|
232 |
+
operations[opt_idx(o)] = 2;
|
233 |
+
i--; // deletion
|
234 |
+
} else {
|
235 |
+
o--;
|
236 |
+
operations[opt_idx(o)] = 3;
|
237 |
+
i--;
|
238 |
+
j--; // do nothing
|
239 |
+
}
|
240 |
+
}
|
241 |
+
|
242 |
+
// moving to the left
|
243 |
+
for (int k = 0; k < hyp_len + ref_len; k++) {
|
244 |
+
if (k + o < hyp_len + ref_len) {
|
245 |
+
operations[opt_idx(k)] = operations[opt_idx(k + o)];
|
246 |
+
} else {
|
247 |
+
operations[opt_idx(k)] = 0; // padding
|
248 |
+
}
|
249 |
+
}
|
250 |
+
}
|
251 |
+
|
252 |
+
torch::Tensor GenerateDeletionLabelCuda(
|
253 |
+
torch::Tensor source,
|
254 |
+
torch::Tensor operations) {
|
255 |
+
const auto batch_size = source.size(0);
|
256 |
+
at::TensorOptions options(source.device());
|
257 |
+
options = options.dtype(at::ScalarType::Int);
|
258 |
+
auto labels = torch::empty({batch_size, source.size(1)}, options);
|
259 |
+
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
|
260 |
+
|
261 |
+
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] {
|
262 |
+
generate_deletion_label_kernel<scalar_t>
|
263 |
+
<<<batch_size, 1, 0, stream>>>(
|
264 |
+
source.data_ptr<scalar_t>(),
|
265 |
+
source.size(1),
|
266 |
+
operations.size(1),
|
267 |
+
operations.data_ptr<int>(),
|
268 |
+
labels.data_ptr<int>());
|
269 |
+
}));
|
270 |
+
|
271 |
+
return labels;
|
272 |
+
}
|
273 |
+
|
274 |
+
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
|
275 |
+
torch::Tensor target,
|
276 |
+
torch::Tensor operations) {
|
277 |
+
const auto batch_size = target.size(0);
|
278 |
+
at::TensorOptions options(target.device());
|
279 |
+
options = options.dtype(at::ScalarType::Int);
|
280 |
+
auto labels = torch::empty({batch_size, target.size(1)}, options);
|
281 |
+
auto masks = torch::empty({batch_size, target.size(1)}, options);
|
282 |
+
auto stream = at::cuda::getCurrentCUDAStream(target.device().index());
|
283 |
+
|
284 |
+
AT_DISPATCH_ALL_TYPES(
|
285 |
+
target.scalar_type(), "generate_insertion_labels", ([&] {
|
286 |
+
generate_insertion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
|
287 |
+
target.data_ptr<scalar_t>(),
|
288 |
+
target.size(1),
|
289 |
+
operations.size(1),
|
290 |
+
operations.data_ptr<int>(),
|
291 |
+
labels.data_ptr<int>(),
|
292 |
+
masks.data_ptr<int>());
|
293 |
+
}));
|
294 |
+
|
295 |
+
return std::make_pair(labels, masks);
|
296 |
+
}
|
297 |
+
|
298 |
+
torch::Tensor LevenshteinDistanceCuda(
|
299 |
+
torch::Tensor source,
|
300 |
+
torch::Tensor target,
|
301 |
+
torch::Tensor source_length,
|
302 |
+
torch::Tensor target_length) {
|
303 |
+
const auto batch_size = source.size(0);
|
304 |
+
const auto shared_size =
|
305 |
+
(source.size(1) + 1) * (target.size(1) + 1) * sizeof(short);
|
306 |
+
|
307 |
+
at::TensorOptions options(source.device());
|
308 |
+
options = options.dtype(at::ScalarType::Int);
|
309 |
+
auto operations =
|
310 |
+
torch::empty({batch_size, source.size(1) + target.size(1)}, options);
|
311 |
+
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
|
312 |
+
|
313 |
+
if (shared_size > 40000) {
|
314 |
+
auto distances = torch::empty(
|
315 |
+
{batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options);
|
316 |
+
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] {
|
317 |
+
levenshtein_distance_kernel<scalar_t>
|
318 |
+
<<<batch_size, 1, 0, stream>>>(
|
319 |
+
source.data_ptr<scalar_t>(),
|
320 |
+
target.data_ptr<scalar_t>(),
|
321 |
+
source_length.data_ptr<int>(),
|
322 |
+
target_length.data_ptr<int>(),
|
323 |
+
source.size(1),
|
324 |
+
target.size(1),
|
325 |
+
operations.data_ptr<int>(),
|
326 |
+
distances.data_ptr<int>());
|
327 |
+
}));
|
328 |
+
} else {
|
329 |
+
AT_DISPATCH_ALL_TYPES(
|
330 |
+
source.scalar_type(), "faster_levenshtein_distance", ([&] {
|
331 |
+
faster_levenshtein_distance_kernel<scalar_t>
|
332 |
+
<<<batch_size, 1, shared_size, stream>>>(
|
333 |
+
source.data_ptr<scalar_t>(),
|
334 |
+
target.data_ptr<scalar_t>(),
|
335 |
+
source_length.data_ptr<int>(),
|
336 |
+
target_length.data_ptr<int>(),
|
337 |
+
source.size(1),
|
338 |
+
target.size(1),
|
339 |
+
operations.data_ptr<int>());
|
340 |
+
}));
|
341 |
+
}
|
342 |
+
|
343 |
+
return operations;
|
344 |
+
}
|
fairseq/fairseq/clib/libnat_cuda/edit_dist.h
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright 2017-present, Facebook, Inc.
|
3 |
+
* All rights reserved.
|
4 |
+
*
|
5 |
+
* This source code is licensed under the license found in the
|
6 |
+
* LICENSE file in the root directory of this source tree.
|
7 |
+
*/
|
8 |
+
|
9 |
+
#pragma once
|
10 |
+
|
11 |
+
#include <torch/extension.h>
|
12 |
+
|
13 |
+
torch::Tensor LevenshteinDistanceCuda(
|
14 |
+
torch::Tensor source,
|
15 |
+
torch::Tensor target,
|
16 |
+
torch::Tensor source_length,
|
17 |
+
torch::Tensor target_length);
|
18 |
+
|
19 |
+
torch::Tensor GenerateDeletionLabelCuda(
|
20 |
+
torch::Tensor source,
|
21 |
+
torch::Tensor operations);
|
22 |
+
|
23 |
+
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
|
24 |
+
torch::Tensor source,
|
25 |
+
torch::Tensor operations);
|
fairseq/fairseq/config/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
fairseq/fairseq/config/config.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
run:
|
5 |
+
dir: .
|
6 |
+
|
7 |
+
defaults:
|
8 |
+
- _self_
|
9 |
+
- task: null
|
10 |
+
- model: null
|
11 |
+
- criterion: cross_entropy
|
12 |
+
- optimizer: null
|
13 |
+
- lr_scheduler: fixed
|
14 |
+
- bpe: null
|
15 |
+
- tokenizer: null
|
16 |
+
- scoring: null
|
17 |
+
- generation: null
|
18 |
+
- common_eval: null
|
19 |
+
- eval_lm: null
|
fairseq/fairseq/config/fb_run_config/slurm.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '__'
|
9 |
+
exclude_keys:
|
10 |
+
- fb_run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
sweep:
|
13 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
14 |
+
launcher:
|
15 |
+
cpus_per_task: 60
|
16 |
+
gpus_per_node: ???
|
17 |
+
tasks_per_node: 1
|
18 |
+
nodes: 1
|
19 |
+
partition: learnfair
|
20 |
+
mem_gb: 400
|
21 |
+
timeout_min: 4320
|
22 |
+
max_num_timeout: 10
|
23 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
24 |
+
submitit_folder: ${hydra.sweep.dir}
|
25 |
+
|
26 |
+
distributed_training:
|
27 |
+
ddp_backend: c10d
|
28 |
+
distributed_world_size: ???
|
29 |
+
distributed_port: ???
|
fairseq/fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
activation_fn: "relu"
|
3 |
+
dropout: 0.1
|
4 |
+
attention_dropout: 0.1
|
5 |
+
activation_dropout: 0.0
|
6 |
+
relu_dropout: 0.0
|
7 |
+
decoder_embed_dim: 512
|
8 |
+
decoder_output_dim: 512
|
9 |
+
decoder_input_dim: 512
|
10 |
+
decoder_ffn_embed_dim: 4096
|
11 |
+
decoder_layers: 12
|
12 |
+
decoder_attention_heads: 16
|
13 |
+
decoder_normalize_before: true
|
14 |
+
no_decoder_final_norm: true
|
15 |
+
adaptive_softmax_cutoff: null
|
16 |
+
adaptive_softmax_dropout: 0
|
17 |
+
adaptive_softmax_factor: 4
|
18 |
+
no_token_positional_embeddings: false
|
19 |
+
share_decoder_input_output_embed: false
|
20 |
+
character_embeddings: false
|
21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
22 |
+
character_embedding_dim: 4
|
23 |
+
char_embedder_highway_layers: 2
|
24 |
+
adaptive_input: false
|
25 |
+
adaptive_input_factor: 4
|
26 |
+
adaptive_input_cutoff: null
|
27 |
+
tie_adaptive_weights: false
|
28 |
+
tie_adaptive_proj: false
|
29 |
+
decoder_learned_pos: false
|
30 |
+
decoder_layerdrop: 0
|
31 |
+
decoder_layers_to_keep: null
|
32 |
+
layernorm_embedding: false
|
33 |
+
no_scale_embedding: false
|
34 |
+
quant_noise_pq: 0
|
35 |
+
quant_noise_pq_block_size: 8
|
36 |
+
quant_noise_scalar: 0
|
fairseq/fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
activation_fn: "relu"
|
3 |
+
dropout: 0.3
|
4 |
+
attention_dropout: 0.1
|
5 |
+
activation_dropout: 0.1
|
6 |
+
relu_dropout: 0.1
|
7 |
+
decoder_embed_dim: 1024
|
8 |
+
decoder_output_dim: 1024
|
9 |
+
decoder_input_dim: 1024
|
10 |
+
decoder_ffn_embed_dim: 4096
|
11 |
+
decoder_layers: 16
|
12 |
+
decoder_attention_heads: 8
|
13 |
+
decoder_normalize_before: true
|
14 |
+
no_decoder_final_norm: true
|
15 |
+
adaptive_softmax_cutoff: "20000,60000"
|
16 |
+
adaptive_softmax_dropout: 0.2
|
17 |
+
adaptive_softmax_factor: 4
|
18 |
+
no_token_positional_embeddings: false
|
19 |
+
share_decoder_input_output_embed: false
|
20 |
+
character_embeddings: false
|
21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
22 |
+
character_embedding_dim: 4
|
23 |
+
char_embedder_highway_layers: 2
|
24 |
+
adaptive_input: true
|
25 |
+
adaptive_input_factor: 4
|
26 |
+
adaptive_input_cutoff: "20000,60000"
|
27 |
+
tie_adaptive_weights: true
|
28 |
+
tie_adaptive_proj: true
|
29 |
+
decoder_learned_pos: false
|
30 |
+
decoder_layerdrop: 0
|
31 |
+
decoder_layers_to_keep: null
|
32 |
+
layernorm_embedding: false
|
33 |
+
no_scale_embedding: false
|
34 |
+
quant_noise_pq: 0
|
35 |
+
quant_noise_pq_block_size: 8
|
36 |
+
quant_noise_scalar: 0
|
fairseq/fairseq/config/model/transformer_lm/transformer_lm_big.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
activation_fn: "relu"
|
3 |
+
dropout: 0.1
|
4 |
+
attention_dropout: 0.0
|
5 |
+
activation_dropout: 0.0
|
6 |
+
relu_dropout: 0.0
|
7 |
+
decoder_embed_dim: 1024
|
8 |
+
decoder_output_dim: 1024
|
9 |
+
decoder_input_dim: 1024
|
10 |
+
decoder_ffn_embed_dim: 4096
|
11 |
+
decoder_layers: 12
|
12 |
+
decoder_attention_heads: 16
|
13 |
+
decoder_normalize_before: true
|
14 |
+
no_decoder_final_norm: false
|
15 |
+
adaptive_softmax_cutoff: null
|
16 |
+
adaptive_softmax_dropout: 0
|
17 |
+
adaptive_softmax_factor: 4
|
18 |
+
no_token_positional_embeddings: false
|
19 |
+
share_decoder_input_output_embed: false
|
20 |
+
character_embeddings: false
|
21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
22 |
+
character_embedding_dim: 4
|
23 |
+
char_embedder_highway_layers: 2
|
24 |
+
adaptive_input: false
|
25 |
+
adaptive_input_factor: 4
|
26 |
+
adaptive_input_cutoff: null
|
27 |
+
tie_adaptive_weights: false
|
28 |
+
tie_adaptive_proj: false
|
29 |
+
decoder_learned_pos: false
|
30 |
+
decoder_layerdrop: 0
|
31 |
+
decoder_layers_to_keep: null
|
32 |
+
layernorm_embedding: false
|
33 |
+
no_scale_embedding: false
|
34 |
+
quant_noise_pq: 0
|
35 |
+
quant_noise_pq_block_size: 8
|
36 |
+
quant_noise_scalar: 0
|
fairseq/fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
activation_fn: "relu"
|
3 |
+
dropout: 0.1
|
4 |
+
attention_dropout: 0.1
|
5 |
+
activation_dropout: 0.0
|
6 |
+
relu_dropout: 0.0
|
7 |
+
decoder_embed_dim: 512
|
8 |
+
decoder_output_dim: 512
|
9 |
+
decoder_input_dim: 512
|
10 |
+
decoder_ffn_embed_dim: 4096
|
11 |
+
decoder_layers: 12
|
12 |
+
decoder_attention_heads: 16
|
13 |
+
decoder_normalize_before: true
|
14 |
+
no_decoder_final_norm: true
|
15 |
+
adaptive_softmax_cutoff: null
|
16 |
+
adaptive_softmax_dropout: 0
|
17 |
+
adaptive_softmax_factor: 4
|
18 |
+
no_token_positional_embeddings: false
|
19 |
+
share_decoder_input_output_embed: false
|
20 |
+
character_embeddings: false
|
21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
22 |
+
character_embedding_dim: 4
|
23 |
+
char_embedder_highway_layers: 2
|
24 |
+
adaptive_input: false
|
25 |
+
adaptive_input_factor: 4
|
26 |
+
adaptive_input_cutoff: null
|
27 |
+
tie_adaptive_weights: false
|
28 |
+
tie_adaptive_proj: false
|
29 |
+
decoder_learned_pos: false
|
30 |
+
decoder_layerdrop: 0
|
31 |
+
decoder_layers_to_keep: null
|
32 |
+
layernorm_embedding: false
|
33 |
+
no_scale_embedding: false
|
34 |
+
quant_noise_pq: 0
|
35 |
+
quant_noise_pq_block_size: 8
|
36 |
+
quant_noise_scalar: 0
|
fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
activation_fn: "gelu"
|
3 |
+
dropout: 0.1
|
4 |
+
attention_dropout: 0.1
|
5 |
+
activation_dropout: 0.0
|
6 |
+
relu_dropout: 0.0
|
7 |
+
decoder_embed_dim: 768
|
8 |
+
decoder_output_dim: 768
|
9 |
+
decoder_input_dim: 768
|
10 |
+
decoder_ffn_embed_dim: 3072
|
11 |
+
decoder_layers: 12
|
12 |
+
decoder_attention_heads: 12
|
13 |
+
decoder_normalize_before: true
|
14 |
+
no_decoder_final_norm: false
|
15 |
+
adaptive_softmax_cutoff: null
|
16 |
+
adaptive_softmax_dropout: 0
|
17 |
+
adaptive_softmax_factor: 4
|
18 |
+
no_token_positional_embeddings: false
|
19 |
+
share_decoder_input_output_embed: false
|
20 |
+
character_embeddings: false
|
21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
22 |
+
character_embedding_dim: 4
|
23 |
+
char_embedder_highway_layers: 2
|
24 |
+
adaptive_input: false
|
25 |
+
adaptive_input_factor: 4
|
26 |
+
adaptive_input_cutoff: null
|
27 |
+
tie_adaptive_weights: false
|
28 |
+
tie_adaptive_proj: false
|
29 |
+
decoder_learned_pos: false
|
30 |
+
decoder_layerdrop: 0
|
31 |
+
decoder_layers_to_keep: null
|
32 |
+
layernorm_embedding: false
|
33 |
+
no_scale_embedding: false
|
34 |
+
quant_noise_pq: 0
|
35 |
+
quant_noise_pq_block_size: 8
|
36 |
+
quant_noise_scalar: 0
|
fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
activation_fn: "gelu"
|
3 |
+
dropout: 0.1
|
4 |
+
attention_dropout: 0.1
|
5 |
+
activation_dropout: 0.0
|
6 |
+
relu_dropout: 0.0
|
7 |
+
decoder_embed_dim: 1600
|
8 |
+
decoder_output_dim: 1600
|
9 |
+
decoder_input_dim: 1600
|
10 |
+
decoder_ffn_embed_dim: 6400
|
11 |
+
decoder_layers: 48
|
12 |
+
decoder_attention_heads: 25
|
13 |
+
decoder_normalize_before: true
|
14 |
+
no_decoder_final_norm: false
|
15 |
+
adaptive_softmax_cutoff: null
|
16 |
+
adaptive_softmax_dropout: 0
|
17 |
+
adaptive_softmax_factor: 4
|
18 |
+
no_token_positional_embeddings: false
|
19 |
+
share_decoder_input_output_embed: false
|
20 |
+
character_embeddings: false
|
21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
22 |
+
character_embedding_dim: 4
|
23 |
+
char_embedder_highway_layers: 2
|
24 |
+
adaptive_input: false
|
25 |
+
adaptive_input_factor: 4
|
26 |
+
adaptive_input_cutoff: null
|
27 |
+
tie_adaptive_weights: false
|
28 |
+
tie_adaptive_proj: false
|
29 |
+
decoder_learned_pos: false
|
30 |
+
decoder_layerdrop: 0
|
31 |
+
decoder_layers_to_keep: null
|
32 |
+
layernorm_embedding: false
|
33 |
+
no_scale_embedding: false
|
34 |
+
quant_noise_pq: 0
|
35 |
+
quant_noise_pq_block_size: 8
|
36 |
+
quant_noise_scalar: 0
|
fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
activation_fn: "gelu"
|
3 |
+
dropout: 0.1
|
4 |
+
attention_dropout: 0.1
|
5 |
+
activation_dropout: 0.0
|
6 |
+
relu_dropout: 0.0
|
7 |
+
decoder_embed_dim: 1280
|
8 |
+
decoder_output_dim: 1280
|
9 |
+
decoder_input_dim: 1280
|
10 |
+
decoder_ffn_embed_dim: 5120
|
11 |
+
decoder_layers: 36
|
12 |
+
decoder_attention_heads: 20
|
13 |
+
decoder_normalize_before: true
|
14 |
+
no_decoder_final_norm: false
|
15 |
+
adaptive_softmax_cutoff: null
|
16 |
+
adaptive_softmax_dropout: 0
|
17 |
+
adaptive_softmax_factor: 4
|
18 |
+
no_token_positional_embeddings: false
|
19 |
+
share_decoder_input_output_embed: false
|
20 |
+
character_embeddings: false
|
21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
22 |
+
character_embedding_dim: 4
|
23 |
+
char_embedder_highway_layers: 2
|
24 |
+
adaptive_input: false
|
25 |
+
adaptive_input_factor: 4
|
26 |
+
adaptive_input_cutoff: null
|
27 |
+
tie_adaptive_weights: false
|
28 |
+
tie_adaptive_proj: false
|
29 |
+
decoder_learned_pos: false
|
30 |
+
decoder_layerdrop: 0
|
31 |
+
decoder_layers_to_keep: null
|
32 |
+
layernorm_embedding: false
|
33 |
+
no_scale_embedding: false
|
34 |
+
quant_noise_pq: 0
|
35 |
+
quant_noise_pq_block_size: 8
|
36 |
+
quant_noise_scalar: 0
|
fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
activation_fn: "gelu"
|
3 |
+
dropout: 0.1
|
4 |
+
attention_dropout: 0.1
|
5 |
+
activation_dropout: 0.0
|
6 |
+
relu_dropout: 0.0
|
7 |
+
decoder_embed_dim: 1024
|
8 |
+
decoder_output_dim: 1024
|
9 |
+
decoder_input_dim: 1024
|
10 |
+
decoder_ffn_embed_dim: 4096
|
11 |
+
decoder_layers: 24
|
12 |
+
decoder_attention_heads: 16
|
13 |
+
decoder_normalize_before: true
|
14 |
+
no_decoder_final_norm: false
|
15 |
+
adaptive_softmax_cutoff: null
|
16 |
+
adaptive_softmax_dropout: 0
|
17 |
+
adaptive_softmax_factor: 4
|
18 |
+
no_token_positional_embeddings: false
|
19 |
+
share_decoder_input_output_embed: false
|
20 |
+
character_embeddings: false
|
21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
22 |
+
character_embedding_dim: 4
|
23 |
+
char_embedder_highway_layers: 2
|
24 |
+
adaptive_input: false
|
25 |
+
adaptive_input_factor: 4
|
26 |
+
adaptive_input_cutoff: null
|
27 |
+
tie_adaptive_weights: false
|
28 |
+
tie_adaptive_proj: false
|
29 |
+
decoder_learned_pos: false
|
30 |
+
decoder_layerdrop: 0
|
31 |
+
decoder_layers_to_keep: null
|
32 |
+
layernorm_embedding: false
|
33 |
+
no_scale_embedding: false
|
34 |
+
quant_noise_pq: 0
|
35 |
+
quant_noise_pq_block_size: 8
|
36 |
+
quant_noise_scalar: 0
|
fairseq/fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
activation_fn: "relu"
|
3 |
+
dropout: 0.3
|
4 |
+
attention_dropout: 0.1
|
5 |
+
activation_dropout: 0.1
|
6 |
+
relu_dropout: 0.1
|
7 |
+
decoder_embed_dim: 1024
|
8 |
+
decoder_output_dim: 1024
|
9 |
+
decoder_input_dim: 1024
|
10 |
+
decoder_ffn_embed_dim: 4096
|
11 |
+
decoder_layers: 16
|
12 |
+
decoder_attention_heads: 8
|
13 |
+
decoder_normalize_before: true
|
14 |
+
no_decoder_final_norm: true
|
15 |
+
adaptive_softmax_cutoff: "20000,60000"
|
16 |
+
adaptive_softmax_dropout: 0.2
|
17 |
+
adaptive_softmax_factor: 4
|
18 |
+
no_token_positional_embeddings: false
|
19 |
+
share_decoder_input_output_embed: false
|
20 |
+
character_embeddings: false
|
21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
22 |
+
character_embedding_dim: 4
|
23 |
+
char_embedder_highway_layers: 2
|
24 |
+
adaptive_input: true
|
25 |
+
adaptive_input_factor: 4
|
26 |
+
adaptive_input_cutoff: "20000,60000"
|
27 |
+
tie_adaptive_weights: true
|
28 |
+
tie_adaptive_proj: true
|
29 |
+
decoder_learned_pos: false
|
30 |
+
decoder_layerdrop: 0
|
31 |
+
decoder_layers_to_keep: null
|
32 |
+
layernorm_embedding: false
|
33 |
+
no_scale_embedding: false
|
34 |
+
quant_noise_pq: 0
|
35 |
+
quant_noise_pq_block_size: 8
|
36 |
+
quant_noise_scalar: 0
|
fairseq/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
activation: gelu
|
3 |
+
vq_type: gumbel
|
4 |
+
vq_depth: 2
|
5 |
+
combine_groups: true
|
fairseq/fairseq/config/model/wav2vec2/wav2vec2_base.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
quantize_targets: true
|
4 |
+
final_dim: 256
|
5 |
+
encoder_layerdrop: 0.05
|
6 |
+
dropout_input: 0.1
|
7 |
+
dropout_features: 0.1
|
8 |
+
feature_grad_mult: 0.1
|
fairseq/fairseq/config/model/wav2vec2/wav2vec2_large.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
quantize_targets: true
|
4 |
+
extractor_mode: layer_norm
|
5 |
+
layer_norm_first: true
|
6 |
+
final_dim: 768
|
7 |
+
latent_temp: [2.0,0.1,0.999995]
|
8 |
+
encoder_layerdrop: 0.0
|
9 |
+
dropout_input: 0.0
|
10 |
+
dropout_features: 0.0
|
11 |
+
dropout: 0.0
|
12 |
+
attention_dropout: 0.0
|
13 |
+
conv_bias: true
|
14 |
+
|
15 |
+
encoder_layers: 24
|
16 |
+
encoder_embed_dim: 1024
|
17 |
+
encoder_ffn_embed_dim: 4096
|
18 |
+
encoder_attention_heads: 16
|
19 |
+
|
20 |
+
feature_grad_mult: 1.0
|
fairseq/fairseq/criterions/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
"""isort:skip_file"""
|
6 |
+
|
7 |
+
import importlib
|
8 |
+
import os
|
9 |
+
|
10 |
+
from fairseq import registry
|
11 |
+
from fairseq.criterions.fairseq_criterion import ( # noqa
|
12 |
+
FairseqCriterion,
|
13 |
+
LegacyFairseqCriterion,
|
14 |
+
)
|
15 |
+
from omegaconf import DictConfig
|
16 |
+
|
17 |
+
|
18 |
+
(
|
19 |
+
build_criterion_,
|
20 |
+
register_criterion,
|
21 |
+
CRITERION_REGISTRY,
|
22 |
+
CRITERION_DATACLASS_REGISTRY,
|
23 |
+
) = registry.setup_registry(
|
24 |
+
"--criterion", base_class=FairseqCriterion, default="cross_entropy"
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def build_criterion(cfg: DictConfig, task, from_checkpoint=False):
|
29 |
+
return build_criterion_(cfg, task, from_checkpoint=from_checkpoint)
|
30 |
+
|
31 |
+
|
32 |
+
# automatically import any Python files in the criterions/ directory
|
33 |
+
for file in sorted(os.listdir(os.path.dirname(__file__))):
|
34 |
+
if file.endswith(".py") and not file.startswith("_"):
|
35 |
+
file_name = file[: file.find(".py")]
|
36 |
+
importlib.import_module("fairseq.criterions." + file_name)
|
fairseq/fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc
ADDED
Binary file (4.53 kB). View file
|
|
fairseq/fairseq/criterions/adaptive_loss.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from fairseq import utils
|
11 |
+
from fairseq.logging import metrics
|
12 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
13 |
+
from fairseq.dataclass import FairseqDataclass
|
14 |
+
from fairseq.dataclass.constants import DDP_BACKEND_CHOICES
|
15 |
+
from omegaconf import II
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class AdaptiveLossConfig(FairseqDataclass):
|
20 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
21 |
+
ddp_backend: DDP_BACKEND_CHOICES = II("distributed_training.ddp_backend")
|
22 |
+
|
23 |
+
|
24 |
+
@register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig)
|
25 |
+
class AdaptiveLoss(FairseqCriterion):
|
26 |
+
"""This is an implementation of the loss function accompanying the adaptive softmax approximation for
|
27 |
+
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
|
28 |
+
(http://arxiv.org/abs/1609.04309)."""
|
29 |
+
|
30 |
+
def __init__(self, task, sentence_avg):
|
31 |
+
super().__init__(task)
|
32 |
+
self.sentence_avg = sentence_avg
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def build_criterion(cls, cfg: AdaptiveLossConfig, task):
|
36 |
+
if cfg.ddp_backend in {"c10d", "pytorch_ddp"}:
|
37 |
+
raise Exception(
|
38 |
+
"AdaptiveLoss is not compatible with the PyTorch "
|
39 |
+
"version of DistributedDataParallel. Please use "
|
40 |
+
"`--ddp-backend=legacy_ddp` instead."
|
41 |
+
)
|
42 |
+
return cls(task, cfg.sentence_avg)
|
43 |
+
|
44 |
+
def forward(self, model, sample, reduce=True):
|
45 |
+
"""Compute the loss for the given sample.
|
46 |
+
|
47 |
+
Returns a tuple with three elements:
|
48 |
+
1) the loss
|
49 |
+
2) the sample size, which is used as the denominator for the gradient
|
50 |
+
3) logging outputs to display while training
|
51 |
+
"""
|
52 |
+
|
53 |
+
assert (
|
54 |
+
hasattr(model.decoder, "adaptive_softmax")
|
55 |
+
and model.decoder.adaptive_softmax is not None
|
56 |
+
)
|
57 |
+
adaptive_softmax = model.decoder.adaptive_softmax
|
58 |
+
|
59 |
+
net_output = model(**sample["net_input"])
|
60 |
+
orig_target = model.get_targets(sample, net_output)
|
61 |
+
|
62 |
+
nsentences = orig_target.size(0)
|
63 |
+
orig_target = orig_target.view(-1)
|
64 |
+
|
65 |
+
bsz = orig_target.size(0)
|
66 |
+
|
67 |
+
logits, target = adaptive_softmax(net_output[0], orig_target)
|
68 |
+
assert len(target) == len(logits)
|
69 |
+
|
70 |
+
loss = net_output[0].new(1 if reduce else bsz).zero_()
|
71 |
+
|
72 |
+
for i in range(len(target)):
|
73 |
+
if target[i] is not None:
|
74 |
+
assert target[i].min() >= 0 and target[i].max() <= logits[i].size(1)
|
75 |
+
loss += F.cross_entropy(
|
76 |
+
logits[i],
|
77 |
+
target[i],
|
78 |
+
ignore_index=self.padding_idx,
|
79 |
+
reduction="sum" if reduce else "none",
|
80 |
+
)
|
81 |
+
|
82 |
+
orig = utils.strip_pad(orig_target, self.padding_idx)
|
83 |
+
ntokens = orig.numel()
|
84 |
+
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
|
85 |
+
logging_output = {
|
86 |
+
"loss": loss.data,
|
87 |
+
"ntokens": ntokens,
|
88 |
+
"nsentences": nsentences,
|
89 |
+
"sample_size": sample_size,
|
90 |
+
}
|
91 |
+
return loss, sample_size, logging_output
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
def reduce_metrics(logging_outputs) -> None:
|
95 |
+
"""Aggregate logging outputs from data parallel training."""
|
96 |
+
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
97 |
+
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
98 |
+
sample_size = utils.item(
|
99 |
+
sum(log.get("sample_size", 0) for log in logging_outputs)
|
100 |
+
)
|
101 |
+
|
102 |
+
metrics.log_scalar(
|
103 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
104 |
+
)
|
105 |
+
if sample_size != ntokens:
|
106 |
+
metrics.log_scalar(
|
107 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
108 |
+
)
|
109 |
+
metrics.log_derived(
|
110 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
111 |
+
)
|
112 |
+
else:
|
113 |
+
metrics.log_derived(
|
114 |
+
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
|
115 |
+
)
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def logging_outputs_can_be_summed() -> bool:
|
119 |
+
"""
|
120 |
+
Whether the logging outputs returned by `forward` can be summed
|
121 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
122 |
+
to True will improves distributed training speed.
|
123 |
+
"""
|
124 |
+
return True
|
fairseq/fairseq/criterions/composite_loss.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from fairseq import utils
|
7 |
+
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
|
11 |
+
@register_criterion("composite_loss")
|
12 |
+
class CompositeLoss(LegacyFairseqCriterion):
|
13 |
+
"""This is a composite loss that, given a list of model outputs and a list of targets,
|
14 |
+
computes an average of losses for each output-target pair"""
|
15 |
+
|
16 |
+
def __init__(self, args, task):
|
17 |
+
super().__init__(args, task)
|
18 |
+
self.underlying_criterion = args.underlying_criterion
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def add_args(parser):
|
22 |
+
"""Add criterion-specific arguments to the parser."""
|
23 |
+
# fmt: off
|
24 |
+
parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True,
|
25 |
+
help='underlying criterion to use for the composite loss')
|
26 |
+
# fmt: on
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def build_underlying_criterion(args, task):
|
30 |
+
saved_criterion = args.criterion
|
31 |
+
args.criterion = args.underlying_criterion
|
32 |
+
assert saved_criterion != args.underlying_criterion
|
33 |
+
underlying_criterion = task.build_criterion(args)
|
34 |
+
args.criterion = saved_criterion
|
35 |
+
return underlying_criterion
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def build_criterion(cls, args, task):
|
39 |
+
underlying_criterion = CompositeLoss.build_underlying_criterion(args, task)
|
40 |
+
|
41 |
+
class FakeModel(nn.Module):
|
42 |
+
def __init__(self, model, net_out, target):
|
43 |
+
super().__init__()
|
44 |
+
self.model = model
|
45 |
+
self.net_out = net_out
|
46 |
+
self.target = target
|
47 |
+
|
48 |
+
def forward(self, **unused):
|
49 |
+
return self.net_out
|
50 |
+
|
51 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
52 |
+
return self.model.get_normalized_probs(
|
53 |
+
net_output, log_probs, sample=sample
|
54 |
+
)
|
55 |
+
|
56 |
+
def get_targets(self, *unused):
|
57 |
+
return self.target
|
58 |
+
|
59 |
+
@property
|
60 |
+
def decoder(self):
|
61 |
+
return self.model.decoder
|
62 |
+
|
63 |
+
class _CompositeLoss(LegacyFairseqCriterion):
|
64 |
+
def __init__(self, args, task, underlying_criterion):
|
65 |
+
super().__init__(args, task)
|
66 |
+
self.underlying_criterion = underlying_criterion
|
67 |
+
|
68 |
+
def forward(self, model, sample, reduce=True):
|
69 |
+
net_outputs = model(**sample["net_input"])
|
70 |
+
targets = sample["target"]
|
71 |
+
|
72 |
+
bsz = targets[0].size(0)
|
73 |
+
loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_()
|
74 |
+
|
75 |
+
sample_size = 0
|
76 |
+
logging_output = {}
|
77 |
+
for o, t in zip(net_outputs[0], targets):
|
78 |
+
m = FakeModel(model, (o, net_outputs[1]), t)
|
79 |
+
sample["target"] = t
|
80 |
+
l, ss, logging_output = self.underlying_criterion(m, sample, reduce)
|
81 |
+
loss += l
|
82 |
+
sample_size += ss
|
83 |
+
|
84 |
+
loss.div_(len(targets))
|
85 |
+
sample_size /= len(targets)
|
86 |
+
|
87 |
+
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
|
88 |
+
return loss, sample_size, logging_output
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def aggregate_logging_outputs(logging_outputs):
|
92 |
+
return underlying_criterion.__class__.aggregate_logging_outputs(
|
93 |
+
logging_outputs
|
94 |
+
)
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def reduce_metrics(logging_outputs) -> None:
|
98 |
+
underlying_criterion.__class__.reduce_metrics(logging_outputs)
|
99 |
+
|
100 |
+
return _CompositeLoss(args, task, underlying_criterion)
|
fairseq/fairseq/criterions/cross_entropy.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from fairseq import utils
|
11 |
+
from fairseq.logging import metrics
|
12 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
13 |
+
from fairseq.dataclass import FairseqDataclass
|
14 |
+
from omegaconf import II
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class CrossEntropyCriterionConfig(FairseqDataclass):
|
19 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
20 |
+
|
21 |
+
|
22 |
+
@register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig)
|
23 |
+
class CrossEntropyCriterion(FairseqCriterion):
|
24 |
+
def __init__(self, task, sentence_avg):
|
25 |
+
super().__init__(task)
|
26 |
+
self.sentence_avg = sentence_avg
|
27 |
+
|
28 |
+
def forward(self, model, sample, reduce=True):
|
29 |
+
"""Compute the loss for the given sample.
|
30 |
+
|
31 |
+
Returns a tuple with three elements:
|
32 |
+
1) the loss
|
33 |
+
2) the sample size, which is used as the denominator for the gradient
|
34 |
+
3) logging outputs to display while training
|
35 |
+
"""
|
36 |
+
net_output = model(**sample["net_input"])
|
37 |
+
loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
|
38 |
+
sample_size = (
|
39 |
+
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
40 |
+
)
|
41 |
+
logging_output = {
|
42 |
+
"loss": loss.data,
|
43 |
+
"ntokens": sample["ntokens"],
|
44 |
+
"nsentences": sample["target"].size(0),
|
45 |
+
"sample_size": sample_size,
|
46 |
+
}
|
47 |
+
return loss, sample_size, logging_output
|
48 |
+
|
49 |
+
def compute_loss(self, model, net_output, sample, reduce=True):
|
50 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
51 |
+
lprobs = lprobs.view(-1, lprobs.size(-1))
|
52 |
+
target = model.get_targets(sample, net_output).view(-1)
|
53 |
+
loss = F.nll_loss(
|
54 |
+
lprobs,
|
55 |
+
target,
|
56 |
+
ignore_index=self.padding_idx,
|
57 |
+
reduction="sum" if reduce else "none",
|
58 |
+
)
|
59 |
+
return loss, loss
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def reduce_metrics(logging_outputs) -> None:
|
63 |
+
"""Aggregate logging outputs from data parallel training."""
|
64 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
65 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
66 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
67 |
+
|
68 |
+
# we divide by log(2) to convert the loss from base e to base 2
|
69 |
+
metrics.log_scalar(
|
70 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
71 |
+
)
|
72 |
+
if sample_size != ntokens:
|
73 |
+
metrics.log_scalar(
|
74 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
75 |
+
)
|
76 |
+
metrics.log_derived(
|
77 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
metrics.log_derived(
|
81 |
+
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
|
82 |
+
)
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def logging_outputs_can_be_summed() -> bool:
|
86 |
+
"""
|
87 |
+
Whether the logging outputs returned by `forward` can be summed
|
88 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
89 |
+
to True will improves distributed training speed.
|
90 |
+
"""
|
91 |
+
return True
|
fairseq/fairseq/criterions/ctc.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# All rights reserved.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the license found in the LICENSE file in
|
4 |
+
# the root directory of this source tree. An additional grant of patent rights
|
5 |
+
# can be found in the PATENTS file in the same directory.
|
6 |
+
|
7 |
+
import math
|
8 |
+
from argparse import Namespace
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
from omegaconf import II
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from fairseq import utils
|
17 |
+
from fairseq.logging import metrics
|
18 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
19 |
+
from fairseq.dataclass import FairseqDataclass
|
20 |
+
from fairseq.data.data_utils import post_process
|
21 |
+
from fairseq.tasks import FairseqTask
|
22 |
+
from fairseq.logging.meters import safe_round
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class CtcCriterionConfig(FairseqDataclass):
|
27 |
+
zero_infinity: bool = field(
|
28 |
+
default=False,
|
29 |
+
metadata={"help": "zero inf loss when source length <= target length"},
|
30 |
+
)
|
31 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
32 |
+
post_process: str = field(
|
33 |
+
default="letter",
|
34 |
+
metadata={
|
35 |
+
"help": "how to post process predictions into words. can be letter, "
|
36 |
+
"wordpiece, BPE symbols, etc. "
|
37 |
+
"See fairseq.data.data_utils.post_process() for full list of options"
|
38 |
+
},
|
39 |
+
)
|
40 |
+
wer_kenlm_model: Optional[str] = field(
|
41 |
+
default=None,
|
42 |
+
metadata={
|
43 |
+
"help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
|
44 |
+
},
|
45 |
+
)
|
46 |
+
wer_lexicon: Optional[str] = field(
|
47 |
+
default=None,
|
48 |
+
metadata={"help": "lexicon to use with wer_kenlm_model"},
|
49 |
+
)
|
50 |
+
wer_lm_weight: float = field(
|
51 |
+
default=2.0,
|
52 |
+
metadata={"help": "lm weight to use with wer_kenlm_model"},
|
53 |
+
)
|
54 |
+
wer_word_score: float = field(
|
55 |
+
default=-1.0,
|
56 |
+
metadata={"help": "lm word score to use with wer_kenlm_model"},
|
57 |
+
)
|
58 |
+
wer_sil_weight: float = field(
|
59 |
+
default=0,
|
60 |
+
metadata={"help": "lm word score to use with wer_kenlm_model"},
|
61 |
+
)
|
62 |
+
|
63 |
+
wer_args: Optional[str] = field(
|
64 |
+
default=None,
|
65 |
+
metadata={
|
66 |
+
"help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
|
67 |
+
},
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
@register_criterion("ctc", dataclass=CtcCriterionConfig)
|
72 |
+
class CtcCriterion(FairseqCriterion):
|
73 |
+
def __init__(
|
74 |
+
self, cfg: CtcCriterionConfig, task: FairseqTask, rdrop_alpha: int = 0.0
|
75 |
+
):
|
76 |
+
super().__init__(task)
|
77 |
+
self.blank_idx = (
|
78 |
+
task.target_dictionary.index(task.blank_symbol)
|
79 |
+
if hasattr(task, "blank_symbol")
|
80 |
+
else 0
|
81 |
+
)
|
82 |
+
self.pad_idx = task.target_dictionary.pad()
|
83 |
+
self.eos_idx = task.target_dictionary.eos()
|
84 |
+
self.post_process = cfg.post_process
|
85 |
+
|
86 |
+
self.rdrop_alpha = rdrop_alpha
|
87 |
+
|
88 |
+
if cfg.wer_args is not None:
|
89 |
+
(
|
90 |
+
cfg.wer_kenlm_model,
|
91 |
+
cfg.wer_lexicon,
|
92 |
+
cfg.wer_lm_weight,
|
93 |
+
cfg.wer_word_score,
|
94 |
+
) = eval(cfg.wer_args)
|
95 |
+
|
96 |
+
if cfg.wer_kenlm_model is not None and cfg.wer_kenlm_model != "":
|
97 |
+
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
|
98 |
+
|
99 |
+
dec_args = Namespace()
|
100 |
+
dec_args.nbest = 1
|
101 |
+
dec_args.criterion = "ctc"
|
102 |
+
dec_args.kenlm_model = cfg.wer_kenlm_model
|
103 |
+
dec_args.lexicon = cfg.wer_lexicon
|
104 |
+
dec_args.beam = 50
|
105 |
+
dec_args.beam_size_token = min(50, len(task.target_dictionary))
|
106 |
+
dec_args.beam_threshold = min(50, len(task.target_dictionary))
|
107 |
+
dec_args.lm_weight = cfg.wer_lm_weight
|
108 |
+
dec_args.word_score = cfg.wer_word_score
|
109 |
+
dec_args.sil_weight = cfg.wer_sil_weight
|
110 |
+
dec_args.unk_weight = -math.inf
|
111 |
+
dec_args.sil_weight = 0
|
112 |
+
|
113 |
+
self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
|
114 |
+
else:
|
115 |
+
self.w2l_decoder = None
|
116 |
+
|
117 |
+
self.zero_infinity = cfg.zero_infinity
|
118 |
+
self.sentence_avg = cfg.sentence_avg
|
119 |
+
|
120 |
+
def forward(self, model, sample, reduce=True, **kwargs):
|
121 |
+
net_output = model(**sample["net_input"])
|
122 |
+
lprobs = model.get_normalized_probs(
|
123 |
+
net_output, log_probs=True
|
124 |
+
).contiguous() # (T, B, C) from the encoder
|
125 |
+
|
126 |
+
# CTC loss is calculated over duplicated inputs
|
127 |
+
# sample is already duplicated for R-Drop
|
128 |
+
if self.rdrop_alpha > 0:
|
129 |
+
for k, v in sample.items():
|
130 |
+
if k in ["target", "target_lengths"]:
|
131 |
+
sample[k] = torch.cat([v, v.clone()], dim=0)
|
132 |
+
elif k == "net_input":
|
133 |
+
if sample[k]["src_tokens"].size(1) != sample[k]["src_lengths"].size(
|
134 |
+
0
|
135 |
+
):
|
136 |
+
# for decoder CTC loss
|
137 |
+
sample[k]["src_lengths"] = torch.cat(
|
138 |
+
[
|
139 |
+
sample[k]["src_lengths"],
|
140 |
+
sample[k]["src_lengths"].clone(),
|
141 |
+
],
|
142 |
+
dim=0,
|
143 |
+
)
|
144 |
+
|
145 |
+
if "src_lengths" in sample["net_input"]:
|
146 |
+
input_lengths = sample["net_input"]["src_lengths"]
|
147 |
+
else:
|
148 |
+
if net_output["padding_mask"] is not None:
|
149 |
+
non_padding_mask = ~net_output["padding_mask"]
|
150 |
+
input_lengths = non_padding_mask.long().sum(-1)
|
151 |
+
else:
|
152 |
+
input_lengths = lprobs.new_full(
|
153 |
+
(lprobs.size(1),), lprobs.size(0), dtype=torch.long
|
154 |
+
)
|
155 |
+
|
156 |
+
pad_mask = (sample["target"] != self.pad_idx) & (
|
157 |
+
sample["target"] != self.eos_idx
|
158 |
+
)
|
159 |
+
targets_flat = sample["target"].masked_select(pad_mask)
|
160 |
+
if "target_lengths" in sample:
|
161 |
+
target_lengths = sample["target_lengths"]
|
162 |
+
else:
|
163 |
+
target_lengths = pad_mask.sum(-1)
|
164 |
+
|
165 |
+
with torch.backends.cudnn.flags(enabled=False):
|
166 |
+
loss = F.ctc_loss(
|
167 |
+
lprobs,
|
168 |
+
targets_flat,
|
169 |
+
input_lengths,
|
170 |
+
target_lengths,
|
171 |
+
blank=self.blank_idx,
|
172 |
+
reduction="sum",
|
173 |
+
zero_infinity=self.zero_infinity,
|
174 |
+
)
|
175 |
+
|
176 |
+
ntokens = (
|
177 |
+
sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
|
178 |
+
)
|
179 |
+
|
180 |
+
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
|
181 |
+
logging_output = {
|
182 |
+
"loss": utils.item(loss.data), # * sample['ntokens'],
|
183 |
+
"ntokens": ntokens,
|
184 |
+
"nsentences": sample["id"].numel(),
|
185 |
+
"sample_size": sample_size,
|
186 |
+
}
|
187 |
+
|
188 |
+
if not model.training:
|
189 |
+
import editdistance
|
190 |
+
|
191 |
+
with torch.no_grad():
|
192 |
+
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
|
193 |
+
|
194 |
+
c_err = 0
|
195 |
+
c_len = 0
|
196 |
+
w_errs = 0
|
197 |
+
w_len = 0
|
198 |
+
wv_errs = 0
|
199 |
+
for lp, t, inp_l in zip(
|
200 |
+
lprobs_t,
|
201 |
+
sample["target_label"]
|
202 |
+
if "target_label" in sample
|
203 |
+
else sample["target"],
|
204 |
+
input_lengths,
|
205 |
+
):
|
206 |
+
lp = lp[:inp_l].unsqueeze(0)
|
207 |
+
|
208 |
+
decoded = None
|
209 |
+
if self.w2l_decoder is not None:
|
210 |
+
decoded = self.w2l_decoder.decode(lp)
|
211 |
+
if len(decoded) < 1:
|
212 |
+
decoded = None
|
213 |
+
else:
|
214 |
+
decoded = decoded[0]
|
215 |
+
if len(decoded) < 1:
|
216 |
+
decoded = None
|
217 |
+
else:
|
218 |
+
decoded = decoded[0]
|
219 |
+
|
220 |
+
p = (t != self.task.target_dictionary.pad()) & (
|
221 |
+
t != self.task.target_dictionary.eos()
|
222 |
+
)
|
223 |
+
targ = t[p]
|
224 |
+
targ_units = self.task.target_dictionary.string(targ)
|
225 |
+
targ_units_arr = targ.tolist()
|
226 |
+
|
227 |
+
toks = lp.argmax(dim=-1).unique_consecutive()
|
228 |
+
pred_units_arr = toks[toks != self.blank_idx].tolist()
|
229 |
+
|
230 |
+
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
|
231 |
+
c_len += len(targ_units_arr)
|
232 |
+
|
233 |
+
targ_words = post_process(targ_units, self.post_process).split()
|
234 |
+
|
235 |
+
pred_units = self.task.target_dictionary.string(pred_units_arr)
|
236 |
+
pred_words_raw = post_process(pred_units, self.post_process).split()
|
237 |
+
|
238 |
+
if decoded is not None and "words" in decoded:
|
239 |
+
pred_words = decoded["words"]
|
240 |
+
w_errs += editdistance.eval(pred_words, targ_words)
|
241 |
+
wv_errs += editdistance.eval(pred_words_raw, targ_words)
|
242 |
+
else:
|
243 |
+
dist = editdistance.eval(pred_words_raw, targ_words)
|
244 |
+
w_errs += dist
|
245 |
+
wv_errs += dist
|
246 |
+
|
247 |
+
w_len += len(targ_words)
|
248 |
+
|
249 |
+
logging_output["wv_errors"] = wv_errs
|
250 |
+
logging_output["w_errors"] = w_errs
|
251 |
+
logging_output["w_total"] = w_len
|
252 |
+
logging_output["c_errors"] = c_err
|
253 |
+
logging_output["c_total"] = c_len
|
254 |
+
|
255 |
+
return loss, sample_size, logging_output
|
256 |
+
|
257 |
+
@staticmethod
|
258 |
+
def reduce_metrics(logging_outputs) -> None:
|
259 |
+
"""Aggregate logging outputs from data parallel training."""
|
260 |
+
|
261 |
+
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
262 |
+
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
263 |
+
nsentences = utils.item(
|
264 |
+
sum(log.get("nsentences", 0) for log in logging_outputs)
|
265 |
+
)
|
266 |
+
sample_size = utils.item(
|
267 |
+
sum(log.get("sample_size", 0) for log in logging_outputs)
|
268 |
+
)
|
269 |
+
|
270 |
+
metrics.log_scalar(
|
271 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
272 |
+
)
|
273 |
+
metrics.log_scalar("ntokens", ntokens)
|
274 |
+
metrics.log_scalar("nsentences", nsentences)
|
275 |
+
if sample_size != ntokens:
|
276 |
+
metrics.log_scalar(
|
277 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
278 |
+
)
|
279 |
+
|
280 |
+
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
|
281 |
+
metrics.log_scalar("_c_errors", c_errors)
|
282 |
+
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
|
283 |
+
metrics.log_scalar("_c_total", c_total)
|
284 |
+
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
|
285 |
+
metrics.log_scalar("_w_errors", w_errors)
|
286 |
+
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
|
287 |
+
metrics.log_scalar("_wv_errors", wv_errors)
|
288 |
+
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
|
289 |
+
metrics.log_scalar("_w_total", w_total)
|
290 |
+
|
291 |
+
if c_total > 0:
|
292 |
+
metrics.log_derived(
|
293 |
+
"uer",
|
294 |
+
lambda meters: safe_round(
|
295 |
+
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
|
296 |
+
)
|
297 |
+
if meters["_c_total"].sum > 0
|
298 |
+
else float("nan"),
|
299 |
+
)
|
300 |
+
if w_total > 0:
|
301 |
+
metrics.log_derived(
|
302 |
+
"wer",
|
303 |
+
lambda meters: safe_round(
|
304 |
+
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
305 |
+
)
|
306 |
+
if meters["_w_total"].sum > 0
|
307 |
+
else float("nan"),
|
308 |
+
)
|
309 |
+
metrics.log_derived(
|
310 |
+
"raw_wer",
|
311 |
+
lambda meters: safe_round(
|
312 |
+
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
313 |
+
)
|
314 |
+
if meters["_w_total"].sum > 0
|
315 |
+
else float("nan"),
|
316 |
+
)
|
317 |
+
|
318 |
+
@staticmethod
|
319 |
+
def logging_outputs_can_be_summed() -> bool:
|
320 |
+
"""
|
321 |
+
Whether the logging outputs returned by `forward` can be summed
|
322 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
323 |
+
to True will improves distributed training speed.
|
324 |
+
"""
|
325 |
+
return True
|
fairseq/fairseq/criterions/fairseq_criterion.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import inspect
|
7 |
+
from typing import Any, Dict, List
|
8 |
+
|
9 |
+
from fairseq import utils
|
10 |
+
from fairseq.logging import metrics
|
11 |
+
from fairseq.dataclass import FairseqDataclass
|
12 |
+
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
13 |
+
from torch.nn.modules.loss import _Loss
|
14 |
+
|
15 |
+
|
16 |
+
class FairseqCriterion(_Loss):
|
17 |
+
def __init__(self, task):
|
18 |
+
super().__init__()
|
19 |
+
self.task = task
|
20 |
+
if hasattr(task, "target_dictionary"):
|
21 |
+
tgt_dict = task.target_dictionary
|
22 |
+
self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def add_args(cls, parser):
|
26 |
+
"""Add criterion-specific arguments to the parser."""
|
27 |
+
dc = getattr(cls, "__dataclass", None)
|
28 |
+
if dc is not None:
|
29 |
+
gen_parser_from_dataclass(parser, dc())
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def build_criterion(cls, cfg: FairseqDataclass, task):
|
33 |
+
"""Construct a criterion from command-line args."""
|
34 |
+
# arguments in the __init__.
|
35 |
+
init_args = {}
|
36 |
+
for p in inspect.signature(cls).parameters.values():
|
37 |
+
if (
|
38 |
+
p.kind == p.POSITIONAL_ONLY
|
39 |
+
or p.kind == p.VAR_POSITIONAL
|
40 |
+
or p.kind == p.VAR_KEYWORD
|
41 |
+
):
|
42 |
+
# we haven't implemented inference for these argument types,
|
43 |
+
# but PRs welcome :)
|
44 |
+
raise NotImplementedError("{} not supported".format(p.kind))
|
45 |
+
|
46 |
+
assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
|
47 |
+
|
48 |
+
if p.name == "task":
|
49 |
+
init_args["task"] = task
|
50 |
+
elif p.name == "cfg":
|
51 |
+
init_args["cfg"] = cfg
|
52 |
+
elif hasattr(cfg, p.name):
|
53 |
+
init_args[p.name] = getattr(cfg, p.name)
|
54 |
+
elif p.default != p.empty:
|
55 |
+
pass # we'll use the default value
|
56 |
+
else:
|
57 |
+
raise NotImplementedError(
|
58 |
+
"Unable to infer Criterion arguments, please implement "
|
59 |
+
"{}.build_criterion".format(cls.__name__)
|
60 |
+
)
|
61 |
+
return cls(**init_args)
|
62 |
+
|
63 |
+
def forward(self, model, sample, reduce=True):
|
64 |
+
"""Compute the loss for the given sample.
|
65 |
+
|
66 |
+
Returns a tuple with three elements:
|
67 |
+
1) the loss
|
68 |
+
2) the sample size, which is used as the denominator for the gradient
|
69 |
+
3) logging outputs to display while training
|
70 |
+
"""
|
71 |
+
raise NotImplementedError
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def aggregate_logging_outputs(
|
75 |
+
logging_outputs: List[Dict[str, Any]]
|
76 |
+
) -> Dict[str, Any]:
|
77 |
+
"""Aggregate logging outputs from data parallel training."""
|
78 |
+
utils.deprecation_warning(
|
79 |
+
"The aggregate_logging_outputs API is deprecated. "
|
80 |
+
"Please use the reduce_metrics API instead."
|
81 |
+
)
|
82 |
+
raise NotImplementedError
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
|
86 |
+
"""Aggregate logging outputs from data parallel training."""
|
87 |
+
utils.deprecation_warning(
|
88 |
+
"Criterions should implement the reduce_metrics API. "
|
89 |
+
"Falling back to deprecated aggregate_logging_outputs API."
|
90 |
+
)
|
91 |
+
agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
|
92 |
+
for k, v in agg_logging_outputs.items():
|
93 |
+
if k in {"nsentences", "ntokens", "sample_size"}:
|
94 |
+
continue
|
95 |
+
metrics.log_scalar(k, v)
|
96 |
+
|
97 |
+
@staticmethod
|
98 |
+
def logging_outputs_can_be_summed() -> bool:
|
99 |
+
"""
|
100 |
+
Whether the logging outputs returned by `forward` can be summed
|
101 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
102 |
+
to True will improves distributed training speed.
|
103 |
+
"""
|
104 |
+
return False
|
105 |
+
|
106 |
+
|
107 |
+
class LegacyFairseqCriterion(FairseqCriterion):
|
108 |
+
def __init__(self, args, task):
|
109 |
+
super().__init__(task=task)
|
110 |
+
self.args = args
|
111 |
+
|
112 |
+
utils.deprecation_warning(
|
113 |
+
"Criterions should take explicit arguments instead of an "
|
114 |
+
"argparse.Namespace object, please update your criterion by "
|
115 |
+
"extending FairseqCriterion instead of LegacyFairseqCriterion."
|
116 |
+
)
|
117 |
+
|
118 |
+
@classmethod
|
119 |
+
def build_criterion(cls, args, task):
|
120 |
+
"""Construct a criterion from command-line args."""
|
121 |
+
return cls(args, task)
|
fairseq/fairseq/criterions/fastspeech2_loss.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the LICENSE file in
|
5 |
+
# the root directory of this source tree. An additional grant of patent rights
|
6 |
+
# can be found in the PATENTS file in the same directory.
|
7 |
+
|
8 |
+
from typing import List, Dict, Any
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from fairseq import utils
|
15 |
+
from fairseq.logging import metrics
|
16 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
17 |
+
from fairseq.dataclass import FairseqDataclass
|
18 |
+
from fairseq.data.data_utils import lengths_to_mask
|
19 |
+
from fairseq.models.fairseq_model import FairseqEncoderModel
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class FastSpeech2CriterionConfig(FairseqDataclass):
|
24 |
+
ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
|
25 |
+
|
26 |
+
|
27 |
+
@register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig)
|
28 |
+
class FastSpeech2Loss(FairseqCriterion):
|
29 |
+
def __init__(self, task, ctc_weight):
|
30 |
+
super().__init__(task)
|
31 |
+
self.ctc_weight = ctc_weight
|
32 |
+
|
33 |
+
def forward(self, model: FairseqEncoderModel, sample, reduction="mean"):
|
34 |
+
src_tokens = sample["net_input"]["src_tokens"]
|
35 |
+
src_lens = sample["net_input"]["src_lengths"]
|
36 |
+
tgt_lens = sample["target_lengths"]
|
37 |
+
_feat_out, _feat_out_post, _, log_dur_out, pitch_out, energy_out = model(
|
38 |
+
src_tokens=src_tokens,
|
39 |
+
src_lengths=src_lens,
|
40 |
+
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
|
41 |
+
incremental_state=None,
|
42 |
+
target_lengths=tgt_lens,
|
43 |
+
speaker=sample["speaker"],
|
44 |
+
durations=sample["durations"],
|
45 |
+
pitches=sample["pitches"],
|
46 |
+
energies=sample["energies"],
|
47 |
+
)
|
48 |
+
|
49 |
+
src_mask = lengths_to_mask(sample["net_input"]["src_lengths"])
|
50 |
+
tgt_mask = lengths_to_mask(sample["target_lengths"])
|
51 |
+
|
52 |
+
pitches, energies = sample["pitches"], sample["energies"]
|
53 |
+
pitch_out, pitches = pitch_out[src_mask], pitches[src_mask]
|
54 |
+
energy_out, energies = energy_out[src_mask], energies[src_mask]
|
55 |
+
|
56 |
+
feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask]
|
57 |
+
l1_loss = F.l1_loss(feat_out, feat, reduction=reduction)
|
58 |
+
if _feat_out_post is not None:
|
59 |
+
l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, reduction=reduction)
|
60 |
+
|
61 |
+
pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction)
|
62 |
+
energy_loss = F.mse_loss(energy_out, energies, reduction=reduction)
|
63 |
+
|
64 |
+
log_dur_out = log_dur_out[src_mask]
|
65 |
+
dur = sample["durations"].float()
|
66 |
+
dur = dur.half() if log_dur_out.type().endswith(".HalfTensor") else dur
|
67 |
+
log_dur = torch.log(dur + 1)[src_mask]
|
68 |
+
dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction)
|
69 |
+
|
70 |
+
ctc_loss = torch.tensor(0.0).type_as(l1_loss)
|
71 |
+
if self.ctc_weight > 0.0:
|
72 |
+
lprobs = model.get_normalized_probs((_feat_out,), log_probs=True)
|
73 |
+
lprobs = lprobs.transpose(0, 1) # T x B x C
|
74 |
+
src_mask = lengths_to_mask(src_lens)
|
75 |
+
src_tokens_flat = src_tokens.masked_select(src_mask)
|
76 |
+
ctc_loss = (
|
77 |
+
F.ctc_loss(
|
78 |
+
lprobs,
|
79 |
+
src_tokens_flat,
|
80 |
+
tgt_lens,
|
81 |
+
src_lens,
|
82 |
+
reduction=reduction,
|
83 |
+
zero_infinity=True,
|
84 |
+
)
|
85 |
+
* self.ctc_weight
|
86 |
+
)
|
87 |
+
|
88 |
+
loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss
|
89 |
+
|
90 |
+
sample_size = sample["nsentences"]
|
91 |
+
logging_output = {
|
92 |
+
"loss": utils.item(loss.data),
|
93 |
+
"ntokens": sample["ntokens"],
|
94 |
+
"nsentences": sample["nsentences"],
|
95 |
+
"sample_size": sample_size,
|
96 |
+
"l1_loss": utils.item(l1_loss.data),
|
97 |
+
"dur_loss": utils.item(dur_loss.data),
|
98 |
+
"pitch_loss": utils.item(pitch_loss.data),
|
99 |
+
"energy_loss": utils.item(energy_loss.data),
|
100 |
+
"ctc_loss": utils.item(ctc_loss.data),
|
101 |
+
}
|
102 |
+
return loss, sample_size, logging_output
|
103 |
+
|
104 |
+
@classmethod
|
105 |
+
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
|
106 |
+
ns = [log.get("sample_size", 0) for log in logging_outputs]
|
107 |
+
ntot = sum(ns)
|
108 |
+
ws = [n / (ntot + 1e-8) for n in ns]
|
109 |
+
for key in [
|
110 |
+
"loss",
|
111 |
+
"l1_loss",
|
112 |
+
"dur_loss",
|
113 |
+
"pitch_loss",
|
114 |
+
"energy_loss",
|
115 |
+
"ctc_loss",
|
116 |
+
]:
|
117 |
+
vals = [log.get(key, 0) for log in logging_outputs]
|
118 |
+
val = sum(val * w for val, w in zip(vals, ws))
|
119 |
+
metrics.log_scalar(key, val, ntot, round=3)
|
120 |
+
metrics.log_scalar("sample_size", ntot, len(logging_outputs))
|
121 |
+
|
122 |
+
# inference metrics
|
123 |
+
if "targ_frames" not in logging_outputs[0]:
|
124 |
+
return
|
125 |
+
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
|
126 |
+
for key, new_key in [
|
127 |
+
("mcd_loss", "mcd_loss"),
|
128 |
+
("pred_frames", "pred_ratio"),
|
129 |
+
("nins", "ins_rate"),
|
130 |
+
("ndel", "del_rate"),
|
131 |
+
]:
|
132 |
+
val = sum(log.get(key, 0) for log in logging_outputs)
|
133 |
+
metrics.log_scalar(new_key, val / n, n, round=3)
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def logging_outputs_can_be_summed() -> bool:
|
137 |
+
return False
|
fairseq/fairseq/criterions/hubert_criterion.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import re
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from typing import List, Optional
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from fairseq import utils
|
14 |
+
from fairseq.logging import metrics
|
15 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
16 |
+
from fairseq.dataclass import FairseqDataclass
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class HubertCriterionConfig(FairseqDataclass):
|
21 |
+
pred_masked_weight: float = field(
|
22 |
+
default=1.0,
|
23 |
+
metadata={"help": "weight for predictive loss for masked frames"},
|
24 |
+
)
|
25 |
+
pred_nomask_weight: float = field(
|
26 |
+
default=0.0,
|
27 |
+
metadata={"help": "weight for predictive loss for unmasked frames"},
|
28 |
+
)
|
29 |
+
loss_weights: Optional[List[float]] = field(
|
30 |
+
default=None,
|
31 |
+
metadata={"help": "weights for additional loss terms (not first one)"},
|
32 |
+
)
|
33 |
+
log_keys: List[str] = field(
|
34 |
+
default_factory=lambda: [],
|
35 |
+
metadata={"help": "output keys to log"},
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
@register_criterion("hubert", dataclass=HubertCriterionConfig)
|
40 |
+
class HubertCriterion(FairseqCriterion):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
task,
|
44 |
+
pred_masked_weight,
|
45 |
+
pred_nomask_weight,
|
46 |
+
loss_weights=None,
|
47 |
+
log_keys=None,
|
48 |
+
):
|
49 |
+
super().__init__(task)
|
50 |
+
self.pred_masked_weight = pred_masked_weight
|
51 |
+
self.pred_nomask_weight = pred_nomask_weight
|
52 |
+
self.loss_weights = loss_weights
|
53 |
+
self.log_keys = [] if log_keys is None else log_keys
|
54 |
+
|
55 |
+
def forward(self, model, sample, reduce=True, log_pred=False):
|
56 |
+
"""Compute the loss for the given sample.
|
57 |
+
Returns a tuple with three elements:
|
58 |
+
1) the loss
|
59 |
+
2) the sample size, which is used as the denominator for the gradient
|
60 |
+
3) logging outputs to display while training
|
61 |
+
"""
|
62 |
+
net_output = model(target_list=sample["target_list"], **sample["net_input"])
|
63 |
+
loss = 0.0
|
64 |
+
sample_size = 0
|
65 |
+
logging_output = {}
|
66 |
+
reduction = "sum" if reduce else "none"
|
67 |
+
|
68 |
+
loss_m_list = []
|
69 |
+
logp_m_list = model.get_logits(net_output, True)
|
70 |
+
targ_m_list = model.get_targets(net_output, True)
|
71 |
+
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
|
72 |
+
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
|
73 |
+
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
|
74 |
+
loss_m_list.append(loss_m)
|
75 |
+
logging_output[f"loss_m_{i}"] = loss_m.detach().item()
|
76 |
+
if self.pred_masked_weight > 0:
|
77 |
+
loss += self.pred_masked_weight * sum(loss_m_list)
|
78 |
+
sample_size += targ_m_list[0].numel()
|
79 |
+
|
80 |
+
loss_u_list = []
|
81 |
+
logp_u_list = model.get_logits(net_output, False)
|
82 |
+
targ_u_list = model.get_targets(net_output, False)
|
83 |
+
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
|
84 |
+
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
|
85 |
+
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
|
86 |
+
loss_u_list.append(loss_u)
|
87 |
+
logging_output[f"loss_u_{i}"] = loss_u.detach().item()
|
88 |
+
if self.pred_nomask_weight > 0:
|
89 |
+
loss += self.pred_nomask_weight * sum(loss_u_list)
|
90 |
+
sample_size += targ_u_list[0].numel()
|
91 |
+
|
92 |
+
if self.loss_weights is not None:
|
93 |
+
assert hasattr(model, "get_extra_losses")
|
94 |
+
extra_losses, names = model.get_extra_losses(net_output)
|
95 |
+
if torch.is_tensor(extra_losses):
|
96 |
+
extra_losses = [extra_losses]
|
97 |
+
names = [names]
|
98 |
+
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
99 |
+
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
100 |
+
assert len(extra_losses) == len(
|
101 |
+
self.loss_weights
|
102 |
+
), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
103 |
+
for p, n, coef in zip(extra_losses, names, self.loss_weights):
|
104 |
+
if coef != 0 and p is not None:
|
105 |
+
p = coef * p.float() * sample_size
|
106 |
+
loss += p
|
107 |
+
logging_output[f"loss_{n}"] = p.item()
|
108 |
+
|
109 |
+
logging_output = {
|
110 |
+
"loss": loss.item() if reduce else loss,
|
111 |
+
"ntokens": sample_size,
|
112 |
+
"nsentences": sample["id"].numel(),
|
113 |
+
"sample_size": sample_size,
|
114 |
+
**logging_output,
|
115 |
+
}
|
116 |
+
|
117 |
+
for lk in self.log_keys:
|
118 |
+
if lk in net_output:
|
119 |
+
logging_output[lk] = float((net_output[lk]))
|
120 |
+
|
121 |
+
def compute_correct(logits):
|
122 |
+
if logits.numel() == 0:
|
123 |
+
return 0, 0
|
124 |
+
else:
|
125 |
+
assert logits.dim() > 1, logits.shape
|
126 |
+
max = logits.argmax(-1) == 0
|
127 |
+
min = logits.argmin(-1) == 0
|
128 |
+
both = max & min
|
129 |
+
corr = max.long().sum().item() - both.long().sum().item()
|
130 |
+
count = max.numel()
|
131 |
+
return corr, count
|
132 |
+
|
133 |
+
with torch.no_grad():
|
134 |
+
for i, logp_m in enumerate(logp_m_list):
|
135 |
+
corr_m, count_m = compute_correct(logp_m)
|
136 |
+
logging_output[f"correct_m_{i}"] = corr_m
|
137 |
+
logging_output[f"count_m_{i}"] = count_m
|
138 |
+
|
139 |
+
for i, logp_u in enumerate(logp_u_list):
|
140 |
+
corr_u, count_u = compute_correct(logp_u)
|
141 |
+
logging_output[f"correct_u_{i}"] = corr_u
|
142 |
+
logging_output[f"count_u_{i}"] = count_u
|
143 |
+
|
144 |
+
return loss, sample_size, logging_output
|
145 |
+
|
146 |
+
@staticmethod
|
147 |
+
def reduce_metrics(logging_outputs) -> None:
|
148 |
+
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
|
149 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
150 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
151 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
152 |
+
|
153 |
+
metrics.log_scalar(
|
154 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
155 |
+
)
|
156 |
+
if sample_size != ntokens:
|
157 |
+
metrics.log_scalar(
|
158 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
159 |
+
)
|
160 |
+
metrics.log_derived(
|
161 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
metrics.log_derived(
|
165 |
+
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
|
166 |
+
)
|
167 |
+
|
168 |
+
counts = {}
|
169 |
+
for lk in logging_outputs[0].keys():
|
170 |
+
if lk.startswith("count_"):
|
171 |
+
val = sum(log[lk] for log in logging_outputs)
|
172 |
+
metrics.log_scalar(lk, val)
|
173 |
+
counts[lk] = val
|
174 |
+
|
175 |
+
for lk in logging_outputs[0].keys():
|
176 |
+
if lk.startswith("loss_"):
|
177 |
+
val = sum(log[lk] for log in logging_outputs)
|
178 |
+
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
|
179 |
+
elif lk.startswith("correct_"):
|
180 |
+
val = sum(log[lk] for log in logging_outputs)
|
181 |
+
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
|
182 |
+
|
183 |
+
@staticmethod
|
184 |
+
def aggregate_logging_outputs(logging_outputs):
|
185 |
+
"""Aggregate logging outputs from data parallel training."""
|
186 |
+
raise NotImplementedError()
|
187 |
+
|
188 |
+
@staticmethod
|
189 |
+
def logging_outputs_can_be_summed() -> bool:
|
190 |
+
"""
|
191 |
+
Whether the logging outputs returned by `forward` can be summed
|
192 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
193 |
+
to True will improves distributed training speed.
|
194 |
+
"""
|
195 |
+
return False
|
fairseq/fairseq/criterions/label_smoothed_cross_entropy.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from fairseq import utils
|
11 |
+
from fairseq.logging import metrics
|
12 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
13 |
+
from fairseq.dataclass import FairseqDataclass
|
14 |
+
from omegaconf import II
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
|
19 |
+
label_smoothing: float = field(
|
20 |
+
default=0.0,
|
21 |
+
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
22 |
+
)
|
23 |
+
report_accuracy: bool = field(
|
24 |
+
default=False,
|
25 |
+
metadata={"help": "report accuracy metric"},
|
26 |
+
)
|
27 |
+
ignore_prefix_size: int = field(
|
28 |
+
default=0,
|
29 |
+
metadata={"help": "Ignore first N tokens"},
|
30 |
+
)
|
31 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
32 |
+
|
33 |
+
|
34 |
+
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
|
35 |
+
if target.dim() == lprobs.dim() - 1:
|
36 |
+
target = target.unsqueeze(-1)
|
37 |
+
nll_loss = -lprobs.gather(dim=-1, index=target)
|
38 |
+
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
|
39 |
+
if ignore_index is not None:
|
40 |
+
pad_mask = target.eq(ignore_index)
|
41 |
+
nll_loss.masked_fill_(pad_mask, 0.0)
|
42 |
+
smooth_loss.masked_fill_(pad_mask, 0.0)
|
43 |
+
else:
|
44 |
+
nll_loss = nll_loss.squeeze(-1)
|
45 |
+
smooth_loss = smooth_loss.squeeze(-1)
|
46 |
+
if reduce:
|
47 |
+
nll_loss = nll_loss.sum()
|
48 |
+
smooth_loss = smooth_loss.sum()
|
49 |
+
eps_i = epsilon / (lprobs.size(-1) - 1)
|
50 |
+
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
|
51 |
+
return loss, nll_loss
|
52 |
+
|
53 |
+
|
54 |
+
@register_criterion(
|
55 |
+
"label_smoothed_cross_entropy", dataclass=LabelSmoothedCrossEntropyCriterionConfig
|
56 |
+
)
|
57 |
+
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
task,
|
61 |
+
sentence_avg,
|
62 |
+
label_smoothing,
|
63 |
+
ignore_prefix_size=0,
|
64 |
+
report_accuracy=False,
|
65 |
+
):
|
66 |
+
super().__init__(task)
|
67 |
+
self.sentence_avg = sentence_avg
|
68 |
+
self.eps = label_smoothing
|
69 |
+
self.ignore_prefix_size = ignore_prefix_size
|
70 |
+
self.report_accuracy = report_accuracy
|
71 |
+
|
72 |
+
def forward(self, model, sample, reduce=True):
|
73 |
+
"""Compute the loss for the given sample.
|
74 |
+
|
75 |
+
Returns a tuple with three elements:
|
76 |
+
1) the loss
|
77 |
+
2) the sample size, which is used as the denominator for the gradient
|
78 |
+
3) logging outputs to display while training
|
79 |
+
"""
|
80 |
+
net_output = model(**sample["net_input"])
|
81 |
+
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
|
82 |
+
sample_size = (
|
83 |
+
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
84 |
+
)
|
85 |
+
logging_output = {
|
86 |
+
"loss": loss.data,
|
87 |
+
"nll_loss": nll_loss.data,
|
88 |
+
"ntokens": sample["ntokens"],
|
89 |
+
"nsentences": sample["target"].size(0),
|
90 |
+
"sample_size": sample_size,
|
91 |
+
}
|
92 |
+
if self.report_accuracy:
|
93 |
+
n_correct, total = self.compute_accuracy(model, net_output, sample)
|
94 |
+
logging_output["n_correct"] = utils.item(n_correct.data)
|
95 |
+
logging_output["total"] = utils.item(total.data)
|
96 |
+
return loss, sample_size, logging_output
|
97 |
+
|
98 |
+
def get_lprobs_and_target(self, model, net_output, sample):
|
99 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
100 |
+
target = model.get_targets(sample, net_output)
|
101 |
+
if self.ignore_prefix_size > 0:
|
102 |
+
# lprobs: B x T x C
|
103 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
104 |
+
target = target[:, self.ignore_prefix_size :].contiguous()
|
105 |
+
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
|
106 |
+
|
107 |
+
def compute_loss(self, model, net_output, sample, reduce=True):
|
108 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
109 |
+
loss, nll_loss = label_smoothed_nll_loss(
|
110 |
+
lprobs,
|
111 |
+
target,
|
112 |
+
self.eps,
|
113 |
+
ignore_index=self.padding_idx,
|
114 |
+
reduce=reduce,
|
115 |
+
)
|
116 |
+
return loss, nll_loss
|
117 |
+
|
118 |
+
def compute_accuracy(self, model, net_output, sample):
|
119 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
120 |
+
mask = target.ne(self.padding_idx)
|
121 |
+
n_correct = torch.sum(
|
122 |
+
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
123 |
+
)
|
124 |
+
total = torch.sum(mask)
|
125 |
+
return n_correct, total
|
126 |
+
|
127 |
+
@classmethod
|
128 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
129 |
+
"""Aggregate logging outputs from data parallel training."""
|
130 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
131 |
+
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
|
132 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
133 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
134 |
+
|
135 |
+
metrics.log_scalar(
|
136 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
137 |
+
)
|
138 |
+
metrics.log_scalar(
|
139 |
+
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
|
140 |
+
)
|
141 |
+
metrics.log_derived(
|
142 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
143 |
+
)
|
144 |
+
|
145 |
+
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
|
146 |
+
if total > 0:
|
147 |
+
metrics.log_scalar("total", total)
|
148 |
+
n_correct = utils.item(
|
149 |
+
sum(log.get("n_correct", 0) for log in logging_outputs)
|
150 |
+
)
|
151 |
+
metrics.log_scalar("n_correct", n_correct)
|
152 |
+
metrics.log_derived(
|
153 |
+
"accuracy",
|
154 |
+
lambda meters: round(
|
155 |
+
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
|
156 |
+
)
|
157 |
+
if meters["total"].sum > 0
|
158 |
+
else float("nan"),
|
159 |
+
)
|
160 |
+
|
161 |
+
@staticmethod
|
162 |
+
def logging_outputs_can_be_summed() -> bool:
|
163 |
+
"""
|
164 |
+
Whether the logging outputs returned by `forward` can be summed
|
165 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
166 |
+
to True will improves distributed training speed.
|
167 |
+
"""
|
168 |
+
return True
|
fairseq/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
import torch
|
8 |
+
from fairseq import utils
|
9 |
+
from fairseq.logging import metrics
|
10 |
+
from fairseq.criterions import register_criterion
|
11 |
+
from fairseq.criterions.label_smoothed_cross_entropy import (
|
12 |
+
LabelSmoothedCrossEntropyCriterion,
|
13 |
+
LabelSmoothedCrossEntropyCriterionConfig,
|
14 |
+
)
|
15 |
+
|
16 |
+
try:
|
17 |
+
from simuleval.metrics.latency import (
|
18 |
+
AverageLagging,
|
19 |
+
AverageProportion,
|
20 |
+
DifferentiableAverageLagging,
|
21 |
+
)
|
22 |
+
|
23 |
+
LATENCY_METRICS = {
|
24 |
+
"average_lagging": AverageLagging,
|
25 |
+
"average_proportion": AverageProportion,
|
26 |
+
"differentiable_average_lagging": DifferentiableAverageLagging,
|
27 |
+
}
|
28 |
+
except ImportError:
|
29 |
+
LATENCY_METRICS = None
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig(
|
34 |
+
LabelSmoothedCrossEntropyCriterionConfig
|
35 |
+
):
|
36 |
+
latency_avg_weight: float = field(
|
37 |
+
default=0.0,
|
38 |
+
metadata={"help": "weight fot average latency loss."},
|
39 |
+
)
|
40 |
+
latency_var_weight: float = field(
|
41 |
+
default=0.0,
|
42 |
+
metadata={"help": "weight fot variance latency loss."},
|
43 |
+
)
|
44 |
+
latency_avg_type: str = field(
|
45 |
+
default="differentiable_average_lagging",
|
46 |
+
metadata={"help": "latency type for average loss"},
|
47 |
+
)
|
48 |
+
latency_var_type: str = field(
|
49 |
+
default="variance_delay",
|
50 |
+
metadata={"help": "latency typ for variance loss"},
|
51 |
+
)
|
52 |
+
latency_gather_method: str = field(
|
53 |
+
default="weighted_average",
|
54 |
+
metadata={"help": "method to gather latency loss for all heads"},
|
55 |
+
)
|
56 |
+
latency_update_after: int = field(
|
57 |
+
default=0,
|
58 |
+
metadata={"help": "Add latency loss after certain steps"},
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
@register_criterion(
|
63 |
+
"latency_augmented_label_smoothed_cross_entropy",
|
64 |
+
dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig,
|
65 |
+
)
|
66 |
+
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
67 |
+
LabelSmoothedCrossEntropyCriterion
|
68 |
+
):
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
task,
|
72 |
+
sentence_avg,
|
73 |
+
label_smoothing,
|
74 |
+
ignore_prefix_size,
|
75 |
+
report_accuracy,
|
76 |
+
latency_avg_weight,
|
77 |
+
latency_var_weight,
|
78 |
+
latency_avg_type,
|
79 |
+
latency_var_type,
|
80 |
+
latency_gather_method,
|
81 |
+
latency_update_after,
|
82 |
+
):
|
83 |
+
super().__init__(
|
84 |
+
task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy
|
85 |
+
)
|
86 |
+
assert LATENCY_METRICS is not None, "Please make sure SimulEval is installed."
|
87 |
+
|
88 |
+
self.latency_avg_weight = latency_avg_weight
|
89 |
+
self.latency_var_weight = latency_var_weight
|
90 |
+
self.latency_avg_type = latency_avg_type
|
91 |
+
self.latency_var_type = latency_var_type
|
92 |
+
self.latency_gather_method = latency_gather_method
|
93 |
+
self.latency_update_after = latency_update_after
|
94 |
+
|
95 |
+
def forward(self, model, sample, reduce=True):
|
96 |
+
net_output = model(**sample["net_input"])
|
97 |
+
# 1. Compute cross entropy loss
|
98 |
+
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
|
99 |
+
|
100 |
+
# 2. Compute cross latency loss
|
101 |
+
latency_loss, expected_latency, expected_delays_var = self.compute_latency_loss(
|
102 |
+
model, sample, net_output
|
103 |
+
)
|
104 |
+
|
105 |
+
if self.latency_update_after > 0:
|
106 |
+
num_updates = getattr(model.decoder, "num_updates", None)
|
107 |
+
assert (
|
108 |
+
num_updates is not None
|
109 |
+
), "model.decoder doesn't have attribute 'num_updates'"
|
110 |
+
if num_updates <= self.latency_update_after:
|
111 |
+
latency_loss = 0
|
112 |
+
|
113 |
+
loss += latency_loss
|
114 |
+
|
115 |
+
sample_size = (
|
116 |
+
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
117 |
+
)
|
118 |
+
|
119 |
+
logging_output = {
|
120 |
+
"loss": loss.data,
|
121 |
+
"nll_loss": nll_loss.data,
|
122 |
+
"ntokens": sample["ntokens"],
|
123 |
+
"nsentences": sample["target"].size(0),
|
124 |
+
"sample_size": sample_size,
|
125 |
+
"latency": expected_latency,
|
126 |
+
"delays_var": expected_delays_var,
|
127 |
+
"latency_loss": latency_loss,
|
128 |
+
}
|
129 |
+
|
130 |
+
if self.report_accuracy:
|
131 |
+
n_correct, total = self.compute_accuracy(model, net_output, sample)
|
132 |
+
logging_output["n_correct"] = utils.item(n_correct.data)
|
133 |
+
logging_output["total"] = utils.item(total.data)
|
134 |
+
return loss, sample_size, logging_output
|
135 |
+
|
136 |
+
def compute_latency_loss(self, model, sample, net_output):
|
137 |
+
assert (
|
138 |
+
net_output[-1].encoder_padding_mask is None
|
139 |
+
or not net_output[-1].encoder_padding_mask[:, 0].any()
|
140 |
+
), "Only right padding on source is supported."
|
141 |
+
# 1. Obtain the expected alignment
|
142 |
+
alpha_list = [item["alpha"] for item in net_output[1].attn_list]
|
143 |
+
num_layers = len(alpha_list)
|
144 |
+
bsz, num_heads, tgt_len, src_len = alpha_list[0].size()
|
145 |
+
|
146 |
+
# bsz * num_layers * num_heads, tgt_len, src_len
|
147 |
+
alpha_all = torch.cat(alpha_list, dim=1).view(-1, tgt_len, src_len)
|
148 |
+
|
149 |
+
# 2 compute expected delays
|
150 |
+
# bsz * num_heads * num_layers, tgt_len, src_len for MMA
|
151 |
+
steps = (
|
152 |
+
torch.arange(1, 1 + src_len)
|
153 |
+
.unsqueeze(0)
|
154 |
+
.unsqueeze(1)
|
155 |
+
.expand_as(alpha_all)
|
156 |
+
.type_as(alpha_all)
|
157 |
+
)
|
158 |
+
|
159 |
+
expected_delays = torch.sum(steps * alpha_all, dim=-1)
|
160 |
+
|
161 |
+
target_padding_mask = (
|
162 |
+
model.get_targets(sample, net_output)
|
163 |
+
.eq(self.padding_idx)
|
164 |
+
.unsqueeze(1)
|
165 |
+
.expand(bsz, num_layers * num_heads, tgt_len)
|
166 |
+
.contiguous()
|
167 |
+
.view(-1, tgt_len)
|
168 |
+
)
|
169 |
+
|
170 |
+
src_lengths = (
|
171 |
+
sample["net_input"]["src_lengths"]
|
172 |
+
.unsqueeze(1)
|
173 |
+
.expand(bsz, num_layers * num_heads)
|
174 |
+
.contiguous()
|
175 |
+
.view(-1)
|
176 |
+
)
|
177 |
+
expected_latency = LATENCY_METRICS[self.latency_avg_type](
|
178 |
+
expected_delays, src_lengths, None, target_padding_mask=target_padding_mask
|
179 |
+
)
|
180 |
+
|
181 |
+
# 2.1 average expected latency of heads
|
182 |
+
# bsz, num_layers * num_heads
|
183 |
+
expected_latency = expected_latency.view(bsz, -1)
|
184 |
+
if self.latency_gather_method == "average":
|
185 |
+
# bsz * tgt_len
|
186 |
+
expected_latency = expected_delays.mean(dim=1)
|
187 |
+
elif self.latency_gather_method == "weighted_average":
|
188 |
+
weights = torch.nn.functional.softmax(expected_latency, dim=1)
|
189 |
+
expected_latency = torch.sum(expected_latency * weights, dim=1)
|
190 |
+
elif self.latency_gather_method == "max":
|
191 |
+
expected_latency = expected_latency.max(dim=1)[0]
|
192 |
+
else:
|
193 |
+
raise NotImplementedError
|
194 |
+
|
195 |
+
expected_latency = expected_latency.sum()
|
196 |
+
avg_loss = self.latency_avg_weight * expected_latency
|
197 |
+
|
198 |
+
# 2.2 variance of expected delays
|
199 |
+
expected_delays_var = (
|
200 |
+
expected_delays.view(bsz, -1, tgt_len).var(dim=1).mean(dim=1)
|
201 |
+
)
|
202 |
+
expected_delays_var = expected_delays_var.sum()
|
203 |
+
var_loss = self.latency_avg_weight * expected_delays_var
|
204 |
+
|
205 |
+
# 3. Final loss
|
206 |
+
latency_loss = avg_loss + var_loss
|
207 |
+
|
208 |
+
return latency_loss, expected_latency, expected_delays_var
|
209 |
+
|
210 |
+
@classmethod
|
211 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
212 |
+
super().reduce_metrics(logging_outputs)
|
213 |
+
latency = sum(log.get("latency", 0) for log in logging_outputs)
|
214 |
+
delays_var = sum(log.get("delays_var", 0) for log in logging_outputs)
|
215 |
+
latency_loss = sum(log.get("latency_loss", 0) for log in logging_outputs)
|
216 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
217 |
+
metrics.log_scalar("latency", latency.float() / nsentences, nsentences, round=3)
|
218 |
+
metrics.log_scalar("delays_var", delays_var / nsentences, nsentences, round=3)
|
219 |
+
metrics.log_scalar(
|
220 |
+
"latency_loss", latency_loss / nsentences, nsentences, round=3
|
221 |
+
)
|