tess-2-demo / sdlm /train_reward_model.py
hamishivi's picture
commit
17ff0d8 verified
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python examples/scripts/reward_modeling.py \
--model_name_or_path=facebook/opt-350m \
--output_dir="reward_modeling_anthropic_hh" \
--per_device_train_batch_size=16 \
--num_train_epochs=1 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing=True \
--learning_rate=1.41e-5 \
--report_to="wandb" \
--remove_unused_columns=False \
--optim="adamw_torch" \
--logging_steps=10 \
--eval_strategy="steps" \
--eval_steps=500 \
--max_length=512 \
"""
import warnings
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim.lr_scheduler import LambdaLR
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser, PreTrainedModel
from transformers.trainer_pt_utils import nested_detach
from trl import ModelConfig, RewardConfig, RewardTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
from sdlm.models.mistral.modeling_mistral import MistralforSequenceClassificationWithPadding
from sdlm.models.utils import get_torch_dtype
from sdlm.schedulers import SimplexDDPMScheduler
tqdm.pandas()
# TODO: allow end_lr to be changed via some config.
def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, end_lr_ratio: float = 0.1):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
else:
return end_lr_ratio + (1.0 - end_lr_ratio) * max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, end_lr_ratio, last_epoch=-1):
lr_lambda = partial(
_get_linear_schedule_with_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
end_lr_ratio=end_lr_ratio,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
# new little trainer with the scheduler we want.
class RewardTrainerScheduler(RewardTrainer):
def __init__(self, *args, train_on_noisy_inputs=False, **kwargs):
super().__init__(*args, **kwargs)
self.train_on_noisy_inputs = train_on_noisy_inputs
self.noise_scheduler = SimplexDDPMScheduler(
num_train_timesteps=5000,
beta_schedule="squaredcos_improved_ddpm",
simplex_value=5,
clip_sample=False,
device='cuda',
)
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
if self.lr_scheduler is None:
self.lr_scheduler = get_linear_schedule_with_warmup(optimizer, self.args.warmup_steps, num_training_steps, end_lr_ratio=0.1)
self._created_lr_scheduler = True
return self.lr_scheduler
def prediction_step(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
with torch.no_grad():
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
if prediction_loss_only:
return (loss, None, None)
loss = loss.detach()
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
logits = nested_detach(logits)
# Stack accepted against rejected, mean over logits
# and softmax to get preferences between accepted and rejected to sum to 1
# logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
# removing softmax for now, since I want to see the raw logits.
logits = torch.stack(logits).mean(dim=2).T
labels = torch.zeros(logits.shape[0])
labels = self._prepare_inputs(labels)
return loss, logits, labels
# hacky override to set cache to false
# required to fix FA2 + mistral issues
# see https://github.com/huggingface/trl/issues/1217
def compute_loss(
self,
model,
inputs,
return_outputs=False,
):
if not self.use_reward_data_collator:
warnings.warn(
"The current compute_loss is implemented for RewardDataCollatorWithPadding,"
" if you are using a custom data collator make sure you know what you are doing or"
" implement your own compute_loss method."
)
if self.train_on_noisy_inputs:
from sdlm.utils import convert_to_simplex
def construct_noisy_simplex(input_ids):
# hardcoded simplex value for now TODO: make this a config
simplex = convert_to_simplex(
input_ids, 5, len(self.tokenizer)
)
noise = 5 * torch.randn(
simplex.shape, device=simplex.device, dtype=torch.float32
)
bsz = simplex.shape[0]
timesteps = torch.randint(
0,
5000, # hardcoded value for now TODO: make this a config
(bsz, input_ids.shape[1])
if False # is_tokenwise_cdcd_check(self.model)
else (bsz,),
device=simplex.device,
dtype=torch.int64,
)
timesteps = timesteps[:, None].expand(-1, input_ids.shape[1])
# Adds noise to each simplex representation (Forward diffusion process).
noisy_simplex = self.noise_scheduler.add_noise(simplex, noise, timesteps)
return noisy_simplex.detach() # detach to avoid backpropagating through the noise
simplex_chosen = construct_noisy_simplex(inputs["input_ids_chosen"])
simplex_chosen = torch.softmax(simplex_chosen, dim=-1).to(torch.bfloat16)
# unwrap model for FSDP, to compute input embeddings
with FSDP.summon_full_params(model):
embedding_weight = model.get_input_embeddings().weight.data
inputs_embeds_chosen = F.linear(
simplex_chosen, model.get_input_embeddings().weight.data.T
)
simplex_rejected = construct_noisy_simplex(inputs["input_ids_rejected"])
simplex_rejected = torch.softmax(simplex_rejected, dim=-1).to(torch.bfloat16)
inputs_embeds_rejected = F.linear(
simplex_rejected, model.get_input_embeddings().weight.data.T
)
rewards_chosen = model(
inputs_embeds=inputs_embeds_chosen,
attention_mask=inputs["attention_mask_chosen"],
return_dict=True,
use_cache=False,
)["logits"]
rewards_rejected = model(
inputs_embeds=inputs_embeds_rejected,
attention_mask=inputs["attention_mask_rejected"],
return_dict=True,
use_cache=False,
)["logits"]
else:
rewards_chosen = model(
input_ids=inputs["input_ids_chosen"],
attention_mask=inputs["attention_mask_chosen"],
return_dict=True,
use_cache=False,
)["logits"]
rewards_rejected = model(
input_ids=inputs["input_ids_rejected"],
attention_mask=inputs["attention_mask_rejected"],
return_dict=True,
use_cache=False,
)["logits"]
# calculate loss, optionally modulate with margin
if "margin" in inputs:
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
else:
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
if return_outputs:
return loss, {
"rewards_chosen": rewards_chosen,
"rewards_rejected": rewards_rejected,
}
return loss
@dataclass
class RewardModelingArguments:
include_padding: bool = False # if true, we pad the input_ids to the max_length and compute reward at final token.
use_tulu_chat_template: bool = False # if true, we use the tulu chat template for the input_ids.
end_lr: float = 1e-6 # final learning rate for the learning rate scheduler.
dataset_name: str = "argilla/ultrafeedback-binarized-preferences-cleaned" # dataset to use for reward modeling.
use_flash_attention2: bool = False # if true, we use the flash attention2 implementation.
eval_only: bool = False # if true, we only evaluate the model.
train_on_noisy_inputs: bool = False # if true, we emulate the diffusion noise as input during training.
if __name__ == "__main__":
parser = HfArgumentParser((RewardConfig, ModelConfig, RewardModelingArguments))
config, model_config, reward_config = parser.parse_args_into_dataclasses()
config.gradient_checkpointing_kwargs = dict(use_reentrant=False)
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
attn_implementation="flash_attention_2"
if reward_config.use_flash_attention2
else "eager",
torch_dtype=get_torch_dtype(config),
)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision, use_fast=True)
# just always add the pad token.
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
# make sure the pad token is set correctly.
tokenizer.pad_token = "[PAD]"
tokenizer.pad_token_id = 32000
if reward_config.use_tulu_chat_template:
tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
if reward_config.include_padding:
model = MistralforSequenceClassificationWithPadding.from_pretrained(
model_config.model_name_or_path, num_labels=1, **model_kwargs
)
else:
model = AutoModelForSequenceClassification.from_pretrained(
model_config.model_name_or_path, num_labels=1, **model_kwargs
)
# resize model embeddings
vocab_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > vocab_size:
model.resize_token_embeddings(len(tokenizer))
# make sure the model knows the pad token id
model.config.pad_token_id = tokenizer.pad_token_id
tokenizer.padding_side = "right"
if model_config.lora_task_type != "SEQ_CLS":
warnings.warn(
"You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
" Make sure to pass --lora_task_type SEQ_CLS when using this script."
)
# Dataset loading
raw_datasets = load_dataset(reward_config.dataset_name)
# use reward bench for validation.
eval_dataset = load_dataset("allenai/reward-bench", split="filtered")
# Tokenize chosen/rejected pairs of inputs
# Adapt this section to your needs for custom datasets
def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
# flatten from 2d to 1d
tokenize_func = lambda x: tokenizer.apply_chat_template(
x,
return_tensors="pt",
max_length=config.max_length,
padding=reward_config.include_padding,
).flatten()
tokenized_chosen = tokenize_func(chosen)
tokenized_rejected = tokenize_func(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen)
new_examples["attention_mask_chosen"].append(torch.ones_like(tokenized_chosen))
new_examples["input_ids_rejected"].append(tokenized_rejected)
new_examples["attention_mask_rejected"].append(torch.ones_like(tokenized_rejected))
return new_examples
def preprocess_function_no_list(examples):
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for prompt, chosen, rejected in zip(examples["prompt"], examples["chosen"], examples["rejected"]):
# construct lists
chosen = [{"role": "user", "content": prompt}, {"role": "assistant", "content": chosen}]
rejected = [{"role": "user", "content": prompt}, {"role": "assistant", "content": rejected}]
# same as above
tokenize_func = lambda x: tokenizer.apply_chat_template(
x,
return_tensors="pt",
max_length=config.max_length,
padding=reward_config.include_padding,
).flatten()
tokenized_chosen = tokenize_func(chosen)
tokenized_rejected = tokenize_func(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen)
new_examples["attention_mask_chosen"].append(torch.ones_like(tokenized_chosen))
new_examples["input_ids_rejected"].append(tokenized_rejected)
new_examples["attention_mask_rejected"].append(torch.ones_like(tokenized_rejected))
return new_examples
# Preprocess the dataset and filter out examples that are longer than args.max_length
raw_datasets = raw_datasets.map(
preprocess_function,
batched=True,
num_proc=4,
)
raw_datasets = raw_datasets.filter(
lambda x: len(x["input_ids_chosen"]) <= config.max_length and len(x["input_ids_rejected"]) <= config.max_length
)
train_dataset = raw_datasets["train"]
eval_dataset = eval_dataset.map(preprocess_function_no_list, batched=True, num_proc=4)
eval_dataset = eval_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= config.max_length and len(x["input_ids_rejected"]) <= config.max_length
)
################
# Training
################
trainer = RewardTrainerScheduler(
model=model,
tokenizer=tokenizer,
args=config,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=get_peft_config(model_config),
train_on_noisy_inputs=reward_config.train_on_noisy_inputs,
)
if not reward_config.eval_only:
trainer.train()
trainer.save_model(config.output_dir)
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
print(metrics)