# coding=utf-8 # Copyright 2022 HuggingFace Inc. team and BigScience workshop. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch BLOOM model.""" import math import warnings from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn import functional as F from ...cache_utils import Cache, DynamicCache, StaticCache from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel from ...utils import logging from .configuration_bloom import BloomConfig logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "bigscience/bloom-560m" _CONFIG_FOR_DOC = "BloomConfig" # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, min_dtype: float, cache_position: torch.Tensor, batch_size: int, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to plcae the 4D attention mask on. min_dtype (`float`): The minimum value representable with the dtype `dtype`. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: """ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value `softmax(l+a) = softmax(l)`. Based on https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. Args: Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) attention_mask (`torch.Tensor`): Token-wise attention mask, this should be of shape (batch_size, max_seq_len). num_heads (`int`): number of heads dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): dtype of the output tensor """ batch_size, seq_length = attention_mask.shape closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) base = torch.tensor( 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 ) powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != num_heads: extra_base = torch.tensor( 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 ) num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) # Note: alibi will added to the attention bias that will be applied to the query, key product of attention # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) # => the query_length dimension will then be broadcasted correctly # This is more or less identical to T5's relative position bias: # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] alibi = slopes[..., None] * arange_tensor return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: """ Dropout add function Args: x (`torch.tensor`): input tensor residual (`torch.tensor`): residual tensor prob (`float`): dropout probability training (`bool`): training mode """ out = F.dropout(x, p=prob, training=training) out = residual + out return out def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor: """ Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to make the model jitable. Args: x (`torch.tensor`): input hidden states """ return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) Args: g (`torch.tensor`): gradient output tensor x (`torch.tensor`): input tensor """ x = x[0] # x is a tuple of 1 element, needs to unpack it first tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) return ff * g class GeLUFunction(torch.autograd.Function): @staticmethod def forward(ctx, input: torch.Tensor) -> torch.Tensor: ctx.save_for_backward(input) return bloom_gelu_forward(input) @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: input = ctx.saved_tensors tmp = bloom_gelu_back(grad_output, input) return tmp class BloomGelu(nn.Module): """ BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly copied from Megatron-DeepSpeed code and adapted for our needs See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329 """ def __init__(self): super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training: return GeLUFunction.apply(x) else: return bloom_gelu_forward(x) class BloomAttention(nn.Module): def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None): super().__init__() self.pretraining_tp = config.pretraining_tp self.slow_but_exact = config.slow_but_exact self.hidden_size = config.hidden_size self.num_heads = config.n_head self.head_dim = self.hidden_size // self.num_heads self.split_size = self.hidden_size self.hidden_dropout = config.hidden_dropout if self.head_dim * self.num_heads != self.hidden_size: raise ValueError( f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" f" {self.num_heads})." ) # Layer-wise attention scaling self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) self.beta = 1.0 self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) self.dense = nn.Linear(self.hidden_size, self.hidden_size) self.attention_dropout = nn.Dropout(config.attention_dropout) def _reshape(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Split the last dimension into (num_heads, head_dim) and reshapes to (bs, heads, len, dim) shape without making any copies, results share same memory storage as `fused_qkv` Args: fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim] Returns: query: [batch_size, num_heads, seq_length, head_dim] key: [batch_size, num_heads, seq_length, head_dim] value: [batch_size, num_heads, seq_length, head_dim] """ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) query_layer = fused_qkv[..., 0, :].transpose(1, 2) key_layer = fused_qkv[..., 1, :].transpose(1, 2) value_layer = fused_qkv[..., 2, :].transpose(1, 2) return query_layer, key_layer, value_layer def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: """ Merge heads together over the last dimension Args: x (`torch.tensor`): [batch_size * num_heads, seq_length, head_dim] Returns: torch.tensor: [batch_size, seq_length, num_heads * head_dim] """ # What we want to achieve is: # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim batch_size_and_num_heads, seq_length, _ = x.shape batch_size = batch_size_and_num_heads // self.num_heads # First view to decompose the batch size # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim x = x.view(batch_size, self.num_heads, seq_length, self.head_dim) # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim x = x.permute(0, 2, 1, 3) # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim) def forward( self, hidden_states: torch.Tensor, residual: torch.Tensor, alibi: torch.Tensor, attention_mask: torch.Tensor, layer_past: Optional[Cache] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, ): batch_size, q_length, _ = hidden_states.shape fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, num_heads, seq_length, head_dim] query_layer, key_layer, value_layer = self._reshape(fused_qkv) if layer_past is not None: cache_kwargs = {"cache_position": cache_position} key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) # reshape qkv for further computations query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(1, 2) value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) kv_length = cache_position[-1] + 1 # cache position is 0-indexed while length should start from 1 # [batch_size * num_heads, q_length, kv_length] # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 matmul_result = alibi.baddbmm( batch1=query_layer, batch2=key_layer, beta=self.beta, alpha=self.inv_norm_factor, ) # change view to [batch_size, num_heads, q_length, kv_length] attn_weights = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, :kv_length] attn_weights = attn_weights + causal_mask # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype) # [batch_size, num_heads, q_length, kv_length] attention_probs = self.attention_dropout(attention_probs) if head_mask is not None: attention_probs = attention_probs * head_mask # change view [batch_size x num_heads, q_length, kv_length] attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = torch.bmm(attention_probs_reshaped, value_layer) # change view [batch_size, q_length, num_heads * head_dim] context_layer = self._merge_heads(context_layer) # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 if self.pretraining_tp > 1 and self.slow_but_exact: slices = self.hidden_size / self.pretraining_tp output_tensor = torch.zeros_like(context_layer) for i in range(self.pretraining_tp): output_tensor = output_tensor + F.linear( context_layer[:, :, int(i * slices) : int((i + 1) * slices)], self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: output_tensor = self.dense(context_layer) output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) outputs = (output_tensor, layer_past) if output_attentions: outputs += (attention_probs,) return outputs class BloomMLP(nn.Module): def __init__(self, config: BloomConfig): super().__init__() hidden_size = config.hidden_size self.pretraining_tp = config.pretraining_tp self.slow_but_exact = config.slow_but_exact self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size) self.gelu_impl = BloomGelu() self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size) self.hidden_dropout = config.hidden_dropout def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) if self.pretraining_tp > 1 and self.slow_but_exact: intermediate_output = torch.zeros_like(residual) slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp for i in range(self.pretraining_tp): intermediate_output = intermediate_output + F.linear( hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: intermediate_output = self.dense_4h_to_h(hidden_states) output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training) return output class BloomBlock(nn.Module): def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None): super().__init__() hidden_size = config.hidden_size self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.num_heads = config.n_head self.self_attention = BloomAttention(config, layer_idx) self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = BloomMLP(config) self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm self.hidden_dropout = config.hidden_dropout def forward( self, hidden_states: torch.Tensor, alibi: torch.Tensor, attention_mask: torch.Tensor, layer_past: Optional[Cache] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, ): # hidden_states: [batch_size, seq_length, hidden_size] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Layer norm post the self attention. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states # Self attention. attn_outputs = self.self_attention( layernorm_output, residual, layer_past=layer_past, attention_mask=attention_mask, alibi=alibi, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, ) attention_output = attn_outputs[0] outputs = attn_outputs[1:] layernorm_output = self.post_attention_layernorm(attention_output) # Get residual if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = attention_output # MLP. output = self.mlp(layernorm_output, residual) if use_cache: outputs = (output,) + outputs else: outputs = (output,) + outputs[1:] return outputs # hidden_states, past_kv, attentions class BloomPreTrainedModel(PreTrainedModel): config_class = BloomConfig base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["BloomBlock"] _skip_keys_device_placement = "past_key_values" _supports_cache_class = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) def _init_weights(self, module: nn.Module): """Initialize the weights.""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) BLOOM_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`BloomConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ BLOOM_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as `input_ids`. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see `past_key_values`). use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ @add_start_docstrings( "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.", BLOOM_START_DOCSTRING, ) class BloomModel(BloomPreTrainedModel): def __init__(self, config: BloomConfig): super().__init__(config) self.embed_dim = config.hidden_size self.num_heads = config.n_head # Embedding + LN Embedding self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) # Transformer blocks self.h = nn.ModuleList([BloomBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) # Final Layer Norm self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: return build_alibi_tensor(attention_mask, num_heads, dtype) def get_input_embeddings(self): return self.word_embeddings def set_input_embeddings(self, new_embeddings: torch.Tensor): self.word_embeddings = new_embeddings @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop("position_ids", False) is not False: # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` warnings.warn( "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.", FutureWarning, ) if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) # kept for BC (non `Cache` `past_key_values` inputs) use_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache) and not self.training: use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( "Using `past_key_values` as a tuple is deprecated and will be removed in v4.45. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) batch_size, seq_length, _ = inputs_embeds.shape past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 seq_length_with_past = seq_length + past_length if cache_position is None: cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) hidden_states = self.word_embeddings_layernorm(inputs_embeds) next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None # Compute alibi tensor: check build_alibi_tensor documentation if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) else: attention_mask = attention_mask.to(hidden_states.device) alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, alibi, causal_mask, past_key_values, head_mask[i], use_cache, output_attentions, cache_position, ) else: outputs = block( hidden_states, layer_past=past_key_values, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, cache_position=cache_position, ) hidden_states = outputs[0] if use_cache: next_decoder_cache = outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) # Add last hidden state hidden_states = self.ln_f(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) next_cache = None if use_cache: next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, ) # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @add_start_docstrings( """ The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings). """, BLOOM_START_DOCSTRING, ) class BloomForCausalLM(BloomPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: BloomConfig): super().__init__(config) self.transformer = BloomModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings: torch.Tensor): self.lm_head = new_embeddings def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, use_cache=True, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, } ) return model_inputs @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ if deprecated_arguments.pop("position_ids", False) is not False: # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` warnings.warn( "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.", FutureWarning, ) if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct( shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithCrossAttentions( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) def _reorder_cache( self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. Output shares the same memory storage as `past`. """ # Get a copy of `beam_idx` on all the devices where we need those indices. device_to_beam_idx = { past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past } reordered_past = tuple( ( layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), ) for layer_past in past ) return reordered_past @add_start_docstrings( """ The Bloom Model transformer with a sequence classification head on top (linear layer). [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-1) do. Since it does classification on the last token, it requires to know the position of the last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, BLOOM_START_DOCSTRING, ) class BloomForSequenceClassification(BloomPreTrainedModel): def __init__(self, config: BloomConfig): super().__init__(config) self.num_labels = config.num_labels self.transformer = BloomModel(config) self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ if deprecated_arguments.pop("position_ids", False) is not False: # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` warnings.warn( "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.", FutureWarning, ) if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 logger.warning_once( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) @add_start_docstrings( """ Bloom Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, BLOOM_START_DOCSTRING, ) class BloomForTokenClassification(BloomPreTrainedModel): def __init__(self, config: BloomConfig): super().__init__(config) self.num_labels = config.num_labels self.transformer = BloomModel(config) if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: classifier_dropout = config.classifier_dropout elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: classifier_dropout = config.hidden_dropout else: classifier_dropout = 0.1 self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ if deprecated_arguments.pop("position_ids", False) is not False: # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` warnings.warn( "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.", FutureWarning, ) if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] hidden_states = self.dropout(hidden_states) logits = self.classifier(hidden_states) loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(logits.device) batch_size, seq_length = labels.shape loss_fct = CrossEntropyLoss() loss = loss_fct( logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) ) if not return_dict: output = (logits,) + transformer_outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) @add_start_docstrings( """ The BLOOM Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, BLOOM_START_DOCSTRING, ) class BloomForQuestionAnswering(BloomPreTrainedModel): def __init__(self, config): super().__init__(config) self.transformer = BloomModel(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.transformer( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )