# -------------------------------------------------------- # InternVL # Copyright (c) 2024 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from typing import Dict, List, Literal, Optional, Tuple, Union import torch from torch import nn from torch.utils.data import ConcatDataset from trl import DPOTrainer from trl.trainer.utils import RunningMoments, pad_to_length def _map(self, *args, **kwargs): return self ConcatDataset.map = _map class MultimodalDPOTrainer(DPOTrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.loss_type != 'bco_pair' and 'bco_pair' in self.loss_type: self.running = RunningMoments(self.accelerator) @staticmethod def concatenated_inputs( batch: Dict[str, Union[List, torch.LongTensor]], is_encoder_decoder: bool = False, is_vision_model: bool = False, label_pad_token_id: int = -100, padding_value: int = 0, device: Optional[torch.device] = None, ) -> Dict[str, torch.LongTensor]: """Concatenate the chosen and rejected inputs into a single tensor. Args: batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). is_encoder_decoder: Whether the model is an encoder-decoder model. label_pad_token_id: The label pad token id. padding_value: The padding value to use for the concatenated inputs_ids. device: The device for the concatenated inputs. Returns: A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. """ concatenated_batch = {} if is_encoder_decoder: max_length = max(batch['chosen_labels'].shape[1], batch['rejected_labels'].shape[1]) else: max_length = max(batch['chosen_input_ids'].shape[1], batch['rejected_input_ids'].shape[1]) for k in batch: if k.startswith('chosen') and isinstance(batch[k], torch.Tensor): if 'labels' in k or is_encoder_decoder: pad_value = label_pad_token_id elif k.endswith('_input_ids'): pad_value = padding_value elif k.endswith('_attention_mask'): pad_value = 0 concatenated_key = k.replace('chosen', 'concatenated') concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) for k in batch: if k.startswith('rejected') and isinstance(batch[k], torch.Tensor): if 'labels' in k or is_encoder_decoder: pad_value = label_pad_token_id elif k.endswith('_input_ids'): pad_value = padding_value elif k.endswith('_attention_mask'): pad_value = 0 concatenated_key = k.replace('rejected', 'concatenated') concatenated_batch[concatenated_key] = torch.cat( ( concatenated_batch[concatenated_key], pad_to_length(batch[k], max_length, pad_value=pad_value), ), dim=0, ).to(device=device) if is_encoder_decoder: concatenated_batch['concatenated_input_ids'] = batch['prompt_input_ids'].repeat(2, 1).to(device=device) concatenated_batch['concatenated_attention_mask'] = ( batch['prompt_attention_mask'].repeat(2, 1).to(device=device) ) if 'pixel_values' in batch: concatenated_batch['pixel_values'] = batch['pixel_values'].repeat(2, 1, 1, 1) concatenated_batch['image_flags'] = batch['image_flags'].repeat(2) return concatenated_batch def concatenated_forward( self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. We do this to avoid doing two forward passes, because it's faster for FSDP. """ concatenated_batch = self.concatenated_inputs( batch, is_encoder_decoder=self.is_encoder_decoder, is_vision_model=self.is_vision_model, label_pad_token_id=self.label_pad_token_id, padding_value=self.padding_value, device=self.accelerator.device, ) len_chosen = batch['chosen_labels'].shape[0] model_kwargs = {} if self.is_encoder_decoder: model_kwargs['labels'] = concatenated_batch['concatenated_labels'] model_kwargs['decoder_input_ids'] = concatenated_batch.pop('concatenated_decoder_input_ids', None) if self.is_vision_model: model_kwargs['pixel_values'] = concatenated_batch['pixel_values'] model_kwargs['pixel_attention_mask'] = concatenated_batch['pixel_attention_mask'] if self.aux_loss_enabled: model_kwargs['output_router_logits'] = True outputs = model( input_ids=concatenated_batch['concatenated_input_ids'], attention_mask=concatenated_batch['concatenated_attention_mask'], pixel_values=concatenated_batch['pixel_values'], image_flags=concatenated_batch['image_flags'], use_cache=False, **model_kwargs, ) all_logits = outputs.logits all_logps, size_completion = self.get_batch_logps( all_logits, concatenated_batch['concatenated_labels'], # average_log_prob=self.loss_type == "ipo", is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) def cross_entropy_loss(logits, labels): if not self.is_encoder_decoder: # Shift so that tokens < n predict n logits = logits[..., :-1, :].contiguous() labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() logits = logits.view(-1, logits.shape[-1]) labels = labels.view(-1) # Enable model parallelism labels = labels.to(logits.device) loss = loss_fct(logits, labels) return loss labels = concatenated_batch['concatenated_labels'].clone() nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) if self.loss_type == 'ipo': all_logps = all_logps / size_completion chosen_logps = all_logps[:len_chosen] rejected_logps = all_logps[len_chosen:] chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] if self.aux_loss_enabled: return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss) return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss) def _prepare_deepspeed(self, model): deepspeed_plugin = self.accelerator.state.deepspeed_plugin config_kwargs = deepspeed_plugin.deepspeed_config if config_kwargs['zero_optimization']['stage'] == 3: print('Enable DPOTrainer._prepare_deepspeed') return super()._prepare_deepspeed(model) print('Disable DPOTrainer._prepare_deepspeed') for param in model.parameters(): param.requires_grad = False model.eval() model = model.to(self.accelerator.device) return model def get_batch_loss_metrics( self, model, batch: Dict[str, Union[List, torch.LongTensor]], train_eval: Literal['train', 'eval'] = 'train', ): """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} forward_output = self.concatenated_forward(model, batch) ( policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, policy_nll_loss, ) = forward_output[:5] if self.aux_loss_enabled: aux_loss = forward_output[5] # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model if ( 'reference_chosen_logps' in batch and 'reference_rejected_logps' in batch and self.args.rpo_alpha is not None ): reference_chosen_logps = batch['reference_chosen_logps'] reference_rejected_logps = batch['reference_rejected_logps'] else: with torch.no_grad(): if self.ref_model is None: with self.null_ref_context(): ( reference_chosen_logps, reference_rejected_logps, _, _, _, ) = self.concatenated_forward(self.model, batch) else: ( reference_chosen_logps, reference_rejected_logps, _, _, _, ) = self.concatenated_forward(self.ref_model, batch) if ',' in self.loss_type: loss_type = self.loss_type loss_type_list = loss_type.split(',') losses, chosen_rewards, rejected_rewards = 0, 0, 0 for curr_type in loss_type_list: self.loss_type = curr_type curr_losses, curr_chosen_rewards, curr_rejected_rewards = self.dpo_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, ) curr_weight = getattr(self.args, f'{curr_type}_loss_weight') losses = losses + curr_losses * curr_weight chosen_rewards = chosen_rewards + curr_chosen_rewards * curr_weight rejected_rewards = rejected_rewards + curr_rejected_rewards * curr_weight self.loss_type = loss_type else: losses, chosen_rewards, rejected_rewards = self.dpo_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, ) reward_accuracies = (chosen_rewards > rejected_rewards).float() if self.args.rpo_alpha is not None: # losses = losses * self.args.rpo_alpha + policy_nll_loss losses = losses + policy_nll_loss * self.args.rpo_alpha prefix = 'eval_' if train_eval == 'eval' else '' metrics[f'{prefix}rewards/chosen'] = chosen_rewards.mean().cpu() metrics[f'{prefix}rewards/rejected'] = rejected_rewards.mean().cpu() metrics[f'{prefix}rewards/accuracies'] = reward_accuracies.mean().cpu() metrics[f'{prefix}rewards/margins'] = (chosen_rewards - rejected_rewards).mean().cpu() metrics[f'{prefix}logps/rejected'] = policy_rejected_logps.detach().mean().cpu() metrics[f'{prefix}logps/chosen'] = policy_chosen_logps.detach().mean().cpu() metrics[f'{prefix}logits/rejected'] = policy_rejected_logits.detach().mean().cpu() metrics[f'{prefix}logits/chosen'] = policy_chosen_logits.detach().mean().cpu() if self.args.rpo_alpha is not None: metrics[f'{prefix}nll_loss'] = policy_nll_loss.detach().mean().cpu() if self.aux_loss_enabled: return losses.mean() + getattr(model.config, 'router_aux_loss_coef', 0.0) * aux_loss, metrics return losses.mean(), metrics