boris commited on
Commit
cbeacb9
·
unverified ·
2 Parent(s): 61f888f 9db361a

Merge pull request #10 from borisdayma/feat-loss

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +9 -34
seq2seq/run_seq2seq_flax.py CHANGED
@@ -487,10 +487,6 @@ def main():
487
 
488
  model_inputs["decoder_input_ids"] = labels
489
 
490
- # We need decoder_attention_mask so we can ignore pad tokens from loss
491
- # TODO: I don't believe we need "decoder_attention_mask" in this case because all labels have same length
492
- #model_inputs["decoder_attention_mask"] = labels["attention_mask"]
493
-
494
  return model_inputs
495
 
496
  if training_args.do_train:
@@ -643,39 +639,19 @@ def main():
643
  state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
644
 
645
  # label smoothed cross entropy
646
- def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
647
- """
648
- The label smoothing implementation is adapted from Flax's official example:
649
- https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
650
- """
651
- vocab_size = logits.shape[-1]
652
- confidence = 1.0 - label_smoothing_factor
653
- low_confidence = (1.0 - confidence) / (vocab_size - 1)
654
- normalizing_constant = -(
655
- confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
656
- )
657
- soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
658
-
659
- loss = optax.softmax_cross_entropy(logits, soft_labels)
660
- loss = loss - normalizing_constant
661
-
662
- if padding_mask is None:
663
- padding_mask = np.ones(loss.shape)
664
-
665
- # ignore padded tokens from loss
666
- loss = loss * padding_mask
667
- loss = loss.sum() / padding_mask.sum()
668
  return loss
669
 
670
  # Define gradient update step fn
671
- def train_step(state, batch, label_smoothing_factor=0.0):
672
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
673
 
674
  def compute_loss(params):
675
  labels = batch.pop("labels")
676
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
677
- padding_mask = batch.get("decoder_attention_mask", None)
678
- loss = loss_fn(logits, labels, padding_mask, label_smoothing_factor)
679
  return loss
680
 
681
  grad_fn = jax.value_and_grad(compute_loss)
@@ -690,11 +666,10 @@ def main():
690
  return new_state, metrics
691
 
692
  # Define eval fn
693
- def eval_step(params, batch, label_smoothing_factor=0.0):
694
  labels = batch.pop("labels")
695
  logits = model(**batch, params=params, train=False)[0]
696
- padding_mask = batch.get("decoder_attention_mask", None)
697
- loss = loss_fn(logits, labels, padding_mask, label_smoothing_factor)
698
 
699
  # summarize metrics
700
  metrics = {"loss": loss}
@@ -715,9 +690,9 @@ def main():
715
 
716
  # Create parallel version of the train and eval step
717
  p_train_step = jax.pmap(
718
- partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
719
  )
720
- p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
721
  p_generate_step = jax.pmap(generate_step, "batch")
722
 
723
  # Replicate the train state on each device
 
487
 
488
  model_inputs["decoder_input_ids"] = labels
489
 
 
 
 
 
490
  return model_inputs
491
 
492
  if training_args.do_train:
 
639
  state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
640
 
641
  # label smoothed cross entropy
642
+ def loss_fn(logits, labels):
643
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
644
+ loss = loss.mean()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645
  return loss
646
 
647
  # Define gradient update step fn
648
+ def train_step(state, batch):
649
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
650
 
651
  def compute_loss(params):
652
  labels = batch.pop("labels")
653
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
654
+ loss = loss_fn(logits, labels)
 
655
  return loss
656
 
657
  grad_fn = jax.value_and_grad(compute_loss)
 
666
  return new_state, metrics
667
 
668
  # Define eval fn
669
+ def eval_step(params, batch):
670
  labels = batch.pop("labels")
671
  logits = model(**batch, params=params, train=False)[0]
672
+ loss = loss_fn(logits, labels)
 
673
 
674
  # summarize metrics
675
  metrics = {"loss": loss}
 
690
 
691
  # Create parallel version of the train and eval step
692
  p_train_step = jax.pmap(
693
+ train_step, "batch", donate_argnums=(0,)
694
  )
695
+ p_eval_step = jax.pmap(eval_step, "batch")
696
  p_generate_step = jax.pmap(generate_step, "batch")
697
 
698
  # Replicate the train state on each device