Spaces:
Running
Running
""" | |
Loss Implementation based upon | |
https://github.com/eric-mitchell/direct-preference-optimization | |
https://github.com/huggingface/trl | |
""" | |
import logging | |
from typing import Any, KeysView | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
logger = logging.getLogger(__name__) | |
class DPOLoss(nn.Module): | |
""" | |
Implements | |
"Direct Preference Optimization: | |
Your Language Model is Secretly a Reward Model" | |
from https://arxiv.org/abs/2305.18290 | |
""" | |
def __init__(self, cfg: Any): | |
super().__init__() | |
self.cfg = cfg | |
self.requires_reference_model = True | |
def forward( | |
self, | |
policy_chosen_logps: torch.FloatTensor, | |
policy_rejected_logps: torch.FloatTensor, | |
reference_chosen_logps: torch.FloatTensor, | |
reference_rejected_logps: torch.FloatTensor, | |
): | |
pi_logratios = policy_chosen_logps - policy_rejected_logps | |
ref_logratios = reference_chosen_logps - reference_rejected_logps | |
losses = self.get_losses(logits=pi_logratios - ref_logratios) | |
chosen_rewards = ( | |
self.cfg.training.beta | |
* (policy_chosen_logps - reference_chosen_logps).detach() | |
) | |
rejected_rewards = ( | |
self.cfg.training.beta | |
* (policy_rejected_logps - reference_rejected_logps).detach() | |
) | |
return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean() | |
def get_losses(self, logits): | |
# The beta is a temperature parameter for the DPO loss, | |
# typically something in the range of 0.1 to 0.5. | |
# We ignore the reference model as beta -> 0. | |
# The label_smoothing parameter encodes our uncertainty about the labels and | |
# calculates a conservative DPO loss. | |
# For now, set label_smoothing to 0 (original DPO loss). | |
# See https://ericmitchell.ai/cdpo.pdf for more details | |
label_smoothing = 0 | |
losses = ( | |
-F.logsigmoid(self.cfg.training.beta * logits) * (1 - label_smoothing) | |
- F.logsigmoid(-self.cfg.training.beta * logits) * label_smoothing | |
) | |
return losses | |
class DPOHingeLoss(DPOLoss): | |
def get_losses(self, logits): | |
losses = torch.relu(1 - self.cfg.training.beta * logits) | |
return losses | |
class DPOIPOLoss(DPOLoss): | |
""" | |
Implements "A General Theoretical Paradigm | |
to Understand Learning from Human Preferences" | |
from https://arxiv.org/pdf/2310.12036.pdf | |
""" | |
def get_losses(self, logits): | |
# eqn (17) of the https://arxiv.org/pdf/2310.12036.pdf | |
# where beta is the real, positive KL parameter for the IPO loss, | |
# denoted by tau in the paper (see also eqn (6)). | |
losses = (logits - 1 / (2 * self.cfg.training.beta)) ** 2 | |
return losses | |
class KTOPairLoss(nn.Module): | |
""" | |
Implements original paired KTO implementation | |
Adopted from https://github.com/ContextualAI/HALOs | |
and https://github.com/huggingface/trl | |
""" | |
def __init__(self, cfg: Any): | |
super().__init__() | |
self.cfg = cfg | |
self.requires_reference_model = True | |
def forward( | |
self, | |
policy_chosen_logps: torch.FloatTensor, | |
policy_rejected_logps: torch.FloatTensor, | |
reference_chosen_logps: torch.FloatTensor, | |
reference_rejected_logps: torch.FloatTensor, | |
): | |
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0) | |
rejected_KL = ( | |
(policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) | |
) | |
chosen_logratios = policy_chosen_logps - reference_chosen_logps | |
rejected_logratios = policy_rejected_logps - reference_rejected_logps | |
losses = torch.cat( | |
( | |
1 | |
- F.sigmoid(self.cfg.training.beta * (chosen_logratios - rejected_KL)), | |
1 | |
- F.sigmoid(self.cfg.training.beta * (chosen_KL - rejected_logratios)), | |
), | |
0, | |
) | |
chosen_rewards = ( | |
self.cfg.training.beta | |
* (policy_chosen_logps - reference_chosen_logps).detach() | |
).float() | |
rejected_rewards = ( | |
self.cfg.training.beta | |
* (policy_rejected_logps - reference_rejected_logps).detach() | |
).float() | |
return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean() | |
class CPOLoss(nn.Module): | |
""" | |
Implements CPO Loss https://arxiv.org/abs/2401.08417 | |
Adopted from https://github.com/huggingface/trl | |
""" | |
def __init__(self, cfg: Any): | |
super().__init__() | |
self.cfg = cfg | |
self.requires_reference_model = False | |
def forward( | |
self, | |
policy_chosen_logps: torch.FloatTensor, | |
policy_rejected_logps: torch.FloatTensor, | |
): | |
logits = policy_chosen_logps - policy_rejected_logps | |
losses = self.get_losses(logits) | |
chosen_rewards = (self.cfg.training.beta * policy_chosen_logps.detach()).float() | |
rejected_rewards = ( | |
self.cfg.training.beta * policy_rejected_logps.detach() | |
).float() | |
return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean() | |
def get_losses(self, logits): | |
label_smoothing = 0 | |
losses = ( | |
-F.logsigmoid(self.cfg.training.beta * logits) * (1 - label_smoothing) | |
- F.logsigmoid(-self.cfg.training.beta * logits) * label_smoothing | |
) | |
return losses | |
class SimPOLoss(CPOLoss): | |
""" | |
Implements SimPO Loss https://arxiv.org/abs/2405.14734 | |
Adopted from https://github.com/princeton-nlp/SimPO | |
and https://github.com/huggingface/trl | |
""" | |
def get_losses(self, logits): | |
label_smoothing = 0 | |
gamma = self.cfg.training.simpo_gamma | |
gamma_logratios = gamma / self.cfg.training.beta | |
logits = logits - gamma_logratios | |
losses = ( | |
-F.logsigmoid(self.cfg.training.beta * logits) * (1 - label_smoothing) | |
- F.logsigmoid(-self.cfg.training.beta * logits) * label_smoothing | |
) | |
return losses | |
class Losses: | |
"""Losses factory.""" | |
_losses = { | |
"DPOLoss": DPOLoss, | |
"DPOHingeLoss": DPOHingeLoss, | |
"DPOIPOLoss": DPOIPOLoss, | |
"KTOPairLoss": KTOPairLoss, | |
"CPOLoss": CPOLoss, | |
"SimPOLoss": SimPOLoss, | |
} | |
def names(cls) -> KeysView: | |
return cls._losses.keys() | |
def get(cls, name: str) -> Any: | |
"""Access to Losses. | |
Args: | |
name: losses name | |
Returns: | |
A class to build the Losses | |
""" | |
return cls._losses.get(name, DPOLoss) | |
# see https://github.com/huggingface/trl/commit/29d439a2043edf4455b05cae5a1e2ade69d22794 | |
LOSS_REDUCTION = { | |
"DPOLoss": False, | |
"KTOPairLoss": False, | |
"DPOHingeLoss": True, | |
"DPOIPOLoss": True, | |
"CPOLoss": False, | |
"SimPOLoss": True, | |
} | |