Spaces:
Runtime error
Runtime error
File size: 8,793 Bytes
476ac07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
# ORPO Authors: Jiwoo Hong, Noah Lee, and James Thorne
# Official code: https://github.com/xfactlab/orpo
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.distributed as dist
import torch.nn.functional as F
from mmengine import MessageHub
from torch import nn
from xtuner.parallel.sequence import (gather_forward_split_backward,
get_sequence_parallel_group,
get_sequence_parallel_world_size,
split_for_sequence_parallel)
from .sft import SupervisedFinetune
class ORPO(SupervisedFinetune):
"""ORPO: Monolithic Preference Optimization without Reference Model
https://arxiv.org/abs/2403.07691
Args:
beta (float): Weight of the odds_ratio_loss. Defaults to 0.1.
"""
def __init__(self, *args, beta=0.1, **kwargs):
super().__init__(*args, **kwargs)
self.beta = beta
def _gather_masked_logits(self, logits, labels, mask):
logits = torch.gather(
logits.log_softmax(-1), dim=2,
index=labels.unsqueeze(2)).squeeze(2)
return logits * mask
def get_logps(
self,
all_logits, # bs, seqlen,vocab_size
average_log_prob, # bs, seqlen,vocab_size
labels, # bs, seqlen
):
labels = labels[:, 1:].clone()
all_logits = all_logits[:, :-1, :]
labels[labels == -100] = 0
loss_mask = labels != 0
all_logps = self._gather_masked_logits(all_logits, labels,
loss_mask).sum(-1)
if average_log_prob: # average_log_prob
all_logps = all_logps / loss_mask.sum(-1)
chosen_logps = all_logps[::2]
rejected_logps = all_logps[1::2]
return chosen_logps, rejected_logps
def get_var_len_atten_logps(self, all_logits, average_log_prob, labels,
cu_seqlens, attention_mask):
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
# unpack sequence
unpacked_logits = torch.split(all_logits, seqlens, dim=1)
unpacked_labels = torch.split(labels, seqlens, dim=1)
if attention_mask is not None:
# It indicate that we pad the original sequence, labels,
# position_ids and cumulative_len for sequence parallel if the
# attention_mask is not None.
# We then need to remove the padded segments.
assert False in attention_mask
unpacked_logits = unpacked_logits[:-1]
unpacked_labels = unpacked_labels[:-1]
assert len(unpacked_logits) % 2 == 0
def compute_logps(_logits, _labels):
_labels = _labels[:, 1:].clone()
_logits = _logits[:, :-1, :]
_labels[_labels == -100] = 0
loss_mask = _labels != 0
logps = self._gather_masked_logits(_logits, _labels, loss_mask)
logps = logps.sum(-1)
if average_log_prob:
logps /= loss_mask.sum(-1)
return logps
chosen_logps, rejected_logps = [], []
for i in range(len(unpacked_logits) // 2):
chosen = unpacked_logits[2 * i]
rejected = unpacked_logits[2 * i + 1]
chosen_label = unpacked_labels[2 * i]
rejected_label = unpacked_labels[2 * i + 1]
chosen_logps.append(compute_logps(chosen, chosen_label))
rejected_logps.append(compute_logps(rejected, rejected_label))
return (torch.stack(chosen_logps), torch.stack(rejected_logps))
def cross_entropy_loss(self, logits, labels):
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss
def odds_ratio_loss(
self,
chosen_logps: torch.FloatTensor,
rejected_logps: torch.FloatTensor,
):
# modified from https://github.com/huggingface/trl/blob/b031adfdb8708f1f295eab6c3f2cb910e8fe0c23/trl/trainer/orpo_trainer.py#L597 # noqa
# Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) # noqa
log_odds = (chosen_logps - rejected_logps) - (
torch.log1p(-torch.exp(chosen_logps)) -
torch.log1p(-torch.exp(rejected_logps)))
ratio = F.logsigmoid(log_odds)
ratio = ratio[~torch.isnan(ratio)] # select valid loss
losses = self.beta * ratio
chosen_rewards = self.beta * chosen_logps
rejected_rewards = self.beta * rejected_logps
return losses, chosen_rewards, rejected_rewards, torch.mean(
ratio), torch.mean(log_odds)
@staticmethod
def _split_for_sequence_parallel(data):
# attention mask should not be split
ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids')
sp_group = get_sequence_parallel_group()
for key in ARGS_NEED_TO_SPLIT:
val = data.get(key, None)
if val is not None:
# `dim` is 1 as the shape of tensor is (bs, seq_len, ...)
data[key] = split_for_sequence_parallel(
val, dim=1, sp_group=sp_group)
return data
def compute_loss(self, data, data_samples=None):
labels_ori = data.pop('labels')
if get_sequence_parallel_world_size() > 1:
data = self._split_for_sequence_parallel(data)
all_logits = self.llm(**data).logits
if get_sequence_parallel_world_size() > 1:
all_logits = gather_forward_split_backward(
all_logits,
dim=1,
sp_group=get_sequence_parallel_group(),
grad_scale='up')
if not self.use_varlen_attn:
chosen_nll_loss = self.cross_entropy_loss(all_logits[::2],
labels_ori.clone()[::2])
chosen_logps, rejected_logps = self.get_logps(
all_logits, True, labels_ori)
else:
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cu_seqlens = message_hub.get_info(f'cumulative_len_rank_{rank}')
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
attention_mask = data['attention_mask']
if attention_mask is not None:
# It indicate that we pad the original sequence, labels,
# position_ids and cumulative_len for sequence parallel if the
# attention_mask is not None.
# We then need to remove the padded segments.
logits = torch.split(all_logits, seqlens, dim=1)[:-1]
assert len(logits) % 2 == 0
chosen_logits = logits[::2]
labels = torch.split(labels_ori.clone(), seqlens, dim=1)[:-1]
assert len(labels) % 2 == 0
chosen_labels = labels[::2]
else:
chosen_logits = torch.split(all_logits, seqlens, dim=1)[::2]
chosen_labels = torch.split(
labels_ori.clone(), seqlens, dim=1)[::2]
chosen_logits = torch.cat(chosen_logits, dim=1)
chosen_labels = torch.cat(chosen_labels, dim=1)
chosen_nll_loss = self.cross_entropy_loss(chosen_logits,
chosen_labels)
chosen_logps, rejected_logps = self.get_var_len_atten_logps(
all_logits, True, labels_ori, cu_seqlens, attention_mask)
(losses, chosen_rewards, rejected_rewards, log_odds_ratio,
log_odds_chosen) = self.odds_ratio_loss(chosen_logps, rejected_logps)
losses = losses.mean()
# skip nan loss
if torch.isnan(chosen_nll_loss):
chosen_nll_loss = all_logits.mean() * 0
if torch.isnan(losses):
losses = all_logits.mean() * 0
loss = chosen_nll_loss - losses
reward_acc = (chosen_rewards > rejected_rewards).float().mean()
loss_dict = {
'loss': loss,
'chosen_rewards': chosen_rewards.mean(),
'rejected_rewards': rejected_rewards.mean(),
'reward_acc': reward_acc,
'reward_margin': (chosen_rewards - rejected_rewards).mean(),
'log_odds_ratio': log_odds_ratio,
'log_odds_chosen': log_odds_chosen,
'nll_loss': chosen_nll_loss.detach().mean()
}
return loss_dict
|