|
import json |
|
import os |
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
from transformers import Trainer |
|
|
|
from ...extras.logging import get_logger |
|
|
|
|
|
if TYPE_CHECKING: |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.trainer import PredictionOutput |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
class PairwiseTrainer(Trainer): |
|
r""" |
|
Inherits PeftTrainer to compute pairwise loss. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.can_return_loss = True |
|
|
|
def compute_loss( |
|
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False |
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: |
|
r""" |
|
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. |
|
|
|
Subclass and override to inject custom behavior. |
|
|
|
Note that the first element will be removed from the output tuple. |
|
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509 |
|
""" |
|
|
|
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True) |
|
|
|
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model) |
|
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": |
|
values = torch.transpose(values, 0, 1) |
|
|
|
|
|
batch_size = inputs["input_ids"].size(0) // 2 |
|
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:] |
|
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:] |
|
chosen_scores, rejected_scores = [], [] |
|
|
|
|
|
|
|
loss = 0 |
|
for i in range(batch_size): |
|
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 |
|
rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 |
|
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero() |
|
|
|
if len(check_divergence) == 0: |
|
end_index = chosen_length |
|
div_index = end_index - 1 |
|
else: |
|
end_index = max(chosen_length, rejected_length) |
|
div_index = check_divergence[0] |
|
|
|
assert div_index > 0 |
|
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index] |
|
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index] |
|
if return_outputs: |
|
chosen_scores.append(chosen_rewards[i, chosen_length - 1]) |
|
rejected_scores.append(rejected_rewards[i, rejected_length - 1]) |
|
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean() |
|
|
|
loss = loss / batch_size |
|
if return_outputs: |
|
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores) |
|
return loss, [loss, chosen_scores, rejected_scores] |
|
|
|
return loss |
|
|
|
def save_predictions(self, predict_results: "PredictionOutput") -> None: |
|
r""" |
|
Saves model predictions to `output_dir`. |
|
|
|
A custom behavior that not contained in Seq2SeqTrainer. |
|
""" |
|
if not self.is_world_process_zero(): |
|
return |
|
|
|
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") |
|
logger.info(f"Saving prediction results to {output_prediction_file}") |
|
chosen_scores, rejected_scores = predict_results.predictions |
|
|
|
with open(output_prediction_file, "w", encoding="utf-8") as writer: |
|
res: List[str] = [] |
|
for c_score, r_score in zip(chosen_scores, rejected_scores): |
|
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)})) |
|
writer.write("\n".join(res)) |
|
|