Spaces:
Runtime error
Runtime error
File size: 19,934 Bytes
476ac07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 |
# Copyright (c) OpenMMLab. All rights reserved.
import json
import math
import os
import warnings
from collections import OrderedDict
from contextlib import nullcontext
import torch
import torch.distributed as dist
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from mmengine.model import BaseModel
from mmengine.runner import load_checkpoint
from peft import get_peft_model, prepare_model_for_kbit_training
from torch import nn
from transformers import (AutoConfig, AutoModelForSequenceClassification,
PreTrainedModel, PreTrainedTokenizer)
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import no_init_weights
from xtuner.parallel.sequence import (gather_forward_split_backward,
get_sequence_parallel_group,
get_sequence_parallel_world_size,
split_for_sequence_parallel)
from xtuner.registry import BUILDER
from .modules import dispatch_modules
from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
from .utils import (LoadWoInit, find_all_linear_names,
get_peft_model_state_dict, make_inputs_require_grad,
traverse_dict)
def reduce_mean(tensor):
""""Obtain the mean of tensor on different GPUs."""
if not (dist.is_available() and dist.is_initialized()):
return tensor
tensor = tensor.clone()
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
return tensor
def smart_tokenizer_and_embedding_resize(
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
):
"""Resize embedding."""
if is_deepspeed_zero3_enabled():
import deepspeed
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings(
) is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(
params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
assert isinstance(model.get_output_embeddings(), nn.Linear)
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
num_new_tokens = len(tokenizer) - current_embedding_size
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
print_log(
f'Resized token embeddings from {current_embedding_size} to '
f'{len(tokenizer)}.', 'current')
class RewardModel(BaseModel):
def __init__(
self,
llm,
lora=None,
peft_model=None,
use_activation_checkpointing=True,
use_varlen_attn=False,
tokenizer=None,
max_position_embeddings=None,
reward_token_id=None,
loss_type='ranking',
penalty_type='log_barrier',
penalty_weight=0.01,
):
super().__init__()
with LoadWoInit():
if isinstance(llm, dict):
llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
self.llm = self._build_from_cfg_or_module(llm).model
self.v_head = nn.Linear(self.llm.config.hidden_size, 1, bias=False)
# zero init
self.v_head.weight.data.zero_()
self.reward_token_id = reward_token_id
assert loss_type in ('ranking',
'focal'), f'Unsupported loss type {loss_type}'
self.loss_type = loss_type
assert penalty_type in (
'log_barrier', 'L2',
'none'), f'Unsupported penalty type {penalty_type}'
self.penalty_type = penalty_type
self.penalty_weight = penalty_weight
if tokenizer is not None:
if isinstance(tokenizer, dict):
tokenizer = BUILDER.build(tokenizer)
smart_tokenizer_and_embedding_resize(tokenizer, self.llm)
self.llm.config.use_cache = False
dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn)
if use_activation_checkpointing:
# For backward compatibility
if hasattr(self.llm, 'enable_input_require_grads'):
self.llm.enable_input_require_grads()
else:
self.llm.get_input_embeddings().register_forward_hook(
make_inputs_require_grad)
# enable gradient checkpointing for memory efficiency
self.gradient_checkpointing_enable()
if isinstance(lora, dict) or isinstance(lora, Config) or isinstance(
lora, ConfigDict):
self.lora = BUILDER.build(lora)
else:
self.lora = lora
self.peft_model = peft_model
self.use_lora = lora is not None
if self.use_lora:
self._prepare_for_lora(peft_model, use_activation_checkpointing)
self._is_init = True
# Determines whether to calculate attention based on the
# seq_len dimension (use_varlen_attn = False) or the actual length of
# the sequence.
self.use_varlen_attn = use_varlen_attn
def gradient_checkpointing_enable(self):
self.activation_checkpointing_enable()
def activation_checkpointing_enable(self):
self.llm.gradient_checkpointing_enable()
def gradient_checkpointing_disable(self):
self.activation_checkpointing_disable()
def activation_checkpointing_disable(self):
self.llm.gradient_checkpointing_disable()
def _prepare_for_lora(self,
peft_model=None,
use_activation_checkpointing=True):
self.llm = prepare_model_for_kbit_training(
self.llm, use_activation_checkpointing)
if self.lora.target_modules is None:
modules = find_all_linear_names(self.llm)
self.lora.target_modules = modules
self.llm = get_peft_model(self.llm, self.lora)
if peft_model is not None:
_ = load_checkpoint(self, peft_model)
def init_weights(self):
pass
@staticmethod
def _prepare_for_long_context_training(cfg, llm_cfg,
max_position_embeddings):
if not hasattr(llm_cfg, 'rope_scaling'):
print_log('Current model does not support RoPE scaling.',
'current')
return
current_max_length = getattr(llm_cfg, 'max_position_embeddings', None)
if current_max_length and max_position_embeddings > current_max_length:
print_log(
f'Enlarge max model length from {current_max_length} '
f'to {max_position_embeddings}.', 'current')
scaling_factor = float(
math.ceil(max_position_embeddings / current_max_length))
else:
print_log(
'The input `max_position_embeddings` is smaller than '
'origin max length. Consider increase input length.',
'current')
scaling_factor = 1.0
cfg.rope_scaling = {'type': 'linear', 'factor': scaling_factor}
return cfg
@staticmethod
def _prepare_for_flash_attn(cfg, llm_cfg):
cls_name = type(llm_cfg).__name__
SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
'Starcoder2Config', 'Starcoder2Config',
'Phi3Config')
SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
'MistralConfig', 'MixtralConfig', 'Qwen2Config',
'Qwen2MoeConfig', 'Starcoder2Config',
'Starcoder2Config', 'Phi3Config')
torch_dtype = torch.bfloat16 if (
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
else torch.float16
if getattr(cfg, 'attn_implementation', None) is not None:
# Flash Attention 2.0 only supports torch.float16 and
# torch.bfloat16 dtypes
if cfg.attn_implementation == 'flash_attention_2':
cfg.torch_dtype = torch_dtype
elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
cfg.torch_dtype = torch_dtype
cfg.attn_implementation = 'flash_attention_2'
elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
cfg.attn_implementation = 'sdpa'
return cfg
@staticmethod
def _prepare_for_qlora_zero3(cfg):
if (not is_deepspeed_zero3_enabled()) or (not hasattr(
cfg, 'quantization_config')):
return cfg
torch_dtype = torch.bfloat16 if (
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
else torch.float16
cfg.torch_dtype = torch_dtype
quantization_config = cfg.quantization_config
quantization_config.bnb_4bit_compute_dtype = torch_dtype
quantization_config.bnb_4bit_quant_storage = torch_dtype
return cfg
def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
cfg = self._prepare_for_qlora_zero3(cfg)
pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
llm_cfg = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True)
cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
if max_position_embeddings is not None:
cfg = self._prepare_for_long_context_training(
cfg, llm_cfg, max_position_embeddings)
return cfg
def _build_from_cfg_or_module(self, cfg_or_mod):
if isinstance(cfg_or_mod, nn.Module):
return cfg_or_mod
elif isinstance(cfg_or_mod, dict):
traverse_dict(cfg_or_mod)
return BUILDER.build(cfg_or_mod)
else:
raise NotImplementedError
def forward(self, data, data_samples=None, mode='loss'):
labels = data.pop('labels', None)
if mode == 'loss':
return self.compute_loss(data, labels)
elif mode == 'predict':
return self.predict(data, data_samples)
elif mode == 'tensor':
return self._forward(data, data_samples)
else:
raise NotImplementedError
def _forward(self, data, data_samples=None):
hidden_states = self.llm(**data)[0]
logits = self.v_head(hidden_states)
return logits
def predict(self, data, data_samples=None):
hidden_states = self.llm(**data)[0]
logits = self.v_head(hidden_states)
logits_dict = [{'logits': log} for log in logits]
return logits_dict
@staticmethod
def _split_for_sequence_parallel(data):
# attention mask should not be split
ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids')
sp_group = get_sequence_parallel_group()
for key in ARGS_NEED_TO_SPLIT:
val = data.get(key, None)
if val is not None:
# `dim` is 1 as the shape of tensor is (bs, seq_len, ...)
data[key] = split_for_sequence_parallel(
val, dim=1, sp_group=sp_group)
return data
def compute_loss(self, data, labels=None):
if get_sequence_parallel_world_size() > 1:
data = self._split_for_sequence_parallel(data)
hidden_states = self.llm(**data)[0]
logits = self.v_head(hidden_states)
if get_sequence_parallel_world_size() > 1:
logits = gather_forward_split_backward(
logits,
dim=1,
sp_group=get_sequence_parallel_group(),
grad_scale='up')
chosen_idx = torch.where(labels == 0)
rejected_idx = torch.where(labels == 1)
chosen_logits = logits[chosen_idx]
rejected_logits = logits[rejected_idx]
num_samples = torch.tensor(len(chosen_logits)).float().to(
hidden_states.device)
avg_factor = 1.0 / num_samples
avg_factor = reduce_mean(avg_factor).to(hidden_states.device)
chosen_mean = reduce_mean(chosen_logits.mean().detach())
rejected_mean = reduce_mean(rejected_logits.mean().detach())
acc = reduce_mean(
(chosen_logits > rejected_logits).sum() / num_samples).detach()
num_tokens = torch.tensor(labels.shape[1]).float()
# ranking loss
if self.loss_type == 'ranking':
rank_loss = self.ranking_loss(
chosen_logits, rejected_logits, avg_factor=avg_factor)
elif self.loss_type == 'focal':
rank_loss = self.focal_loss(
chosen_logits, rejected_logits, avg_factor=avg_factor)
else:
raise NotImplementedError(
f'Unsupported loss type {self.loss_type}')
# penalty loss
if self.penalty_type == 'log_barrier':
penalty = self.log_barrier_penalty(
torch.cat([chosen_logits, rejected_logits]),
lower_bound=-5,
upper_bound=5,
avg_factor=avg_factor)
elif self.penalty_type == 'L2':
penalty = self.l2_penalty(
torch.cat([chosen_logits, rejected_logits]),
avg_factor=avg_factor)
elif self.penalty_type == 'none':
penalty = 0
else:
raise NotImplementedError(
f'Unsupported penalty type {self.penalty_type}')
loss = rank_loss + self.penalty_weight * penalty
loss_dict = {
'loss': loss,
'acc': acc,
'chosen_score_mean': chosen_mean,
'rejected_score_mean': rejected_mean,
'num_samples': num_samples,
'num_tokens': num_tokens,
}
return loss_dict
def ranking_loss(self, chosen_logits, rejected_logits, avg_factor):
rank_loss = -nn.functional.logsigmoid(chosen_logits - rejected_logits)
return rank_loss.sum() * avg_factor
def focal_loss(self, chosen_logits, rejected_logits, avg_factor):
# focal ranking loss from InternLM2 paper https://arxiv.org/abs/2403.17297 # noqa
rank_loss = -nn.functional.logsigmoid(chosen_logits - rejected_logits)
p_ij = torch.sigmoid(chosen_logits - rejected_logits)
p = 2 * torch.relu(p_ij - 0.5)
gamma = 2
focal_loss = ((1 - p)**gamma) * rank_loss
return focal_loss.sum() * avg_factor
def log_barrier_penalty(self,
logits,
lower_bound,
upper_bound,
epsilon=1e-3,
avg_factor=1):
# log barrier penalty from InternLM2 paper https://arxiv.org/abs/2403.17297 # noqa
logits_fp32 = logits.float()
logits_clamped = torch.clamp(logits_fp32, lower_bound + epsilon,
upper_bound - epsilon)
penalty = -torch.log(upper_bound - logits_clamped) - torch.log(
logits_clamped - lower_bound)
return penalty.sum() * avg_factor
def l2_penalty(self, logits, avg_factor=1):
return (logits**2).sum() * avg_factor
def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
if not self.use_lora:
return state_dict
to_return = get_peft_model_state_dict(self.llm, state_dict=state_dict)
return OrderedDict(to_return)
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.llm, name)
def to_hf(self,
cfg,
save_dir,
fp32=False,
save_pretrained_kwargs={},
**kwargs):
print(f'Saving LLM tokenizer to {save_dir}')
tokenizer = BUILDER.build(cfg.tokenizer)
tokenizer.save_pretrained(save_dir)
if 'PeftModel' in self.llm.__class__.__name__:
# merge adapter
self.llm = self.llm.merge_and_unload()
if 'InternLM2' in self.llm.__class__.__name__:
from xtuner.tools.model_converters.modeling_internlm2_reward.modeling_internlm2 import \
InternLM2ForRewardModel # noqa
print(f'Saving Reward Model to {save_dir}')
hf_cfg = self.llm.config
hf_cfg.reward_token_id = self.reward_token_id if \
self.reward_token_id is not None else cfg.reward_token_id
if not fp32:
dtype = torch.float16
else:
dtype = torch.float32
with no_init_weights():
reward_model = InternLM2ForRewardModel._from_config(
hf_cfg, torch_dtype=dtype)
reward_model.model.load_state_dict(self.llm.state_dict())
reward_model.v_head.load_state_dict(self.v_head.state_dict())
reward_model.save_pretrained(save_dir, **save_pretrained_kwargs)
# fix auto_map in config
with open(os.path.join(save_dir, 'config.json')) as fp:
config_dict = json.load(fp)
config_dict['auto_map'][
'AutoModel'] = 'modeling_internlm2.InternLM2ForRewardModel'
config_dict['auto_map'].pop('AutoModelForCausalLM', None)
with open(os.path.join(save_dir, 'config.json'), 'w') as fp:
json.dump(config_dict, fp, indent=2)
else:
warnings.warn(
f'The pretrained model type: {self.llm.__class__.__name__} '
'has no reward model class defined. Use '
'the SequenceClassification class instead.'
'You can refer to `xtuner/tools/model_converters/modeling_internlm2_reward` ' # noqa
'to implement the reward model class.')
hf_cfg = self.llm.config
hf_cfg.num_labels = 1 # set the output dim to 1
try:
with no_init_weights():
reward_model = \
AutoModelForSequenceClassification.from_config(hf_cfg)
except Exception as e:
warnings.warn(f'Cannot find SequenceClassification class '
f'from transformers: {e}, \n'
'try to find it in the dynamic module.')
module_file, causal_model_name = hf_cfg.auto_map[
'AutoModelForCausalLM'].split('.')
seqcls_model_name = causal_model_name.split(
'For')[0] + 'ForSequenceClassification'
seqcls_class = get_class_from_dynamic_module(
f'{module_file}.{seqcls_model_name}', hf_cfg._name_or_path)
with no_init_weights():
reward_model = seqcls_class(hf_cfg)
reward_model.model.load_state_dict(self.llm.state_dict())
reward_model.score.load_state_dict(self.v_head.state_dict())
reward_model.save_pretrained(save_dir, **save_pretrained_kwargs)
|