Loss function suggestion
Hello @pankajrajdeo !
This looks like a fascinating model, so I wanted to jump in and offer some (unrequested, my apologies) advice.
When training with (anchor, positive, negative) triplets, you can often reach the best performance by training with a loss that uses in-batch negatives. In the literature, this is often called InfoNCE loss, and in Sentence Transformers it's called MultipleNegativesRankingLoss
(MNRL).
In short:
- In a batch,
TripletLoss
will minimize the distance between anchor_i and positive_i, and maximize the distance between anchor_i and negative_i. - In a batch,
MultipleNegativesRankingLoss
will minimize the distance between anchor_i and positive_i, and maximize the distance between anchor_i and negative_i, between anchor_i and positive_j for all j != i, and between anchor_i and negative_j for all j != i.
So, MNRL will maximize the distance between the anchor and batch_size * 2 + 1
unrelated texts, with the automatic assumption that positive_j and anchor_i, and negative_j and anchor_i are primarily unrelated. If this is usually true in your data, then MNRL will outperform TripletLoss.
These are all options to consider:
Some details on the others:
CachedMultipleNegativesRankingLoss
: Like MNRL, but uses some clever mechanics called GradCache allowing you to use arbitrarily large batch sizes. Larger batch sizes = more in-batch negatives, which can help depending on how often you'd get a false negative.GISTEmbedLoss
: Like MNRL, but with a "guide model" that can filter out potential false negatives. This is mostly useful if you have lots of false negatives that are easily recognized by a small model, which isn't the case for you.CachedGISTEmbedLoss
: Same asGISTEmbedLoss
, but also with the GradCache mechanic allowing for larger batch sizes.
I would experiment with MNRL and leave the others be. I suspect that your final model would be stronger.
To help compare, you can use the TripletEvaluator, e.g.:
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import TripletEvaluator
# Initialize the TripletEvaluator using anchors, positives, and negatives
triplet_evaluator = TripletEvaluator(
anchors=test_anchors,
positives=test_positives,
negatives=test_negatives,
name="triplet-test",
)
results = triplet_evaluator(model)
# or
# trainer = SentenceTransformerTrainer(
# ...
# evaluator=triplet_evaluator,
# )
See an example script here.
Feel free to disregard my advice if you'd like to stick with the TripletLoss
- it was unrequested advice after all 😅
- Tom Aarsen
Hi @tomaarsen ,
Thank you for your thoughtful suggestion. I appreciate the detailed breakdown you provided regarding the potential benefits of MNRL. I will definitely try it and update.