Loss function suggestion

#1
by tomaarsen HF staff - opened

Hello @pankajrajdeo !

This looks like a fascinating model, so I wanted to jump in and offer some (unrequested, my apologies) advice.
When training with (anchor, positive, negative) triplets, you can often reach the best performance by training with a loss that uses in-batch negatives. In the literature, this is often called InfoNCE loss, and in Sentence Transformers it's called MultipleNegativesRankingLoss (MNRL).

In short:

  • In a batch, TripletLoss will minimize the distance between anchor_i and positive_i, and maximize the distance between anchor_i and negative_i.
  • In a batch, MultipleNegativesRankingLoss will minimize the distance between anchor_i and positive_i, and maximize the distance between anchor_i and negative_i, between anchor_i and positive_j for all j != i, and between anchor_i and negative_j for all j != i.

So, MNRL will maximize the distance between the anchor and batch_size * 2 + 1 unrelated texts, with the automatic assumption that positive_j and anchor_i, and negative_j and anchor_i are primarily unrelated. If this is usually true in your data, then MNRL will outperform TripletLoss.

These are all options to consider:

image.png

Some details on the others:

  • CachedMultipleNegativesRankingLoss: Like MNRL, but uses some clever mechanics called GradCache allowing you to use arbitrarily large batch sizes. Larger batch sizes = more in-batch negatives, which can help depending on how often you'd get a false negative.
  • GISTEmbedLoss: Like MNRL, but with a "guide model" that can filter out potential false negatives. This is mostly useful if you have lots of false negatives that are easily recognized by a small model, which isn't the case for you.
  • CachedGISTEmbedLoss: Same as GISTEmbedLoss, but also with the GradCache mechanic allowing for larger batch sizes.

I would experiment with MNRL and leave the others be. I suspect that your final model would be stronger.

To help compare, you can use the TripletEvaluator, e.g.:

from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import TripletEvaluator

# Initialize the TripletEvaluator using anchors, positives, and negatives
triplet_evaluator = TripletEvaluator(
    anchors=test_anchors,
    positives=test_positives,
    negatives=test_negatives,
    name="triplet-test",
)

results = triplet_evaluator(model)
# or
# trainer = SentenceTransformerTrainer(
#     ...
#     evaluator=triplet_evaluator,
# )

See an example script here.

Feel free to disregard my advice if you'd like to stick with the TripletLoss - it was unrequested advice after all 😅

  • Tom Aarsen

Hi @tomaarsen ,

Thank you for your thoughtful suggestion. I appreciate the detailed breakdown you provided regarding the potential benefits of MNRL. I will definitely try it and update.

Sign up or log in to comment