|
|
|
|
|
|
|
|
|
|
|
from fairseq.criterions import register_criterion |
|
from fairseq.criterions.label_smoothed_cross_entropy import ( |
|
LabelSmoothedCrossEntropyCriterion, |
|
) |
|
|
|
|
|
@register_criterion("latency_augmented_label_smoothed_cross_entropy") |
|
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( |
|
LabelSmoothedCrossEntropyCriterion |
|
): |
|
def __init__( |
|
self, |
|
task, |
|
sentence_avg, |
|
label_smoothing, |
|
ignore_prefix_size, |
|
report_accuracy, |
|
latency_weight_avg, |
|
latency_weight_avg_type, |
|
latency_weight_var, |
|
latency_weight_var_type, |
|
mass_preservation, |
|
average_method, |
|
): |
|
super().__init__( |
|
task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy |
|
) |
|
from examples.simultaneous_translation.utils.latency import LatencyTraining |
|
self.eps = label_smoothing |
|
self.latency_weight_avg = latency_weight_avg |
|
self.latency_weight_avg_type = latency_weight_avg_type |
|
self.latency_weight_var = latency_weight_var |
|
self.latency_weight_var_type = latency_weight_var_type |
|
self.mass_preservation = mass_preservation |
|
self.average_method = average_method |
|
self.latency_train = LatencyTraining( |
|
self.latency_weight_avg, |
|
self.latency_weight_var, |
|
self.latency_weight_avg_type, |
|
self.latency_weight_var_type, |
|
self.mass_preservation, |
|
self.average_method, |
|
) |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
super( |
|
LatencyAugmentedLabelSmoothedCrossEntropyCriterion, |
|
LatencyAugmentedLabelSmoothedCrossEntropyCriterion, |
|
).add_args(parser) |
|
|
|
|
|
"""Add criterion-specific arguments to the parser.""" |
|
parser.add_argument( |
|
"--label-smoothing", |
|
default=0.0, |
|
type=float, |
|
metavar="D", |
|
help="epsilon for label smoothing, 0 means no label smoothing", |
|
) |
|
parser.add_argument( |
|
"--ignore_prefix_size", |
|
default=0, |
|
type=int, |
|
help="ignore first N tokens", |
|
) |
|
parser.add_argument( |
|
"--report-accuracy", |
|
default=False, |
|
type=bool, |
|
help="report accuracy metric", |
|
) |
|
parser.add_argument("--latency-weight-avg", default=0., type=float, metavar='D', |
|
help="Average loss weight") |
|
parser.add_argument("--latency-weight-var", default=0., type=float, metavar='D', |
|
help="Variance loss weight") |
|
parser.add_argument("--latency-weight-avg-type", default="differentiable_average_lagging", |
|
help="Statistics for Average loss type") |
|
parser.add_argument("--latency-weight-var-type", default="variance_delay", |
|
help="Statistics for variance loss type") |
|
parser.add_argument("--average-method", default="weighted_average", |
|
help="Average loss type") |
|
|
|
|
|
def compute_loss(self, model, net_output, sample, reduce=True): |
|
|
|
loss, nll_loss = super().compute_loss(model, net_output, sample, reduce) |
|
|
|
|
|
attn_list = [item["alpha"] for item in net_output[-1]["attn_list"]] |
|
|
|
target_padding_mask = model.get_targets(sample, net_output).eq(self.padding_idx) |
|
|
|
source_padding_mask = net_output[-1].get("encoder_padding_mask", None) |
|
|
|
|
|
latency_loss = self.latency_train.loss( |
|
attn_list, source_padding_mask, target_padding_mask |
|
) |
|
|
|
loss += latency_loss |
|
|
|
return loss, nll_loss |
|
|