Spaces:
Running
Running
File size: 6,958 Bytes
5caedb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
"""
Loss Implementation based upon
https://github.com/eric-mitchell/direct-preference-optimization
https://github.com/huggingface/trl
"""
import logging
from typing import Any, KeysView
import torch
import torch.nn.functional as F
from torch import nn
logger = logging.getLogger(__name__)
class DPOLoss(nn.Module):
"""
Implements
"Direct Preference Optimization:
Your Language Model is Secretly a Reward Model"
from https://arxiv.org/abs/2305.18290
"""
def __init__(self, cfg: Any):
super().__init__()
self.cfg = cfg
self.requires_reference_model = True
def forward(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
):
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
losses = self.get_losses(logits=pi_logratios - ref_logratios)
chosen_rewards = (
self.cfg.training.beta
* (policy_chosen_logps - reference_chosen_logps).detach()
)
rejected_rewards = (
self.cfg.training.beta
* (policy_rejected_logps - reference_rejected_logps).detach()
)
return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean()
def get_losses(self, logits):
# The beta is a temperature parameter for the DPO loss,
# typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0.
# The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative DPO loss.
# For now, set label_smoothing to 0 (original DPO loss).
# See https://ericmitchell.ai/cdpo.pdf for more details
label_smoothing = 0
losses = (
-F.logsigmoid(self.cfg.training.beta * logits) * (1 - label_smoothing)
- F.logsigmoid(-self.cfg.training.beta * logits) * label_smoothing
)
return losses
class DPOHingeLoss(DPOLoss):
def get_losses(self, logits):
losses = torch.relu(1 - self.cfg.training.beta * logits)
return losses
class DPOIPOLoss(DPOLoss):
"""
Implements "A General Theoretical Paradigm
to Understand Learning from Human Preferences"
from https://arxiv.org/pdf/2310.12036.pdf
"""
def get_losses(self, logits):
# eqn (17) of the https://arxiv.org/pdf/2310.12036.pdf
# where beta is the real, positive KL parameter for the IPO loss,
# denoted by tau in the paper (see also eqn (6)).
losses = (logits - 1 / (2 * self.cfg.training.beta)) ** 2
return losses
class KTOPairLoss(nn.Module):
"""
Implements original paired KTO implementation
Adopted from https://github.com/ContextualAI/HALOs
and https://github.com/huggingface/trl
"""
def __init__(self, cfg: Any):
super().__init__()
self.cfg = cfg
self.requires_reference_model = True
def forward(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
):
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (
(policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)
)
chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
losses = torch.cat(
(
1
- F.sigmoid(self.cfg.training.beta * (chosen_logratios - rejected_KL)),
1
- F.sigmoid(self.cfg.training.beta * (chosen_KL - rejected_logratios)),
),
0,
)
chosen_rewards = (
self.cfg.training.beta
* (policy_chosen_logps - reference_chosen_logps).detach()
).float()
rejected_rewards = (
self.cfg.training.beta
* (policy_rejected_logps - reference_rejected_logps).detach()
).float()
return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean()
class CPOLoss(nn.Module):
"""
Implements CPO Loss https://arxiv.org/abs/2401.08417
Adopted from https://github.com/huggingface/trl
"""
def __init__(self, cfg: Any):
super().__init__()
self.cfg = cfg
self.requires_reference_model = False
def forward(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
):
logits = policy_chosen_logps - policy_rejected_logps
losses = self.get_losses(logits)
chosen_rewards = (self.cfg.training.beta * policy_chosen_logps.detach()).float()
rejected_rewards = (
self.cfg.training.beta * policy_rejected_logps.detach()
).float()
return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean()
def get_losses(self, logits):
label_smoothing = 0
losses = (
-F.logsigmoid(self.cfg.training.beta * logits) * (1 - label_smoothing)
- F.logsigmoid(-self.cfg.training.beta * logits) * label_smoothing
)
return losses
class SimPOLoss(CPOLoss):
"""
Implements SimPO Loss https://arxiv.org/abs/2405.14734
Adopted from https://github.com/princeton-nlp/SimPO
and https://github.com/huggingface/trl
"""
def get_losses(self, logits):
label_smoothing = 0
gamma = self.cfg.training.simpo_gamma
gamma_logratios = gamma / self.cfg.training.beta
logits = logits - gamma_logratios
losses = (
-F.logsigmoid(self.cfg.training.beta * logits) * (1 - label_smoothing)
- F.logsigmoid(-self.cfg.training.beta * logits) * label_smoothing
)
return losses
class Losses:
"""Losses factory."""
_losses = {
"DPOLoss": DPOLoss,
"DPOHingeLoss": DPOHingeLoss,
"DPOIPOLoss": DPOIPOLoss,
"KTOPairLoss": KTOPairLoss,
"CPOLoss": CPOLoss,
"SimPOLoss": SimPOLoss,
}
@classmethod
def names(cls) -> KeysView:
return cls._losses.keys()
@classmethod
def get(cls, name: str) -> Any:
"""Access to Losses.
Args:
name: losses name
Returns:
A class to build the Losses
"""
return cls._losses.get(name, DPOLoss)
# see https://github.com/huggingface/trl/commit/29d439a2043edf4455b05cae5a1e2ade69d22794
LOSS_REDUCTION = {
"DPOLoss": False,
"KTOPairLoss": False,
"DPOHingeLoss": True,
"DPOIPOLoss": True,
"CPOLoss": False,
"SimPOLoss": True,
}
|