diff --git "a/dependencies.py" "b/dependencies.py" new file mode 100644--- /dev/null +++ "b/dependencies.py" @@ -0,0 +1,5430 @@ +# coding=utf-8 +# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import inspect +import os +import warnings +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +import numpy as np +import torch +import torch.distributed as dist +from huggingface_hub import file_exists +from packaging import version +from torch import nn +from torch.nn import functional as F + +from ..cache_utils import ( + Cache, + DynamicCache, + EncoderDecoderCache, + HybridChunkedCache, + OffloadedCache, + OffloadedHybridCache, + QuantizedCacheConfig, +) +from ..configuration_utils import PretrainedConfig +from ..dynamic_module_utils import ( + check_python_requirements, + get_cached_module_file, + get_class_in_module, + resolve_trust_remote_code, +) +from ..integrations.deepspeed import is_deepspeed_zero3_enabled +from ..integrations.fsdp import is_fsdp_managed_module +from ..masking_utils import create_masks_for_generate +from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput +from ..pytorch_utils import isin_mps_friendly +from ..tokenization_utils import ExtensionsTrie +from ..utils import ( + ModelOutput, + is_accelerate_available, + is_hqq_available, + is_optimum_quanto_available, + is_torchdynamo_exporting, + logging, +) +from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint +from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer +from .candidate_generator import ( + AssistantVocabTranslatorCache, + AssistedCandidateGenerator, + AssistedCandidateGeneratorDifferentTokenizers, + CandidateGenerator, + EarlyExitCandidateGenerator, + PromptLookupCandidateGenerator, + UniversalSpeculativeDecodingGenerator, + _crop_past_key_values, + _prepare_attention_mask, + _prepare_token_type_ids, +) +from .configuration_utils import ( + NEED_SETUP_CACHE_CLASSES_MAPPING, + QUANT_BACKEND_CLASSES_MAPPING, + CompileConfig, + GenerationConfig, + GenerationMode, +) +from .continuous_batching import ContinuousMixin +from .logits_process import ( + EncoderNoRepeatNGramLogitsProcessor, + EncoderRepetitionPenaltyLogitsProcessor, + EpsilonLogitsWarper, + EtaLogitsWarper, + ExponentialDecayLengthPenalty, + ForcedBOSTokenLogitsProcessor, + ForcedEOSTokenLogitsProcessor, + HammingDiversityLogitsProcessor, + InfNanRemoveLogitsProcessor, + LogitNormalization, + LogitsProcessorList, + MinLengthLogitsProcessor, + MinNewTokensLengthLogitsProcessor, + MinPLogitsWarper, + NoBadWordsLogitsProcessor, + NoRepeatNGramLogitsProcessor, + PrefixConstrainedLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + SequenceBiasLogitsProcessor, + SuppressTokensAtBeginLogitsProcessor, + SuppressTokensLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + TypicalLogitsWarper, + UnbatchedClassifierFreeGuidanceLogitsProcessor, +) +from .stopping_criteria import ( + ConfidenceCriteria, + EosTokenCriteria, + MaxLengthCriteria, + MaxTimeCriteria, + StoppingCriteria, + StoppingCriteriaList, + StopStringCriteria, +) + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + from ..tokenization_utils_base import PreTrainedTokenizerBase + from .streamers import BaseStreamer + +logger = logging.get_logger(__name__) + +if is_accelerate_available(): + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + + +# Variable names used to hold the cache at generation time +ALL_CACHE_NAMES = [ + "past_key_values", # default + "cache_params", # mamba-based models + "state", # rwkv + "mems", # xlnet + "past_buckets_states", # reformer +] + + +@dataclass +class GenerateDecoderOnlyOutput(ModelOutput): + """ + Outputs of decoder-only generation models, when using non-beam methods. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): + Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. + """ + + sequences: torch.LongTensor + scores: Optional[tuple[torch.FloatTensor]] = None + logits: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None + hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None + past_key_values: Optional[tuple[tuple[tuple[torch.FloatTensor]]]] = None + + +@dataclass +class GenerateEncoderDecoderOutput(ModelOutput): + """ + Outputs of encoder-decoder generation models, when using non-beam methods. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): + Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, + sequence_length, sequence_length)`. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. + """ + + sequences: torch.LongTensor + scores: Optional[tuple[torch.FloatTensor]] = None + logits: Optional[tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor]] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None + cross_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None + past_key_values: Optional[tuple[tuple[tuple[torch.FloatTensor]]]] = None + + +@dataclass +class GenerateBeamDecoderOnlyOutput(ModelOutput): + """ + Outputs of decoder-only generation models, when using beam methods. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`): + Final beam scores of the generated `sequences`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): + Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting + of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. + Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): + Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. + beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`): + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, sequence_length)`. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. + """ + + sequences: torch.LongTensor + sequences_scores: Optional[torch.FloatTensor] = None + scores: Optional[tuple[torch.FloatTensor]] = None + logits: Optional[tuple[torch.FloatTensor]] = None + beam_indices: Optional[torch.LongTensor] = None + attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None + hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None + past_key_values: Optional[tuple[tuple[tuple[torch.FloatTensor]]]] = None + + +@dataclass +class GenerateBeamEncoderDecoderOutput(ModelOutput): + """ + Outputs of encoder-decoder generation models, when using beam methods. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`): + Final beam scores of the generated `sequences`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): + Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting + of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. + Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): + Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. + beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`): + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, sequence_length)`. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, + sequence_length, sequence_length)`. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`. + decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length, + sequence_length)`. + cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. + """ + + sequences: torch.LongTensor + sequences_scores: Optional[torch.FloatTensor] = None + scores: Optional[tuple[torch.FloatTensor]] = None + logits: Optional[tuple[torch.FloatTensor]] = None + beam_indices: Optional[torch.LongTensor] = None + encoder_attentions: Optional[tuple[torch.FloatTensor]] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None + cross_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None + past_key_values: Optional[tuple[tuple[tuple[torch.FloatTensor]]]] = None + + +# TODO (joao): remove the equivalent classes and typing shortcuts below in v5 +# Equivalent classes (kept for retrocompatibility purposes) +GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput +ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput +SampleDecoderOnlyOutput = GenerateDecoderOnlyOutput + +ContrastiveSearchEncoderDecoderOutput = GenerateEncoderDecoderOutput +GreedySearchEncoderDecoderOutput = GenerateEncoderDecoderOutput +SampleEncoderDecoderOutput = GenerateEncoderDecoderOutput + +BeamSearchDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput +BeamSampleDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput + +BeamSearchEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput +BeamSampleEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput + +GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] +SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] +BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] +BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] +ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput] + +# Typing shortcuts +GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput] +GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput] +GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput] + + +class GenerationMixin(ContinuousMixin): + """ + A class containing all functions for auto-regressive text generation, to be used as a mixin in model classes. + Inheriting from this class causes the model to have special generation-related behavior, such as loading a + `GenerationConfig` at initialization time or ensuring `generate`-related tests are run in `transformers` CI. + + A model class should inherit from `GenerationMixin` to enable calling methods like `generate`, or when it + has defined a custom `generate` method that relies on `GenerationMixin`, directly or indirectly, which + approximately shares the same interface to public methods like `generate`. Three examples: + - `LlamaForCausalLM` should inherit from `GenerationMixin` to enable calling `generate` and other public + methods in the mixin; + - `BlipForQuestionAnswering` has a custom `generate` method that approximately shares the same interface as + `GenerationMixin.generate` (it has a few extra arguments, and the same output). That function also calls + `GenerationMixin.generate` indirectly, through an inner model. As such, `BlipForQuestionAnswering` should + inherit from `GenerationMixin` to benefit from all generation-related automation in our codebase; + - `BarkModel` has a custom `generate` method and one of its inner models calls `GenerationMixin.generate`. + However, its `generate` does not share the same interface as `GenerationMixin.generate`. In this case, + `BarkModel` should NOT inherit from `GenerationMixin`, as it breaks the `generate` interface. + + The class exposes [`~generation.GenerationMixin.generate`], which can be used for: + - *greedy decoding* if `num_beams=1` and `do_sample=False` + - *contrastive search* if `penalty_alpha>0` and `top_k>1` + - *multinomial sampling* if `num_beams=1` and `do_sample=True` + - *beam-search decoding* if `num_beams>1` and `do_sample=False` + - *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True` + - *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1` + - *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None` + - *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()` + + To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). + """ + + def load_custom_generate( + self, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + trust_remote_code: Optional[bool] = None, + **kwargs, + ) -> Callable: + """ + Loads and returns a custom generate function, given a model repo. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + trust_remote_code (`bool`, *optional*): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + **kwargs: + Additional keyword arguments for remote code loading. + + Raises: + OSError: If `pretrained_model_name_or_path` does not contain a `custom_generate` subdirectory. + + Returns: + A callable that can be used to generate text. + """ + # Does `pretrained_model_name_or_path` have a `custom_generate` subdirectory? If not -> OSError + is_local_code = os.path.exists(pretrained_model_name_or_path) + has_custom_generate_folder = True + if is_local_code: + if not os.path.exists(os.path.join(pretrained_model_name_or_path, "custom_generate/generate.py")): + has_custom_generate_folder = False + else: + if not file_exists(pretrained_model_name_or_path, "custom_generate/generate.py"): + has_custom_generate_folder = False + + if not has_custom_generate_folder: + raise OSError( + f"`{pretrained_model_name_or_path}` does not contain a `custom_generate` subdirectory with a " + "`generate.py` file, can't load the custom generate function." + ) + + # Handle opt-in `trust_remote_code` and related exceptions + error_message = ( + f"The repository `{pretrained_model_name_or_path}` contains custom generation code that will override " + "the default `generate` method." + ) + resolve_trust_remote_code( + trust_remote_code, + pretrained_model_name_or_path, + has_local_code=is_local_code, + has_remote_code=not is_local_code, + error_message=error_message, + ) + + # Load the custom generate function + check_python_requirements( + pretrained_model_name_or_path, requirements_file="custom_generate/requirements.txt", **kwargs + ) + module = get_cached_module_file( + pretrained_model_name_or_path, module_file="custom_generate/generate.py", **kwargs + ) + custom_generate_function = get_class_in_module("generate", module) + return custom_generate_function + + def _cache_dependant_input_preparation( + self, + input_ids: torch.LongTensor, + inputs_embeds: Optional[torch.FloatTensor], + cache_position: Optional[torch.LongTensor], + ) -> tuple[torch.FloatTensor, torch.LongTensor]: + """ + Generic cache-dependent input preparation + The code is put in a separate function to allow granular unit testing + as it needs a different implementation to be exportable. + + If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + - Exception 1: when passing input_embeds, input_ids may be missing entries + - Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + - Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + - Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + generate the first token for each sequence. Later use the generated Input ids for continuation. + + The current implementation does not rely on ``self`` and could be + a class method. It is left as a standard method to be easily rewritten. + """ + if is_torchdynamo_exporting(): + return self._cache_dependant_input_preparation_exporting(input_ids, inputs_embeds, cache_position) + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( + inputs_embeds is not None # Exception 1 + or (cache_position[-1] >= input_ids.shape[1]) # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + return inputs_embeds, input_ids + + def _cache_dependant_input_preparation_exporting( + self, + input_ids: torch.LongTensor, + inputs_embeds: Optional[torch.FloatTensor], + cache_position: Optional[torch.LongTensor], + ) -> tuple[torch.FloatTensor, torch.LongTensor]: + """ + This method implements method ``_cache_dependant_input_preparation`` + with :func:`torch.cond` to make it exportable with :func:`torch.export.export`. + The code is put in a separate function to allow granular unit testing. + """ + if inputs_embeds is None: + input_ids = input_ids[:, cache_position] + else: + # This is the code we need to implemented with torch.cond. + # if input_ids.shape[1] == 0: + # inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + # else: + # if cache_position[-1] >= input_ids.shape[1]: + # input_ids = input_ids[:, -cache_position.shape[0] :] + # else: + # if input_ids.shape[1] != cache_position.shape[0]: + # input_ids = input_ids[:, cache_position] + def branch_1(inputs_embeds, cache_position): + return inputs_embeds[:, -cache_position.shape[0] :] + + def branch_2(input_ids, cache_position): + return input_ids[:, -cache_position.shape[0] :] + + def branch_3(input_ids, cache_position): + return input_ids[:, cache_position] + + inputs_embeds, input_ids = torch.cond( + input_ids.shape[1] == 0, + ( + lambda input_ids, inputs_embeds, cache_position: ( + branch_1(inputs_embeds, cache_position), + input_ids, + ) + ), + ( + lambda input_ids, inputs_embeds, cache_position: ( + inputs_embeds, + torch.cond( + cache_position[-1] >= input_ids.shape[1], + branch_2, + lambda input_ids, cache_position: ( + torch.cond( + input_ids.shape[1] != cache_position.shape[0], + branch_3, + (lambda input_ids, cache_position: input_ids), + [input_ids, cache_position], + ) + ), + [input_ids, cache_position], + ), + ) + ), + [input_ids, inputs_embeds, cache_position], + ) + return inputs_embeds, input_ids + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + """ + Prepare the model inputs for generation. It includes operations like computing the 4D attention mask or + slicing inputs given the existing cache. + + See the forward pass in the model documentation for expected arguments (different models might have different + requirements for e.g. `past_key_values`). This function should work as is for most LLMs. + """ + + # 1. Handle BC: + model_inputs = {} + # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`) + if self._supports_cache_class: + model_inputs["cache_position"] = cache_position + # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this + # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly + # (this alternative is not as robust as calling `generate` and letting it create `cache_position`) + elif cache_position is None: + past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + + # 2. Generic cache-dependent input preparation + if past_key_values is not None: + model_inputs["past_key_values"] = past_key_values + inputs_embeds, input_ids = self._cache_dependant_input_preparation( + input_ids, inputs_embeds, cache_position + ) + + # 3. Prepare base model inputs + input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step for every prompt. + if not self.config.is_encoder_decoder: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: + model_inputs[input_ids_key] = None + model_inputs["inputs_embeds"] = inputs_embeds + else: + # `clone` calls in this function ensure a consistent stride. See #32227 + model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) + model_inputs["inputs_embeds"] = None + else: + model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) + + # 4. Create missing `position_ids` on the fly + encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None + attention_mask = ( + kwargs.pop("decoder_attention_mask", None) if self.config.is_encoder_decoder else attention_mask + ) + attention_mask_key = "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask" + position_ids_key = "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids" + if ( + attention_mask is not None + and kwargs.get(position_ids_key) is None + and position_ids_key in set(inspect.signature(self.forward).parameters.keys()) + ): + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + kwargs[position_ids_key] = position_ids # placed in kwargs for further processing (see below) + + # 5. Slice model inputs if it's an input that should have the same length as `input_ids` + for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]: + model_input = kwargs.get(model_input_name) + if model_input is not None: + if past_key_values is not None: + current_input_length = ( + model_inputs["inputs_embeds"].shape[1] + if model_inputs.get("inputs_embeds") is not None + else model_inputs[input_ids_key].shape[1] + ) + model_input = model_input[:, -current_input_length:] + model_input = model_input.clone(memory_format=torch.contiguous_format) + model_inputs[model_input_name] = model_input + + # 6. Create 4D attention mask is we are using a compilable cache (important for performant compiled forward + # pass) + if ( + isinstance(past_key_values, Cache) + and past_key_values.is_compileable + and attention_mask is not None + and attention_mask.ndim == 2 + ): + if not self.config.is_encoder_decoder and model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + else: + batch_size, sequence_length = model_inputs[input_ids_key].shape[:2] + + # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create + # the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder. + base_model = getattr(self, self.base_model_prefix, self) + decoder = base_model.get_decoder() if hasattr(base_model, "get_decoder") else None + causal_mask_creation_function = getattr( + base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None + ) + if causal_mask_creation_function is None and decoder is not None: # it may be in the decoder + causal_mask_creation_function = getattr( + decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None + ) + + # If it's not defined, it means the model uses the new general mask API + if causal_mask_creation_function is None: # can't be found + token_type_ids = getattr(model_input, "token_type_ids", None) + # Some models may overwrite the general one + causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate) + attention_mask = causal_mask_creation_function( + config=self.config, + # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings + input_embeds=torch.empty((batch_size, sequence_length), dtype=self.dtype), + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + token_type_ids=token_type_ids, + ) + else: + attention_mask = causal_mask_creation_function( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.dtype, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + if attention_mask is not None: + model_inputs[attention_mask_key] = attention_mask + + if encoder_attention_mask is not None: + model_inputs["attention_mask"] = encoder_attention_mask + + # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples) + model_inputs.pop("labels", None) + return model_inputs + + def _prepare_model_inputs( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[torch.Tensor] = None, + model_kwargs: Optional[dict[str, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]: + """ + This function extracts the model-specific `inputs` for generation. + """ + # 1. retrieve all kwargs that are non-None or non-model input related. + # some encoder-decoder models have different names for model and encoder + if ( + self.config.is_encoder_decoder + and hasattr(self, "encoder") + and self.encoder.main_input_name != self.main_input_name + ): + input_name = self.encoder.main_input_name + else: + input_name = self.main_input_name + + model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} + + # 2. check whether model_input_name is passed as kwarg + # if yes and `inputs` is None use kwarg inputs + inputs_kwarg = model_kwargs.pop(input_name, None) + if inputs_kwarg is not None and inputs is not None: + raise ValueError( + f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. " + f"Make sure to either pass {inputs} or {input_name}=..." + ) + elif inputs_kwarg is not None: + inputs = inputs_kwarg + + # 3. In the presence of `inputs_embeds` for text models: + # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model + # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with + # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`) + # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and + # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states. + if input_name == "input_ids" and "inputs_embeds" in model_kwargs: + if model_kwargs["inputs_embeds"] is None: + model_kwargs.pop("inputs_embeds") + elif not self.config.is_encoder_decoder: + has_inputs_embeds_forwarding = "inputs_embeds" in set( + inspect.signature(self.prepare_inputs_for_generation).parameters.keys() + ) + if not has_inputs_embeds_forwarding: + raise ValueError( + f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} " + "doesn't have its forwarding implemented. See the GPT2 implementation for an example " + "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!" + ) + # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of + # the attention mask) can rely on the actual model input. + model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( + inputs, bos_token_id, model_kwargs=model_kwargs + ) + inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" + else: + if inputs is not None: + raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") + inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" + + # 4. if `inputs` is still None, try to create `input_ids` from BOS token + inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) + return inputs, input_name, model_kwargs + + def _maybe_initialize_input_ids_for_generation( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[torch.Tensor] = None, + model_kwargs: Optional[dict[str, torch.Tensor]] = None, + ) -> torch.LongTensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + + encoder_outputs = model_kwargs.get("encoder_outputs") + if self.config.is_encoder_decoder and encoder_outputs is not None: + # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding + shape = encoder_outputs.last_hidden_state.size()[:-1] + return torch.ones(shape, dtype=torch.long, device=self.device) * -100 + + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, torch.Tensor): + batch_size = value.shape[0] + break + + if "inputs_embeds" in model_kwargs: + return torch.ones((batch_size, 0), dtype=torch.long, device=self.device) + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + + def _prepare_attention_mask_for_generation( + self, + inputs_tensor: torch.Tensor, + generation_config: GenerationConfig, + model_kwargs: dict[str, Any], + ) -> torch.LongTensor: + pad_token_id = generation_config._pad_token_tensor + eos_token_id = generation_config._eos_token_tensor + + # `input_ids` may be present in the model kwargs, instead of being the main input (e.g. multimodal model) + if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: + inputs_tensor = model_kwargs["input_ids"] + + # No information for attention mask inference -> return default attention mask + default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device) + if pad_token_id is None: + return default_attention_mask + + is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] + if not is_input_ids: + return default_attention_mask + + is_pad_token_in_inputs = (pad_token_id is not None) and ( + isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any() + ) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~( + isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any() + ) + can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id + attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long() + + attention_mask = ( + attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask + ) + return attention_mask + + def _prepare_encoder_decoder_kwargs_for_generation( + self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name: Optional[str], + generation_config: GenerationConfig, + ) -> dict[str, Any]: + # 1. get encoder + encoder = self.get_encoder() + # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device + # as the inputs. + if hasattr(self, "hf_device_map"): + if hasattr(encoder, "_hf_hook"): + encoder._hf_hook.io_same_device = True + else: + add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True)) + + # 2. Prepare encoder args and encoder kwargs from model kwargs and generation config. + irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not any(argument.startswith(p) for p in irrelevant_prefix) + } + encoder_signature = set(inspect.signature(encoder.forward).parameters) + encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature + if not encoder_accepts_wildcard: + encoder_kwargs = { + argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature + } + encoder_kwargs["output_attentions"] = generation_config.output_attentions + encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states + + # 3. make sure that encoder returns `ModelOutput` + model_input_name = model_input_name if model_input_name is not None else self.main_input_name + encoder_kwargs["return_dict"] = True + encoder_kwargs[model_input_name] = inputs_tensor + model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) # type: ignore + + return model_kwargs + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + model_input_name: str, + model_kwargs: dict[str, torch.Tensor], + decoder_start_token_id: torch.Tensor, + device: Optional[torch.device] = None, + ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif "input_ids" in model_kwargs and model_input_name != "input_ids": + decoder_input_ids = model_kwargs.pop("input_ids") + else: + decoder_input_ids = None + + # 2. `decoder_start_token_id` must have shape (batch_size, 1) + if device is None: + device = self.device + if decoder_start_token_id.ndim == 1: + if decoder_start_token_id.shape[0] != batch_size: + raise ValueError( + f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}" + ) + decoder_start_token_id = decoder_start_token_id.view(-1, 1) + else: + decoder_start_token_id = ( + torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id + ) + + # 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + # no user input -> use decoder_start_token_id as decoder_input_ids + if decoder_input_ids is None: + decoder_input_ids = decoder_start_token_id + # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the + # original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic. + # See: https://github.com/huggingface/transformers/pull/31470 + elif "donut" in self.__class__.__name__.lower() or ( + self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower() + ): + pass + elif self.config.model_type in ["whisper"]: + pass + # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # decoder_attention_mask if provided) + elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item(): + decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + decoder_attention_mask = torch.cat( + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + dim=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask + + return decoder_input_ids, model_kwargs + + @staticmethod + def _expand_inputs_for_generation( + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" + # Do not call torch.repeat_interleave if expand_size is 1 because it clones + # the input tensor and thus requires more memory although no change is applied + if expand_size == 1: + return input_ids, model_kwargs + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> dict[str, Any]: + # update past_key_values keeping its naming used in model code + for possible_cache_name in ALL_CACHE_NAMES: + if possible_cache_name in outputs: + # TODO (joao): remove output/input mismatch when these old models (xlnet, reformer) are deprecated + if possible_cache_name in ("past_buckets_states", "mems"): + cache_name = "past_key_values" + else: + cache_name = possible_cache_name + model_kwargs[cache_name] = getattr(outputs, possible_cache_name) + break + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + + if not is_encoder_decoder: + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + else: + # update decoder attention mask + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + model_kwargs["decoder_attention_mask"] = torch.cat( + [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], + dim=-1, + ) + + if model_kwargs.get("use_cache", True): + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens + else: + past_positions = model_kwargs.pop("cache_position") + new_positions = torch.arange( + past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype + ).to(past_positions.device) + model_kwargs["cache_position"] = torch.cat((past_positions, new_positions)) + return model_kwargs + + def _reorder_cache(self, past_key_values, beam_idx): + raise NotImplementedError( + f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to" + f" enable beam search for {self.__class__}" + ) + + def _get_candidate_generator( + self, + generation_config: GenerationConfig, + input_ids: torch.LongTensor, + inputs_tensor: torch.Tensor, + assistant_model: "PreTrainedModel", + logits_processor: LogitsProcessorList, + target_tokenizer: "PreTrainedTokenizerBase", + assistant_tokenizer: "PreTrainedTokenizerBase", + model_kwargs: dict, + ) -> CandidateGenerator: + """ + Returns the candidate generator to be used in `assisted_generation` + """ + different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer)) + + if generation_config.assistant_early_exit is not None: + candidate_generator = EarlyExitCandidateGenerator( + input_ids=input_ids, + assistant_model=self, + generation_config=generation_config, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + logits_processor=logits_processor, + ) + elif generation_config.prompt_lookup_num_tokens is not None: + candidate_generator = PromptLookupCandidateGenerator( + eos_token_id=generation_config._eos_token_tensor, + num_output_tokens=generation_config.prompt_lookup_num_tokens, + max_matching_ngram_size=generation_config.max_matching_ngram_size, + max_length=generation_config.max_length, + ) + elif different_tokenizers: + if generation_config.do_sample is True: + atm_translator = AssistantVocabTranslatorCache.get_translator( + target_tokenizer, + assistant_tokenizer, + self.config.get_text_config().vocab_size, + assistant_model=assistant_model, + assistant_prune_lm_head=True, # prune LM head of assistant model + ) + # Since we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismatches between token ids and logits index + assistant_model.generation_config.repetition_penalty = None + candidate_generator = UniversalSpeculativeDecodingGenerator( + input_ids=input_ids, + assistant_model=assistant_model, + generation_config=generation_config, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + logits_processor=logits_processor, + target_tokenizer=target_tokenizer, + assistant_tokenizer=assistant_tokenizer, + atm_translator=atm_translator, + ) + elif generation_config.do_sample is False: + candidate_generator = AssistedCandidateGeneratorDifferentTokenizers( + input_ids=input_ids, + assistant_model=assistant_model, + generation_config=generation_config, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + logits_processor=logits_processor, + target_tokenizer=target_tokenizer, + assistant_tokenizer=assistant_tokenizer, + ) + else: + raise ValueError( + f"Invalid value for `do_sample`: expected a boolean, got {type(generation_config.do_sample).__name__}" + ) + else: + candidate_generator = AssistedCandidateGenerator( + input_ids=input_ids, + assistant_model=assistant_model, + generation_config=generation_config, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + logits_processor=logits_processor, + ) + return candidate_generator + + def _get_logits_processor( + self, + generation_config: GenerationConfig, + input_ids_seq_length: Optional[int] = None, + encoder_input_ids: torch.LongTensor = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = None, + device: Optional[str] = None, + model_kwargs: Optional[dict[str, Any]] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + ) -> LogitsProcessorList: + """ + This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] + instances used to modify the scores of the language model head. + """ + # instantiate processors list + processors = LogitsProcessorList() + if logits_processor is None: + logits_processor = [] + + if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1: + processors.append( + UnbatchedClassifierFreeGuidanceLogitsProcessor( + generation_config.guidance_scale, + self, + unconditional_ids=negative_prompt_ids, + unconditional_attention_mask=negative_prompt_attention_mask, + use_cache=generation_config.use_cache, + ) + ) + if generation_config.sequence_bias is not None: + processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias)) + + if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0: + processors.append( + HammingDiversityLogitsProcessor( + diversity_penalty=generation_config.diversity_penalty, + num_beams=generation_config.num_beams, + num_beam_groups=generation_config.num_beam_groups, + ) + ) + if ( + generation_config.encoder_repetition_penalty is not None + and generation_config.encoder_repetition_penalty != 1.0 + ): + if len(encoder_input_ids.shape) == 2: + processors.append( + EncoderRepetitionPenaltyLogitsProcessor( + penalty=generation_config.encoder_repetition_penalty, + encoder_input_ids=encoder_input_ids, + ) + ) + else: + warnings.warn( + "Passing `encoder_repetition_penalty` requires some form of `input_ids` to be passed to " + "`generate`, ignoring the argument.", + UserWarning, + ) + if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: + processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) + if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: + processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)) + if ( + generation_config.encoder_no_repeat_ngram_size is not None + and generation_config.encoder_no_repeat_ngram_size > 0 + ): + if len(encoder_input_ids.shape) == 2: + processors.append( + EncoderNoRepeatNGramLogitsProcessor( + generation_config.encoder_no_repeat_ngram_size, + encoder_input_ids, + ) + ) + else: + warnings.warn( + "Passing `encoder_no_repeat_ngram_size` requires some form of `input_ids` to be passed to " + "`generate`, ignoring the argument.", + UserWarning, + ) + if generation_config.bad_words_ids is not None: + processors.append( + NoBadWordsLogitsProcessor( + generation_config.bad_words_ids, + generation_config._eos_token_tensor, + ) + ) + if ( + generation_config.min_length is not None + and getattr(generation_config, "_eos_token_tensor", None) is not None + and generation_config.min_length > 0 + ): + processors.append( + MinLengthLogitsProcessor( + generation_config.min_length, + generation_config._eos_token_tensor, + device=device, + ) + ) + if ( + generation_config.min_new_tokens is not None + and getattr(generation_config, "_eos_token_tensor", None) is not None + and generation_config.min_new_tokens > 0 + ): + processors.append( + MinNewTokensLengthLogitsProcessor( + input_ids_seq_length, + generation_config.min_new_tokens, + generation_config._eos_token_tensor, + device=device, + ) + ) + if prefix_allowed_tokens_fn is not None: + processors.append( + PrefixConstrainedLogitsProcessor( + prefix_allowed_tokens_fn, + generation_config.num_beams // generation_config.num_beam_groups, + ) + ) + if generation_config.forced_bos_token_id is not None: + processors.append( + ForcedBOSTokenLogitsProcessor( + generation_config.forced_bos_token_id, + ) + ) + if generation_config.forced_eos_token_id is not None: + processors.append( + ForcedEOSTokenLogitsProcessor( + generation_config.max_length, + generation_config.forced_eos_token_id, + device=device, + ) + ) + if generation_config.remove_invalid_values is True: + processors.append(InfNanRemoveLogitsProcessor()) + if generation_config.exponential_decay_length_penalty is not None: + processors.append( + ExponentialDecayLengthPenalty( + generation_config.exponential_decay_length_penalty, + generation_config._eos_token_tensor, + input_ids_seq_length, + ) + ) + if generation_config.suppress_tokens is not None: + processors.append( + SuppressTokensLogitsProcessor( + generation_config.suppress_tokens, + device=device, + ) + ) + if generation_config.begin_suppress_tokens is not None: + begin_index = input_ids_seq_length + begin_index = ( + begin_index + if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) + else begin_index + 1 + ) + processors.append( + SuppressTokensAtBeginLogitsProcessor( + generation_config.begin_suppress_tokens, + begin_index, + device=device, + ) + ) + + # TODO (joao): find a strategy to specify the order of the processors + processors = self._merge_criteria_processor_list(processors, logits_processor) + + # Processors previously known as `LogitsWarpers`, only applied with sampling strategies + if generation_config.do_sample: + # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a + # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) + if generation_config.num_beams > 1: + if isinstance(generation_config._eos_token_tensor, list): + min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 + elif isinstance(generation_config._eos_token_tensor, torch.Tensor): + min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1 + else: + min_tokens_to_keep = 2 + else: + min_tokens_to_keep = 1 + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if generation_config.temperature is not None and generation_config.temperature != 1.0: + processors.append(TemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_k is not None and generation_config.top_k != 0: + processors.append( + TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.top_p is not None and generation_config.top_p < 1.0: + processors.append( + TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.min_p is not None: + # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084) + processors.append( + MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.typical_p is not None and generation_config.typical_p < 1.0: + processors.append( + TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: + processors.append( + EpsilonLogitsWarper( + epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep + ) + ) + if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: + processors.append( + EtaLogitsWarper( + epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device + ) + ) + + # Watermarking should be after all logits processing is finished (see #34630) + if generation_config.watermarking_config is not None: + processors.append( + generation_config.watermarking_config.construct_processor( + self.config.get_text_config().vocab_size, device + ) + ) + + # `LogitNormalization` should always be the last logit processor, when present + if generation_config.renormalize_logits is True: + processors.append(LogitNormalization()) + return processors + + def _get_stopping_criteria( + self, + generation_config: GenerationConfig, + stopping_criteria: Optional[StoppingCriteriaList], + tokenizer: Optional["PreTrainedTokenizerBase"] = None, + **kwargs, + ) -> StoppingCriteriaList: + criteria = StoppingCriteriaList() + if generation_config.max_length is not None: + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + criteria.append( + MaxLengthCriteria( + max_length=generation_config.max_length, + max_position_embeddings=max_position_embeddings, + ) + ) + if generation_config.max_time is not None: + criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) + if generation_config.stop_strings is not None: + if tokenizer is None: + raise ValueError( + "There are one or more stop strings, either in the arguments to `generate` or in the " + "model's generation config, but we could not locate a tokenizer. When generating with " + "stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`." + ) + criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer)) + if generation_config._eos_token_tensor is not None: + criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor)) + if ( + generation_config.is_assistant + and generation_config.assistant_confidence_threshold is not None + and generation_config.assistant_confidence_threshold > 0 + ): + criteria.append( + ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold) + ) + criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) + return criteria + + def _merge_criteria_processor_list( + self, + default_list: Union[LogitsProcessorList, StoppingCriteriaList], + custom_list: Union[LogitsProcessorList, StoppingCriteriaList], + ) -> Union[LogitsProcessorList, StoppingCriteriaList]: + """ + Merge user-defined processors/criteria with the ones instantiated inside `generate`. In case the same + processor/criteria is present on both lists, use the user-defined one. + + (Note: up to v4.49.0, this function threw an exception is the same logit processor was found twice.) + """ + if len(custom_list) == 0: + return default_list + + final_list = type(default_list)() + for default in default_list: + using_custom = False + for custom in custom_list: + if type(custom) is type(default): + object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" + logger.warning_once( + f"A custom {object_type} of type {type(custom)} has been passed to `.generate()`, but it " + f"was also created in `.generate()`, given its parameterization. The custom {type(custom)} " + f"will take precedence. Please check the docstring of {type(custom)} to see related " + "`.generate()` flags." + ) + final_list.append(custom) + using_custom = True + break + if not using_custom: + final_list.append(default) + + for custom in custom_list: + if custom not in final_list: + final_list.append(custom) + return final_list + + def compute_transition_scores( + self, + sequences: torch.Tensor, + scores: tuple[torch.Tensor], + beam_indices: Optional[torch.Tensor] = None, + normalize_logits: bool = False, + ) -> torch.Tensor: + """ + Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was + used). This is a convenient method to quickly obtain the scores of the selected tokens at generation time. + + Parameters: + sequences (`torch.LongTensor`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or + shorter if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)`): + Transition scores for each vocabulary token at each generation step. Beam transition scores consisting + of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. + Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. + beam_indices (`torch.LongTensor`, *optional*): + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at + generate-time. + normalize_logits (`bool`, *optional*, defaults to `False`): + Whether to normalize the logits (which, for legacy reasons, may be unnormalized). + + Return: + `torch.Tensor`: A `torch.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)` containing + the transition scores (logits) + + Examples: + + ```python + >>> from transformers import GPT2Tokenizer, AutoModelForCausalLM + >>> import numpy as np + + >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer.pad_token_id = tokenizer.eos_token_id + >>> inputs = tokenizer(["Today is"], return_tensors="pt") + + >>> # Example 1: Print the scores for each token generated with Greedy Search + >>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True) + >>> transition_scores = model.compute_transition_scores( + ... outputs.sequences, outputs.scores, normalize_logits=True + ... ) + >>> # input_length is the length of the input prompt for decoder-only models, like the GPT family, and 1 for + >>> # encoder-decoder models, like BART or T5. + >>> input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1] + >>> generated_tokens = outputs.sequences[:, input_length:] + >>> for tok, score in zip(generated_tokens[0], transition_scores[0]): + ... # | token | token string | log probability | probability + ... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}") + | 262 | the | -1.414 | 24.33% + | 1110 | day | -2.609 | 7.36% + | 618 | when | -2.010 | 13.40% + | 356 | we | -1.859 | 15.58% + | 460 | can | -2.508 | 8.14% + + >>> # Example 2: Reconstruct the sequence scores from Beam Search + >>> outputs = model.generate( + ... **inputs, + ... max_new_tokens=5, + ... num_beams=4, + ... num_return_sequences=4, + ... return_dict_in_generate=True, + ... output_scores=True, + ... ) + >>> transition_scores = model.compute_transition_scores( + ... outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False + ... ) + >>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores. + >>> # Tip 1: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the + >>> # use case, you might want to recompute it with `normalize_logits=True`. + >>> # Tip 2: the output length does NOT include the input length + >>> output_length = np.sum(transition_scores.numpy() < 0, axis=1) + >>> length_penalty = model.generation_config.length_penalty + >>> reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty) + >>> print(np.allclose(outputs.sequences_scores, reconstructed_scores)) + True + ```""" + # 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent + # to a beam search approach were the first (and only) beam is always selected + if beam_indices is None: + beam_indices = torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device) + beam_indices = beam_indices.expand(-1, len(scores)) + + # 2. reshape scores as [batch_size*vocab_size, # generation steps] with # generation steps being + # seq_len - input_length + scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1) + + # 3. Optionally normalize the logits (across the vocab dimension) + if normalize_logits: + scores = scores.reshape(-1, self.config.get_text_config().vocab_size, scores.shape[-1]) + scores = torch.nn.functional.log_softmax(scores, dim=1) + scores = scores.reshape(-1, scores.shape[-1]) + + # 4. cut beam_indices to longest beam length + beam_indices_mask = beam_indices < 0 + max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max() + beam_indices = beam_indices.clone()[:, :max_beam_length] + beam_indices_mask = beam_indices_mask[:, :max_beam_length] + + # 5. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards + beam_indices[beam_indices_mask] = 0 + + # 6. multiply beam_indices with vocab size to gather correctly from scores + beam_sequence_indices = beam_indices * self.config.get_text_config().vocab_size + + # 7. Define which indices contributed to scores + cut_idx = sequences.shape[-1] - max_beam_length + indices = sequences[:, cut_idx:] + beam_sequence_indices + + # 8. Compute scores + transition_scores = scores.gather(0, indices) + + # 9. Mask out transition_scores of beams that stopped early + transition_scores[beam_indices_mask] = 0 + + return transition_scores + + def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer): + if assistant_model is None: + return + + if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder: + attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"] + attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check] + are_equal = all( + getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check + ) + if not are_equal: + raise ValueError( + "The main model and the assistant don't have compatible encoder-dependent input shapes. " + "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper." + ) + + doc_reference = ( + "(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)" + ) + if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size: + if assistant_tokenizer is not None: + raise ValueError( + f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}." + ) + else: + if tokenizer is None or assistant_tokenizer is None: + raise ValueError( + f"The main and assistant moedels have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}." + ) + + def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): + """Validates model kwargs for generation. Generate argument typos will also be caught here.""" + # If a `Cache` instance is passed, checks whether the model is compatible with it + if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class: + raise ValueError( + f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please " + "check the model documentation for supported cache formats." + ) + + # Excludes arguments that are handled before calling any model function + if self.config.is_encoder_decoder: + for key in ["decoder_input_ids"]: + model_kwargs.pop(key, None) + + unused_model_args = [] + model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) + # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If + # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) + if "kwargs" in model_args or "model_kwargs" in model_args: + model_args |= set(inspect.signature(self.forward).parameters) + + # Encoder-Decoder models may also need Encoder arguments from `model_kwargs` + if self.config.is_encoder_decoder: + base_model = getattr(self, self.base_model_prefix, None) + + # allow encoder kwargs + encoder = getattr(self, "encoder", None) + # `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`. + # Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder` + # TODO: A better way to handle this. + if encoder is None and base_model is not None: + encoder = getattr(base_model, "encoder", None) + + if encoder is not None: + encoder_model_args = set(inspect.signature(encoder.forward).parameters) + model_args |= encoder_model_args + + # allow decoder kwargs + decoder = getattr(self, "decoder", None) + if decoder is None and base_model is not None: + decoder = getattr(base_model, "decoder", None) + + if decoder is not None: + decoder_model_args = set(inspect.signature(decoder.forward).parameters) + model_args |= {f"decoder_{x}" for x in decoder_model_args} + + for key, value in model_kwargs.items(): + if value is not None and key not in model_args: + unused_model_args.append(key) + + if unused_model_args: + raise ValueError( + f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" + " generate arguments will also show up in this list)" + ) + + def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): + """Performs validation related to the resulting generated length""" + # 1. Max length warnings related to poor parameterization + if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: + # 20 is the default max_length of the generation config + warnings.warn( + f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " + "generation length. We recommend setting `max_new_tokens` to control the maximum length of the " + "generation.", + UserWarning, + ) + if input_ids_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + raise ValueError( + f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_length` or, better yet, setting `max_new_tokens`." + ) + + # 2. Min length warnings due to unfeasible parameter combinations + min_length_error_suffix = ( + " Generation will stop at the defined maximum length. You should decrease the minimum length and/or " + "increase the maximum length." + ) + if has_default_max_length: + min_length_error_suffix += ( + f" Note that `max_length` is set to {generation_config.max_length}, its default value." + ) + if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: + warnings.warn( + f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than" + f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, + UserWarning, + ) + if generation_config.min_new_tokens is not None: + min_length = generation_config.min_new_tokens + input_ids_length + if min_length > generation_config.max_length: + warnings.warn( + f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when " + f"added to the prompt length ({input_ids_length}), is larger than" + f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, + UserWarning, + ) + + def _prepare_generated_length( + self, + generation_config, + has_default_max_length, + has_default_min_length, + model_input_name, + input_ids_length, + inputs_tensor, + ): + """Prepared max and min length in generation configs to avoid clashes between similar attributes""" + + if generation_config.max_new_tokens is not None: + if not has_default_max_length and generation_config.max_length is not None: + logger.warning( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + generation_config.max_length = generation_config.max_new_tokens + input_ids_length + + # if both `inputs_embeds` and `input_ids` are passed, we do not correct the length + # otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length`` + elif ( + model_input_name == "inputs_embeds" + and input_ids_length != inputs_tensor.shape[1] + and not self.config.is_encoder_decoder + ): + generation_config.max_length -= inputs_tensor.shape[1] + elif has_default_max_length: # by default let's always generate 20 new tokens + if generation_config.max_length == GenerationConfig().max_length: + generation_config.max_length = generation_config.max_length + input_ids_length + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + if max_position_embeddings is not None: + generation_config.max_length = min(generation_config.max_length, max_position_embeddings) + + # same for min length + if generation_config.min_new_tokens is not None: + if not has_default_min_length: + logger.warning( + f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(=" + f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + generation_config.min_length = generation_config.min_new_tokens + input_ids_length + + elif ( + model_input_name == "inputs_embeds" + and input_ids_length != inputs_tensor.shape[1] + and not self.config.is_encoder_decoder + ): + generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0) + + return generation_config + + def _prepare_generation_config( + self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: dict + ) -> tuple[GenerationConfig, dict]: + """ + Prepares the base generation config, then applies any generation configuration options from kwargs. This + function handles retrocompatibility with respect to configuration files. + """ + # parameterization priority: + # kwargs > non-global default values in `generation_config` > `model.generation_config` > GenerationConfig() + # TODO (joao): per-model generation config classes. + + using_model_generation_config = False + if generation_config is None: + # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, + # the following conditions must be met + # 1) the generation config must have been created from the model config (`_from_model_config` field); + # 2) the generation config must have seen no modification since its creation (the hash is the same); + # 3) there are non-default generation parameters in the model config. + # 4) the user must have set new generation parameters in the model config. + if ( + self.generation_config._from_model_config # 1) + and self.generation_config._original_object_hash == hash(self.generation_config) # 2) + and len(self.config._get_non_default_generation_parameters()) > 0 # 3) + ): + new_generation_config = GenerationConfig.from_model_config(self.config) + if new_generation_config != self.generation_config: # 4) + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed in v5." + " Please use and modify the model generation configuration (see" + " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )", + UserWarning, + ) + self.generation_config = new_generation_config + + generation_config = self.generation_config + using_model_generation_config = True + + # `torch.export.export` usually raises an exception if it is called + # with ``strict=True``. deepcopy can only be processed if ``strict=False``. + generation_config = copy.deepcopy(generation_config) + + if not using_model_generation_config: + # If `generation_config` is provided: + # - `use_model_defaults`: let's fallback ALL default values to the model's generation config + # - otherwise: legacy behavior, let's just make sure we have the tokens defined + model_base_version = version.parse(version.parse(self.generation_config.transformers_version).base_version) + if use_model_defaults is True or ( + use_model_defaults is None and model_base_version >= version.parse("4.50.0") + ): + modified_values = {} + global_default_generation_config = GenerationConfig() + model_generation_config = self.generation_config + # we iterate over the model's generation config: it may hold custom keys, which we'll want to copy + for key, model_gen_config_value in model_generation_config.__dict__.items(): + if key.startswith("_") or key == "transformers_version": # metadata + continue + global_default_value = getattr(global_default_generation_config, key, None) + custom_gen_config_value = getattr(generation_config, key, None) + if ( + custom_gen_config_value == global_default_value + and model_gen_config_value != global_default_value + ): + modified_values[key] = model_gen_config_value + setattr(generation_config, key, model_gen_config_value) + if use_model_defaults is None and len(modified_values) > 0: + logger.warning_once( + f"`generation_config` default values have been modified to match model-specific defaults: " + f"{modified_values}. If this is not desired, please set these values explicitly." + ) + else: + if generation_config.bos_token_id is None: + generation_config.bos_token_id = self.generation_config.bos_token_id + if generation_config.eos_token_id is None: + generation_config.eos_token_id = self.generation_config.eos_token_id + if generation_config.pad_token_id is None: + generation_config.pad_token_id = self.generation_config.pad_token_id + if generation_config.decoder_start_token_id is None: + generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id + + # Finally, apply any passed kwargs + model_kwargs = generation_config.update(**kwargs) + + return generation_config, model_kwargs + + def _get_initial_cache_position(self, seq_length, device, model_kwargs): + """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" + # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` + if "cache_position" in model_kwargs and model_kwargs["cache_position"]: + return model_kwargs + if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder: + cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder: + cache_position = ( + torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + ) + else: + cache_position = torch.ones(seq_length, dtype=torch.int64, device=device).cumsum(0) - 1 + + past_length = 0 + if model_kwargs.get("past_key_values") is not None: + cache = model_kwargs["past_key_values"] + past_length = 0 + if not isinstance(cache, Cache): + past_length = cache[0][0].shape[2] + elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: + past_length = cache.get_seq_length() + + cache_position = cache_position[past_length:] + + model_kwargs["cache_position"] = cache_position + return model_kwargs + + def _get_layer_device_map_for_cache_init(self) -> Optional[dict[int, Union[str, int]]]: + """ + Returns the device map for each decoder layer, to allocate the cache on the right device. + Inspired from `dispatch_model` in accelerate. + """ + execution_device_map = None + + if hasattr(self, "hf_device_map"): + if set(self.hf_device_map.values()) == {"cpu"} or set(self.hf_device_map.values()) == {"cpu", "disk"}: + main_device = "cpu" + else: + main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0] + execution_device_map = { + name: main_device if device in ["cpu", "disk"] else device + for name, device in self.hf_device_map.items() + } + + # No `execution_device_map` -> rely on `self.device` to allocate the cache + if execution_device_map is None: + return None + + # Single device for all layers + num_hidden_layers = self.config.get_text_config().num_hidden_layers + if len(execution_device_map) == 1 and "" in execution_device_map: + return dict.fromkeys(range(num_hidden_layers), execution_device_map[""]) + + # Multiple devices in `execution_device_map` -> we need to map decoder layers to the correct device. + layer_device_map = {} + # Case 1: The model has a `get_decoder` method, we can use it to find the decoder name. + if hasattr(self, "get_decoder"): + decoder_name = None + for name, module in self.named_modules(): + if module is self.get_decoder(): + decoder_name = name + break + if decoder_name is None: + raise RuntimeError( + "`model.get_decoder()` is not returning a named module of the model. This is unexpected, please " + "open an issue on GitHub." + ) + + decoder_mapped_modules = [ + module_name for module_name in execution_device_map.keys() if decoder_name in module_name + ] + # The decoder name may be present in `execution_device_map` in two forms: + # a) each layer has a device mapping + if len(decoder_mapped_modules) >= num_hidden_layers: + for idx in range(num_hidden_layers): + for module_name in decoder_mapped_modules: + if f".{idx}." in f"{module_name}.": + layer_device_map[idx] = execution_device_map[module_name] + break + + # b) the whole module is mapped to a single device. If the decoder name is NOT present in the device map, + # then the mapping is done in a parent module + else: + while True: + if decoder_name in execution_device_map: + layer_device_map = dict.fromkeys(range(num_hidden_layers), execution_device_map[decoder_name]) + break + elif "." in decoder_name: + decoder_name = decoder_name.rsplit(".", 1)[0] # gets the name of the parent module + else: + raise RuntimeError(f"Decoder name {decoder_name} not found in execution device map") + + # Case 2: Legacy code path: assume the decoder layers are named as `(...).X` (X being the layer index) + else: + for layer in execution_device_map: + for idx in range(num_hidden_layers): + if f".{idx}." in f"{layer}.": + layer_device_map[idx] = execution_device_map[layer] + break + + for idx in range(num_hidden_layers): + if idx not in layer_device_map: + raise RuntimeError(f"layer {idx} has not been mapped to a device.") + return layer_device_map + + def _get_cache( + self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs + ) -> Cache: + """ + Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a + new `generate` call requires a larger cache or uses a different batch size. + + Returns the resulting cache object. + """ + if cache_implementation == "hybrid" and "llama4" in getattr(self.config, "model_type", ""): + cache_implementation = "hybrid_chunked" + + cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] + requires_cross_attention_cache = ( + self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + ) + + if hasattr(self, "_cache"): + cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache + + if cache_implementation == "sliding_window": + max_cache_len = min(self.config.sliding_window, max_cache_len) + + need_new_cache = ( + not hasattr(self, "_cache") + or (not isinstance(cache_to_check, cache_cls)) + or cache_to_check.max_batch_size != batch_size + or isinstance( + cache_to_check, (HybridChunkedCache, OffloadedHybridCache) + ) # due to internal slicing, we always re-init + ) + if cache_implementation != "mamba": + need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len + + if requires_cross_attention_cache and hasattr(self, "_cache"): + need_new_cache = ( + need_new_cache + or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1] + ) + + if need_new_cache: + if hasattr(self.config, "_pre_quantization_dtype"): + cache_dtype = self.config._pre_quantization_dtype + else: + cache_dtype = self.dtype + + layer_device_map = self._get_layer_device_map_for_cache_init() + cache_kwargs = { + "config": self.config.get_text_config(), + "max_batch_size": batch_size, + "max_cache_len": max_cache_len, + "dtype": cache_dtype, + "device": device, + "layer_device_map": layer_device_map, + } + self._cache = cache_cls(**cache_kwargs) + if requires_cross_attention_cache: + encoder_kwargs = cache_kwargs.copy() + encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1] + self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs)) + else: + self._cache.reset() + return self._cache + + def _supports_default_dynamic_cache(self) -> bool: + """ + Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. + This is mostly the same as `_supports_cache_class` attribute, but add exception for `Jamba` model which + uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in + order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed + for `HybridMambaAttentionDynamicCache`). + """ + return ( + self._supports_cache_class + and "jamba" not in self.__class__.__name__.lower() + and "zamba" not in self.__class__.__name__.lower() + and "bamba" not in self.__class__.__name__.lower() + and "minimax" not in self.__class__.__name__.lower() + ) + + def _prepare_cache_for_generation( + self, + generation_config: GenerationConfig, + model_kwargs: dict, + assistant_model: "PreTrainedModel", + batch_size: int, + max_cache_length: int, + device: torch.device, + ) -> bool: + """ + Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is + instantiated, writes it to `model_kwargs`, under the name expected by the model. + """ + + is_hybrid_cache = any(class_name in self.__class__.__name__.lower() for class_name in ["mamba", "falconh1"]) + cache_name = "past_key_values" if not is_hybrid_cache else "cache_params" + + requires_cross_attention_cache = ( + self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + ) + + # Quick escape route 1: if the user specifies a cache, we only need to: + # a) check for conflicting `generate` arguments + # b) convert to the new cache format (if the user passes a legacy cache and model supports it) + user_defined_cache = model_kwargs.get(cache_name) + if user_defined_cache is not None: + if generation_config.cache_implementation is not None: + raise ValueError( + f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " + "Cache object) is unsupported. Please use only one of the two." + ) + if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache(): + model_kwargs[cache_name] = ( + DynamicCache.from_legacy_cache(user_defined_cache) + if not requires_cross_attention_cache + else EncoderDecoderCache.from_legacy_cache(user_defined_cache) + ) + return + + # Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in + # `generation_config.validate()`) + if generation_config.use_cache is False: + return + + # Quick escape route 3: model that only supports legacy caches = nothing to prepare + if not self._supports_default_dynamic_cache(): + if generation_config.cache_implementation is not None: + warnings.warn( + "This model does not support `Cache` instances, it only supports the legacy cache format (tuple " + f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be " + "ignored.", + UserWarning, + ) + return + + # Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation` + + # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, + # which is only supported in dynamic caches atm + if assistant_model is not None and generation_config.cache_implementation is not None: + logger.warning_once( + "An assistant model is provided, using a dynamic cache instead of a cache of type=" + f"'{generation_config.cache_implementation}'." + ) + generation_config.cache_implementation = None + + generation_config.cache_implementation = generation_config.cache_implementation or getattr( + self.config.get_text_config(), "cache_implementation", None + ) + if generation_config.cache_implementation is not None: + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation == "static" and not self._supports_static_cache: + raise ValueError( + "This model does not support `cache_implementation='static'`. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981" + ) + model_kwargs[cache_name] = self._get_cache( + cache_implementation=generation_config.cache_implementation, + batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size, + max_cache_len=max_cache_length, + device=device, + model_kwargs=model_kwargs, + ) + elif generation_config.cache_implementation == "quantized": + if not self._supports_quantized_cache: + raise ValueError( + "This model does not support the quantized cache. If you want your model to support quantized " + "cache, please open an issue and tag @zucchini-nlp." + ) + + cache_config = ( + generation_config.cache_config + if generation_config.cache_config is not None + else QuantizedCacheConfig() + ) + cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] + + if cache_config.backend == "quanto" and not is_optimum_quanto_available(): + raise ImportError( + "You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. " + "Please install it via with `pip install optimum-quanto`" + ) + elif cache_config.backend == "HQQ" and not is_hqq_available(): + raise ImportError( + "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " + "Please install it via with `pip install hqq`" + ) + + model_kwargs[cache_name] = cache_class(cache_config) + elif generation_config.cache_implementation == "offloaded": + model_kwargs[cache_name] = OffloadedCache() + elif generation_config.cache_implementation == "dynamic": + model_kwargs[cache_name] = DynamicCache() + + # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that + # keeps copying the cache thus using much more memory + else: + model_kwargs[cache_name] = ( + DynamicCache() + if not requires_cross_attention_cache + else EncoderDecoderCache(DynamicCache(), DynamicCache()) + ) + + def _supports_logits_to_keep(self) -> bool: + """ + Return True if the current model supports the keyword argument `logits_to_keep` in forward() + to save memory. Checking it in this way allows to avoid using a new model attribute. + """ + return "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys()) + + def _prepare_special_tokens( + self, + generation_config: GenerationConfig, + kwargs_has_attention_mask: Optional[bool] = None, + device: Optional[Union[torch.device, str]] = None, + ): + """ + Prepares the special tokens for generation, overwriting the generation config with their processed versions + converted to tensor. + + Note that `generation_config` is changed in place and stops being serializable after this method is called. + That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the + function). However, if called outside `generate`, consider creating a copy of `generation_config` first. + """ + + # Convert special tokens to tensors + def _tensor_or_none(token, device=None): + if token is None: + return token + + device = device if device is not None else self.device + if isinstance(token, torch.Tensor): + return token.to(device) + return torch.tensor(token, device=device, dtype=torch.long) + + bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) + eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) + pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) + decoder_start_token_tensor = _tensor_or_none(generation_config.decoder_start_token_id, device=device) + + # for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892) + if self.config.is_encoder_decoder: + decoder_start_token_tensor = ( + decoder_start_token_tensor if decoder_start_token_tensor is not None else bos_token_tensor + ) + + # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). + if eos_token_tensor is not None and eos_token_tensor.ndim == 0: + eos_token_tensor = eos_token_tensor.unsqueeze(0) + + # Set pad token if unset (and there are conditions to do so) + if pad_token_tensor is None and eos_token_tensor is not None: + if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + pad_token_tensor = eos_token_tensor[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") + + # Sanity checks/warnings + if self.config.is_encoder_decoder and decoder_start_token_tensor is None: + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + if ( + eos_token_tensor is not None + and isin_mps_friendly(elements=eos_token_tensor, test_elements=pad_token_tensor).any() + ): + if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: + logger.warning_once( + "The attention mask is not set and cannot be inferred from input because pad token is same as " + "eos token. As a consequence, you may observe unexpected behavior. Please pass your input's " + "`attention_mask` to obtain reliable results." + ) + if eos_token_tensor is not None and ( + torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any() + ): + logger.warning( + f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation " + "will not stop until the maximum length is reached. Depending on other flags, it may even crash." + ) + + # Update generation config with the updated special tokens tensors + # NOTE: this must be written into a different attribute name than the one holding the original special tokens + # (in their non-tensor form), in order to enable end-to-end compilation. See + # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations + generation_config._bos_token_tensor = bos_token_tensor + generation_config._eos_token_tensor = eos_token_tensor + generation_config._pad_token_tensor = pad_token_tensor + generation_config._decoder_start_token_tensor = decoder_start_token_tensor + + def _valid_auto_compile_criteria(self, model_kwargs: dict, generation_config: GenerationConfig) -> bool: + """ + Determines whether to trigger auto-compilation of the model's forward pass at generation time. + """ + # Override: honor `disable_compile` flag + if generation_config.disable_compile: + return False + + # Base logic + valid_hardware = self.device.type == "cuda" or bool( + generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices + ) + using_compilable_cache = ( + isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable + ) + can_compile = valid_hardware and using_compilable_cache and self._supports_static_cache + + # Exception 1: Some quantization methods do not support compilation + if getattr(self, "hf_quantizer", None) is not None: + can_compile &= self.hf_quantizer.is_compileable + + if hasattr(self, "hf_device_map"): + all_model_devices = set(self.hf_device_map.values()) + # Exception 2: Don't compile if the model is using CPU offload (as of April 2025, this results in a crash) + has_cpu_offload = "cpu" in all_model_devices and len(all_model_devices) > 1 + can_compile &= not has_cpu_offload + + # Exception 3: Disk offload is not supported for compilation + has_disk_offload = "disk" in all_model_devices + can_compile &= not has_disk_offload + + # Finally: if the user has manually specified compilation options, but compilation is not possible, let's warn + # them + if generation_config.compile_config is not None and not can_compile: + logger.warning_once( + "You have set `compile_config`, but we are unable to meet the criteria for compilation. Compilation " + "will be skipped." + ) + + return can_compile + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + use_model_defaults: Optional[bool] = None, + custom_generate: Optional[str] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + r""" + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](../generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config ([`~generation.GenerationConfig`], *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which has the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complements the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. If your stopping criteria depends on the `scores` input, make + sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is + intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://huggingface.co/papers/2010.00904). + synced_gpus (`bool`, *optional*): + Whether to continue running the while loop until max_length. Unless overridden, this flag will be set + to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid + deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`. + assistant_model (`PreTrainedModel`, *optional*): + An assistant model that can be used to accelerate generation. The assistant model must have the exact + same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistant model + is much faster than running generation with the model you're calling generate from. As such, the + assistant model should be much smaller. + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + The negative prompt needed for some processors such as CFG. The batch size must match the input batch + size. This is an experimental feature, subject to breaking API changes in future versions. + negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Attention_mask for `negative_prompt_ids`. + use_model_defaults (`bool`, *optional*): + When it is `True`, unset parameters in `generation_config` will be set to the model-specific default + generation configuration (`model.generation_config`), as opposed to the global defaults + (`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be + `True`. + custom_generate (`str`, *optional*): + A string containing the name of a huggingface.co repository. If provided, the custom `generate` + function defined in that reposity's `custom_generate/generate.py` file will be executed instead of the + standard `generate` method. Note that the logic is for generation is entirely defined in that + repository, and the return type may be different from the standard `generate` method. + kwargs (`dict[str, Any]`, *optional*): + Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GenerateDecoderOnlyOutput`], + - [`~generation.GenerateBeamDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + # 0. If requested, load an arbitrary generation recipe from the Hub and run it instead + trust_remote_code = kwargs.pop("trust_remote_code", None) + if custom_generate is not None: + # Get all `generate` arguments in a single variable. Custom functions are responsible for handling them: + # they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to + # trigger the custom generation. They can access to methods from `GenerationMixin` through `model`. + global_keys_to_exclude = { + "self", + "kwargs", + "global_keys_to_exclude", + "trust_remote_code", + "custom_generate", + } + generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude} + generate_arguments.update(kwargs) + + custom_generate_function = self.load_custom_generate( + custom_generate, trust_remote_code=trust_remote_code, **kwargs + ) + return custom_generate_function(model=self, **generate_arguments) + + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria + assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation + + generation_config, model_kwargs = self._prepare_generation_config( + generation_config, use_model_defaults, **kwargs + ) + self._validate_model_kwargs(model_kwargs.copy()) + self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer) + + # 2. Set generation parameters if not already defined + if synced_gpus is None: + synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1 + + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + + # 3. Define model inputs + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + device = inputs_tensor.device + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) + + # decoder-only models must use left-padding for batched generation. + if not self.config.is_encoder_decoder: + # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` + # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. + if ( + generation_config._pad_token_tensor is not None + and batch_size > 1 + and len(inputs_tensor.shape) == 2 + and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + + # 4. Define other model kwargs + # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are + # generating the first new token or not, and we only want to use the embeddings for the first new token) + if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": + generation_config.use_cache = True + + if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, generation_config, model_kwargs + ) + elif kwargs_has_attention_mask: + # TODO (joao): generalize this check with other types of inputs + if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2: + raise ValueError("`attention_mask` passed to `generate` must be 2D.") + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name, generation_config + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=batch_size, + model_input_name=model_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=generation_config._decoder_start_token_tensor, + device=inputs_tensor.device, + ) + else: + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + + if generation_config.token_healing: + input_ids = self.heal_tokens(input_ids, tokenizer) + + if streamer is not None: + streamer.put(input_ids.cpu()) + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_length = input_ids.shape[1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) + + # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole + # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding + # dynamically overrides this value as it can need more than the last token logits + if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs: + model_kwargs["logits_to_keep"] = 1 + + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # 7. Prepare the cache. + # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. + # - different models have a different cache name expected by the model (default = "past_key_values") + # - `max_length`, prepared above, is used to determine the maximum cache length + max_cache_length = generation_config.max_length - 1 + if ( + inputs_tensor.shape[1] != input_ids_length + and model_input_name == "inputs_embeds" + and not self.config.is_encoder_decoder + ): + max_cache_length += inputs_tensor.shape[1] + self._prepare_cache_for_generation( + generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device + ) + + # 8. determine generation mode + generation_mode = generation_config.get_generation_mode(assistant_model) + + if streamer is not None and (generation_config.num_beams > 1): + raise ValueError( + "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." + ) + + if self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + + # 9. prepare logits processors and stopping criteria + prepared_logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + device=inputs_tensor.device, + model_kwargs=model_kwargs, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + prepared_stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs + ) + + # Set model_kwargs `use_cache` so we can use it later in forward runs + model_kwargs["use_cache"] = generation_config.use_cache + + # 10. go into different generation modes + if generation_mode == GenerationMode.ASSISTED_GENERATION: + if generation_config.num_return_sequences > 1: + raise ValueError( + "num_return_sequences has to be 1 when doing assisted generate, " + f"but is {generation_config.num_return_sequences}." + ) + if batch_size > 1: + raise ValueError("assisted generate is only supported for batch_size = 1") + if not model_kwargs["use_cache"]: + raise ValueError("assisted generate requires `use_cache=True`") + if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]: + raise ValueError("assisted generate is not supported with Static cache classes`") + if self._is_stateful: + # In assisted generation we need the ability to confirm whether the model would pick certain tokens, + # which is not possible with stateful models (they can't reset to a previous subset of generated text) + raise ValueError( + f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" + ) + + # 11. Get the candidate generator, given the parameterization + candidate_generator = self._get_candidate_generator( + generation_config=generation_config, + input_ids=input_ids, + inputs_tensor=inputs_tensor, + assistant_model=assistant_model, + logits_processor=logits_processor, + target_tokenizer=tokenizer, + assistant_tokenizer=assistant_tokenizer, + model_kwargs=model_kwargs, + ) + + # 12. run assisted generate + result = self._assisted_decoding( + input_ids, + candidate_generator=candidate_generator, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + elif generation_mode == GenerationMode.DOLA_GENERATION: + if not trust_remote_code: + logger.warning_once( + "DoLa Decoding is scheduled to be moved to a `custom_generate` repository in v4.55.0. " + "To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call." + ) + if self._is_stateful: + # DoLa decoding was not designed for stateful models, and would require some changes + raise ValueError( + f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}" + ) + result = self._dola_decoding( + input_ids, + dola_layers=generation_config.dola_layers, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + + elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: + if not trust_remote_code: + logger.warning_once( + "Contrastive Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. " + "To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call." + ) + if not model_kwargs["use_cache"]: + raise ValueError("Contrastive search requires `use_cache=True`") + if self._is_stateful: + # Just like assisted generation, we need to be able to rollback to a previous state (see comment above) + raise ValueError( + f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}" + ) + + result = self._contrastive_search( + input_ids, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + + elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): + # 11. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) + result = self._sample( + input_ids, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + + elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): + # 11. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 12. run beam sample + result = self._beam_search( + input_ids, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH: + logger.warning_once( + "Group Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. " + "To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call." + ) + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + num_beam_groups=generation_config.num_beam_groups, + max_length=generation_config.max_length, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + result = self._group_beam_search( + input_ids, + beam_scorer, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH: + logger.warning_once( + "Constrained Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. " + "To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call." + ) + final_constraints = [] + if generation_config.constraints is not None: + final_constraints = generation_config.constraints + + if generation_config.force_words_ids is not None: + + def typeerror(): + raise ValueError( + "`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]` " + f"of positive integers, but is {generation_config.force_words_ids}." + ) + + if ( + not isinstance(generation_config.force_words_ids, list) + or len(generation_config.force_words_ids) == 0 + ): + typeerror() + + for word_ids in generation_config.force_words_ids: + if isinstance(word_ids[0], list): + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any(not isinstance(token_ids, list) for token_ids in word_ids): + typeerror() + if any( + any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) + for token_ids in word_ids + ): + typeerror() + + constraint = DisjunctiveConstraint(word_ids) + else: + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): + typeerror() + + constraint = PhrasalConstraint(word_ids) + final_constraints.append(constraint) + + # 11. prepare beam search scorer + constrained_beam_scorer = ConstrainedBeamSearchScorer( + constraints=final_constraints, + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + max_length=generation_config.max_length, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + result = self._constrained_beam_search( + input_ids, + constrained_beam_scorer=constrained_beam_scorer, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + # Convert to legacy cache format if requested + if ( + generation_config.return_legacy_cache is True + and hasattr(result, "past_key_values") + and getattr(result.past_key_values, "to_legacy_cache") is not None + ): + result.past_key_values = result.past_key_values.to_legacy_cache() + return result + + def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool: + """ + Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is + fed through `this_peer_finished`. ZeRO stage 3-friendly. + """ + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0, device=device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + return False + elif this_peer_finished: + return False + return True + + def heal_tokens( + self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None + ) -> torch.LongTensor: + r""" + Generates sequences of token ids for models with a language modeling head. + Parameters: + input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation. + tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids. + Return: + `torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension. + """ + if tokenizer is None: + raise ValueError( + " When generating with token healing, you must pass the model's tokenizer to the `tokenizer` " + "argument of `generate`." + ) + + bos_token_id, pad_token_id = tokenizer.bos_token_id, tokenizer.pad_token_id + vocab_trie = ExtensionsTrie(tokenizer.get_vocab()) + generation_config = GenerationConfig(max_new_tokens=1, pad_token_id=pad_token_id) + + # assumption: leading/trailing whitespace is not meaningful, so the prompts are + # stripped before re-tokenizing to desensitize generation to whitespace artefacts + prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)] + input_ids = tokenizer( + prompts, + return_tensors="pt", + padding=True, + ).input_ids.to(input_ids.device) + + # replace bos with pad to not condition healing on it + input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids) + + """ + the latter code assumes the input_ids is not empty, + input_id has to be checked if contains elements + """ + if input_ids.numel() == 0: + return input_ids + + tail_ids = input_ids[:, -1].tolist() + + space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0] + # tail tokens are used for a prefix search, thus, whitespaces are replaced with + # their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace + tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids) + + for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)): + batch_ids = input_ids[batch_idx] + if torch.all(batch_ids == pad_token_id).item(): + continue # skip empty sequences (all pad ids) + + # apply bias for alternatives (extensions) to the tail token + """ + seq_bias key has to be tuple with int so have to use + tokenizer function to convert str to int + """ + seq_bias = { + (tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok) + } + + if len(seq_bias) == 1: + continue # skip if there are no token alternatives to heal with + + # slightly favor original token to limit aggressive healing e.g. 'http' -> 'https' + seq_bias[(tail_id,)] += 1.0 + generation_config.update(sequence_bias=seq_bias) + + trimmed_ids = batch_ids[:-1] + + """ + the latter code assumes trimmed_ids is not empty + so have to check the its element count + """ + if trimmed_ids.numel() == 0: + continue + + # if the prompt is a single (non-pad) token, regenerate from bos + if len(batch_ids[batch_ids != pad_token_id]) == 1: + trimmed_ids[-1] = bos_token_id + + input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=generation_config) + + return input_ids + + def _dola_decoding( + self, + input_ids: torch.LongTensor, + dola_layers: Union[str, list[int]], + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: "BaseStreamer", + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **dola decoding** and can be + used for decoder-only text models. + The method is based on the paper "DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language + Models" (https://huggingface.co/papers/2309.03883) in ICLR 2024. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + dola_layers (`Union[str, list[int]]`): + The candidate layers used in contrasting layers of DoLa. It can be either 1) 'low' or 'high', which + means the lower part or higher part of the model layers, respectively, or 2) a list of layer indices + to be used for candidate layers. The 0-th layer is the word embedding layer of the model. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] + or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + + if self.config.is_encoder_decoder: + raise ValueError("DoLa decoding is only available for decoder-only models.") + # init values + + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # keep track of which sequences are already finished + batch_size, cur_length = input_ids.shape[:2] + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(cur_length, input_ids.device, model_kwargs) + + this_peer_finished = False + + # prepare layers for DoLa decoding + final_layer = self.config.get_text_config().num_hidden_layers + # if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer, + # as the early exit from word embeddings will become identity function + # if the model is really shallow (<=2 layers), we use the 1st layer if it's not the final layer and the 0-th + # layer otherwise. Notice that DoLa does not help shallow models much. + if not self.config.tie_word_embeddings: + start_layer = 0 + elif final_layer > 2: + start_layer = 2 + elif final_layer == 2: + start_layer = 1 + else: + start_layer = 0 + + # For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)` + # are used for `'low'` and `'high'` layers, respectively. + # For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for + # `'low'` and `'high'` layers, respectively. + if isinstance(dola_layers, str) and dola_layers == "low": + if start_layer == final_layer // 2: + candidate_premature_layers = [start_layer] + else: + candidate_premature_layers = ( + list(range(start_layer, final_layer // 2, 2)) + if final_layer <= 40 + else list(range(start_layer, 20, 2)) + ) + elif isinstance(dola_layers, str) and dola_layers == "high": + candidate_premature_layers = ( + list(range(final_layer // 2, final_layer, 2)) + if final_layer <= 40 + else list(range(final_layer - 20, final_layer, 2)) + ) + # Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers. + elif isinstance(dola_layers, list): + candidate_premature_layers = [i for i in dola_layers if i < final_layer] + else: + raise ValueError("dola_layers must be either 'low', 'high' or a list of integers.") + + lm_head = self.get_output_embeddings() + if lm_head is None: + raise ValueError("DoLa is not supported for models that don't have output embeddings.") + + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=True, + ) + + # .float() is needed to retain precision for later logits manipulations + final_layer_next_token_logits = outputs.logits[:, -1, :].detach().to(copy=True, dtype=torch.float32) + final_logits = outputs.logits[:, -1, :].float() + candidate_premature_logits = {} + for candidate_premature_layer in candidate_premature_layers: + candidate_premature_logits[candidate_premature_layer] = lm_head( + outputs.hidden_states[candidate_premature_layer][:, -1, :] + ).to(final_logits.device) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if synced_gpus and this_peer_finished: + continue + + next_token_logits = _dola_select_contrast( + candidate_premature_layers, candidate_premature_logits, final_logits + ) + next_token_logits = next_token_logits.to(input_ids.device) + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (final_layer_next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + if do_sample: # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: # argmax + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + + # stop when each sentence is finished + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids + + @torch.no_grad() + def _contrastive_search( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **contrastive search** and can + be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] + or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + top_k = generation_config.top_k + penalty_alpha = generation_config.penalty_alpha + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + sequential = generation_config.low_memory + + # init attention / hidden states / scores tuples + raw_logits = () if (return_dict_in_generate and output_logits) else None + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size, cur_len = input_ids.shape[:2] + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) + + # Create cosine_matrix_mask based on the attention_mask + cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long) + if self.config.is_encoder_decoder: + if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None: + cosine_matrix_mask = model_kwargs["decoder_attention_mask"] + else: + cosine_matrix_mask = model_kwargs["attention_mask"] + cosine_matrix_mask = cosine_matrix_mask.repeat_interleave(top_k, dim=0) + + this_peer_finished = False + + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; + # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step + if model_kwargs.get("past_key_values") is None or ( + isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache)) + and model_kwargs["past_key_values"].get_seq_length() == 0 + ): + # prepare inputs + model_kwargs["use_cache"] = True + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save + # the `encoder_outputs` + outputs = self( + **model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions + ) + + # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with + # previous tokens) + if self.config.is_encoder_decoder: + last_hidden_states = outputs.decoder_hidden_states[-1] + else: + last_hidden_states = outputs.hidden_states[-1] + + # next logit for contrastive search to select top-k candidate tokens + # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration + # (the clone itself is always small) + # torch.float32 is needed to retain precision for later logits manipulations + logit_for_next_step = outputs.logits[:, -1, :].to( + copy=True, dtype=torch.float32, device=input_ids.device + ) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + + if not sequential: + # Expands model inputs top_k times, for batched forward passes (akin to beam search). + # input_ids is required for expanding visual inputs in qwen2vl + _, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=top_k, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + past_key_values = model_kwargs.get("past_key_values") + if past_key_values is None: + raise ValueError( + f"{self.__class__.__name__} does not support caching and therefore **can't** be used " + "for contrastive search." + ) + elif ( + not isinstance(past_key_values[0], (tuple, torch.Tensor)) + or past_key_values[0][0].shape[0] != batch_size + ): + raise ValueError( + f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be " + "used for contrastive search without further modifications." + ) + + # contrastive_search main logic start: + # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by + # degeneration penalty + processed_logit_for_next_step = logits_processor(input_ids, logit_for_next_step) + next_probs = nn.functional.softmax(processed_logit_for_next_step, dim=-1) + + top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_logits: + raw_logits += (logit_for_next_step,) + if output_scores: + scores += (processed_logit_for_next_step,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # This is needed to properly delete outputs.logits which may be very large for this first iteration + # Otherwise a reference to outputs.logits is kept all along until after the next call to self.forward() + del outputs + + if not sequential: + # Replicates the new past_key_values to match the `top_k` candidates + past = model_kwargs["past_key_values"] + # If it is a static cache, modify it in-place layer after layer to save memory + if isinstance(past, DynamicCache) or ( + isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache) + ): + past.batch_repeat_interleave(top_k) + else: + new_key_values = [] + for layer in past: + items = [] + # item is either the key or the value matrix + for item in layer: + items.append(item.repeat_interleave(top_k, dim=0)) + new_key_values.append(tuple(items)) + + past = tuple(new_key_values) + + model_kwargs["past_key_values"] = past + + if sequential: + all_outputs = [] + for i in range(top_k): + # compute the candidate tokens by the language model and collect their hidden_states + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs) + + outputs = self( + **next_model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions, + ) + if isinstance(outputs["past_key_values"], DynamicCache) or ( + isinstance(outputs["past_key_values"], EncoderDecoderCache) + and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache) + ): + # Remove past K-V from output since we don't need to stack later + outputs["past_key_values"] = None + # Remove last token from past K-V since we don't want to append it at this point + model_kwargs["past_key_values"].crop(-1) + + all_outputs.append(outputs) + outputs = stack_model_outputs(all_outputs, self.config.get_text_config()) + + else: + # compute the candidate tokens by the language model and collect their hidden_states + # assembles top_k_ids into batch of size k + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) + + outputs = self( + **next_model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions, + ) + + # This is essential to avoid having a last reference to the big past K-V and double the necessary memory + # in the next loop + del next_model_inputs + + # name is different for encoder-decoder and decoder-only models + if self.config.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1] + full_hidden_states = outputs.decoder_hidden_states + else: + next_hidden = outputs.hidden_states[-1] + full_hidden_states = outputs.hidden_states + + # .float() is needed to retain precision for later logits manipulations + logits = outputs.logits[:, -1, :].float() + context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) + + # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the + # model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't + # introduce (noticeable) slowdowns on single-device runs. + selected_idx = _ranking_fast( + context_hidden, next_hidden, top_k_probs, cosine_matrix_mask, penalty_alpha, top_k + ) + cosine_matrix_mask = torch.cat( + [cosine_matrix_mask, cosine_matrix_mask.new_ones((cosine_matrix_mask.shape[0], 1))], dim=-1 + ) + selected_idx = selected_idx.to("cpu") + + # This will be used instead of the previous inneficient torch.stack(torch.split()) + augmented_idx = torch.tensor([x + i * top_k for i, x in enumerate(selected_idx)]) + + # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing + # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores + # (model confidence minus degeneration penalty); (6) decoder hidden_states + next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx] + next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) + next_hidden = next_hidden[range(batch_size), selected_idx, :] + last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) + + next_decoder_hidden_states = () + for layer in full_hidden_states: + layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :] + next_decoder_hidden_states += (layer,) + + # generate past_key_values cache of only the selected token + if sequential: + next_model_input = self.prepare_inputs_for_generation( + top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs + ) + + selected_outputs = self( + **next_model_input, + return_dict=True, + output_hidden_states=False, + output_attentions=False, + ) + next_past_key_values = selected_outputs["past_key_values"] + + else: + next_past_key_values = None + for possible_cache_name in ALL_CACHE_NAMES: + next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None) + # Do it in-place layer per layer to save memory + if isinstance(next_past_key_values, DynamicCache) or ( + isinstance(next_past_key_values, EncoderDecoderCache) + and isinstance(next_past_key_values.self_attention_cache, DynamicCache) + ): + next_past_key_values.batch_select_indices(augmented_idx) + else: + new_key_values = [] + for layer in next_past_key_values: + items = [] + # item is either the key or the value matrix + for item in layer: + items.append(item[augmented_idx, ...]) + new_key_values.append(tuple(items)) + + next_past_key_values = tuple(new_key_values) + + logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] + logit_for_next_step = logit_for_next_step.to(input_ids.device) + + # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration + if self.config.is_encoder_decoder: + next_step_cross_attentions = () + next_step_decoder_attentions = () + if output_attentions: + for layer in outputs.cross_attentions: + layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] + next_step_cross_attentions += (layer,) + for layer in outputs.decoder_attentions: + layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] + next_step_decoder_attentions += (layer,) + outputs = Seq2SeqLMOutput( + past_key_values=next_past_key_values, + decoder_hidden_states=next_decoder_hidden_states, + decoder_attentions=next_step_decoder_attentions or None, + cross_attentions=next_step_cross_attentions or None, + ) + else: + next_step_attentions = () + if output_attentions: + for layer in outputs.attentions: + layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] + next_step_attentions += (layer,) + outputs = CausalLMOutputWithPast( + past_key_values=next_past_key_values, + hidden_states=next_decoder_hidden_states, + attentions=next_step_attentions or None, + ) + # contrastive_search main logic end + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if synced_gpus and this_peer_finished: + continue + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + + # stop when each sentence is finished + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + # Contrastive search works by forward looking at the next token, so we need to exclude it from + # `past_key_values` to be consistent with the other decoding methods + if model_kwargs.get("past_key_values") is not None: + if isinstance(model_kwargs["past_key_values"], DynamicCache) or ( + isinstance(model_kwargs["past_key_values"], EncoderDecoderCache) + and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache) + ): + model_kwargs["past_key_values"].crop(-1) + else: + past_key_values = [] + for layer in model_kwargs["past_key_values"]: + layer_past_key_values = [] + for item in layer: + layer_past_key_values.append(item[..., :-1, :]) + past_key_values.append(tuple(layer_past_key_values)) + model_kwargs["past_key_values"] = tuple(past_key_values) + + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids + + def _sample( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size, cur_len = input_ids.shape[:2] + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) + + model_forward = self.__call__ + compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) + if compile_forward: + os.environ["TOKENIZERS_PARALLELISM"] = "0" + # If we use FA2 and a static cache, we cannot compile with fullgraph + if self.config._attn_implementation == "flash_attention_2" and getattr( + model_kwargs.get("past_key_values"), "is_compileable", False + ): + if generation_config.compile_config is None: + generation_config.compile_config = CompileConfig(fullgraph=False) + # only raise warning if the user passed an explicit compile-config (otherwise, simply change the default without confusing the user) + elif generation_config.compile_config.fullgraph: + logger.warning_once( + "When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as " + "FA2 introduces graph breaks. We overrode the option with `fullgraph=False`." + ) + generation_config.compile_config.fullgraph = False + model_forward = self.get_compiled_call(generation_config.compile_config) + + if generation_config.prefill_chunk_size is not None: + model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs) + is_prefill = False + else: + is_prefill = True + + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + if is_prefill: + outputs = self(**model_inputs, return_dict=True) + is_prefill = False + else: + outputs = model_forward(**model_inputs, return_dict=True) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if synced_gpus and this_peer_finished: + continue + + # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device) + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # token selection + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + cur_len += 1 + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids + + # Auxiliary functions for beam search + def _temporary_reorder_cache(self, past_key_values, beam_idx): + """ + Temporary function to handle the different types of cache reordering processes while we roll out `Cache`. + + TODO: standardize cache formats and make all models compatible with `Cache`. It would remove the need + for this function, with `Cache.reorder_cache` being the sole remaining code path + """ + model_class = self.__class__.__name__.lower() + # Exception 1: code path for models using the legacy cache format + if isinstance(past_key_values, (tuple, list)): + past_key_values = self._reorder_cache(past_key_values, beam_idx) + # Exception 2: models with different cache formats. These are limited to `DynamicCache` until their + # cache format is standardized, to avoid adding complexity to the codebase. + elif "gptbigcode" in model_class: + if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)): + raise ValueError( + f"Using an unsupported cache format with {model_class}. Currently, it only supports the " + "legacy tuple format or `DynamicCache`" + ) + past_key_values = self._reorder_cache(past_key_values, beam_idx) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + # Standard code path: use the `Cache.reorder_cache` + else: + past_key_values.reorder_cache(beam_idx) + return past_key_values + + @staticmethod + def _flatten_beam_dim(tensor: torch.Tensor) -> torch.Tensor: + """[batch_size, num_beams, ...] -> [batch_size * num_beams, ...]""" + shape = list(tensor.shape) + return torch.reshape(tensor, [shape[0] * shape[1]] + shape[2:]) + + @staticmethod + def _unflatten_beam_dim(tensor: torch.Tensor, batch_size: int, num_beams: int) -> torch.Tensor: + """[batch_size * num_beams, ...] -> [batch_size, num_beams, ...]""" + shape = list(tensor.shape) + return torch.reshape(tensor, [batch_size, num_beams] + shape[1:]) + + @staticmethod + def _gather_beams(tensor: torch.Tensor, beam_indices: torch.Tensor) -> torch.Tensor: + """ + Gathers the beam slices indexed by beam_indices into new beam array. + + Args: + tensor (`torch.Tensor`): A tensor containing data to be gathered. The tensor is a 2D or a 3D tensor + with the two first dimensions depicting the batch and the beam dimensions. + beam_indices (`torch.Tensor` of shape `(batch_size, num_beams_to_select)`): The indices of the beams to + select . + + Returns: + A tensor with the selected beams + """ + # `take_along_dim` requires its indices arg to have the same number of dims as `input` + while len(beam_indices.shape) < len(tensor.shape): + beam_indices = beam_indices.unsqueeze(-1) + gathered_tensor = torch.take_along_dim(input=tensor, indices=beam_indices, dim=1) + return gathered_tensor + + @staticmethod + def _beam_search_has_unfinished_sequences( + running_beam_scores: torch.Tensor, + beam_scores: torch.Tensor, + is_sent_finished: torch.Tensor, + next_token_hits_stopping_criteria: torch.Tensor, + cur_len: int, + max_length: int, + decoder_prompt_len: int, + early_stopping: Union[bool, str], + length_penalty: float, + ): + """ + Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False + """ + # a. Can the open beams improve the top completed scores? + # early_stopping == False -> apply heuristic = always get the best score from + # `cur_len - decoder_prompt_len`. See the discussion below for more details. + # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 + # early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the + # sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use + # `max_length` there. + if early_stopping == "never" and length_penalty > 0.0: + best_hypothetical_length = max_length - decoder_prompt_len + else: + best_hypothetical_length = cur_len - decoder_prompt_len + best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty) + worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9) + improvement_possible = torch.any(best_possible_running_score > worst_finished_score) + + # b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is + # enabled, where we want to finish as soon as all beams have a completed sequence. + exists_open_beam = ~(torch.all(is_sent_finished) & (early_stopping is True)) + + # c. Have we hit a stopping criteria with all running sequences and have no way to continue? e.g. we have + # reached `max_length`` + valid_continuations = ~torch.all(next_token_hits_stopping_criteria) + + return improvement_possible & exists_open_beam & valid_continuations + + def _get_top_k_continuations( + self, + accumulated_log_probs: torch.Tensor, + running_sequences: torch.Tensor, + running_beam_indices: torch.Tensor, + cur_len: int, + decoder_prompt_len: int, + do_sample: bool, + beams_to_keep: int, + num_beams: int, + vocab_size: int, + batch_size: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Get top-K continuations given the accumulated log probs on the next token. + + A few notes to understand what's going on: + 1. Each item in batch has `num_beams` * `vocab_size` candidate continuations. For each item, get the + top K [K = (number of EOS tokens + 1) * `num_beams`] candidates with the highest accumulated + log-probabilities, or sample them without replacement using the accumulated scores + 2. We gather the top K (as opposed to `num_beams`, or any number lower than K) here so that we have at + least `num_beams` sequences remaining to continue the live beam search. + 3. Note that other stopping criteria might result in impossible to continue beams, i.e. all continuations + selected in this step hit the stopping criteria. + """ + # TODO (joao): This function should take an optional beam scorer function, to manipulate the scores after + # token selection. The function should be an argument exposed, so that custom scoring functions can be + # defined. + + # Gather the top K scores from _all_ beams. + if do_sample: + topk_indices = torch.multinomial( + nn.functional.softmax(accumulated_log_probs, dim=-1), num_samples=beams_to_keep + ) + topk_log_probs = torch.gather(input=accumulated_log_probs, dim=1, index=topk_indices) + else: + topk_log_probs, topk_indices = torch.topk(accumulated_log_probs, k=beams_to_keep) + + # Gather K top beams, recover the beam index by floor division and token id by modulo division + topk_current_beam_indices = topk_indices // vocab_size + topk_running_beam_indices = self._gather_beams(running_beam_indices, topk_current_beam_indices) + topk_running_sequences = self._gather_beams(running_sequences, topk_current_beam_indices) + topk_ids = topk_indices % vocab_size + + # Update sequences for the K top-k new sequences. + topk_running_sequences[:, :, cur_len] = topk_ids + + # we want to store the beam indices with batch information -> real beam index = beam index % num beams + batch_offset = torch.arange(batch_size, device=topk_ids.device).view(-1, 1) * num_beams + batch_modified_indices = topk_current_beam_indices + batch_offset + topk_running_beam_indices[:, :, cur_len - decoder_prompt_len] = batch_modified_indices + + return topk_log_probs, topk_running_sequences, topk_running_beam_indices + + def _get_running_beams_for_next_iteration( + self, + topk_log_probs: torch.Tensor, + topk_running_sequences: torch.Tensor, + topk_running_beam_indices: torch.Tensor, + next_token_hits_stopping_criteria: torch.Tensor, + num_beams: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given the top-K continuations, their scores, and whether they hit a stopping criteria, select the + best non-finished beams to continue beam search in the next iteration. + """ + # To prevent these just finished sequences from being used in subsequent iterations, set their log probs + # to a very large negative value + topk_running_log_probs = topk_log_probs + next_token_hits_stopping_criteria.to(torch.float32) * -1.0e9 + + next_topk_indices = torch.topk(topk_running_log_probs, k=num_beams)[1] + running_sequences = self._gather_beams(topk_running_sequences, next_topk_indices) + running_beam_scores = self._gather_beams(topk_running_log_probs, next_topk_indices) + running_beam_indices = self._gather_beams(topk_running_beam_indices, next_topk_indices) + return running_sequences, running_beam_scores, running_beam_indices + + def _update_finished_beams( + self, + sequences: torch.Tensor, + topk_running_sequences: torch.Tensor, + beam_scores: torch.Tensor, + topk_log_probs: torch.Tensor, + beam_indices: torch.Tensor, + topk_running_beam_indices: torch.Tensor, + is_sent_finished: torch.Tensor, + next_token_hits_stopping_criteria: torch.Tensor, + top_num_beam_mask: torch.Tensor, + num_beams: int, + cur_len: int, + decoder_prompt_len: int, + length_penalty: float, + early_stopping: Union[bool, str], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Updates the finished beams if (and only if) there are new completed sequences that have a higher score than + the current finished sequences. + """ + # Only the top `num_beam` sequences can be considered for the final returned sequences. Remember: the + # remaining sequences only exist as a backup to ensure that we have at least `num_beams` sequences to + # continue. + did_top_num_beams_just_finished = next_token_hits_stopping_criteria & top_num_beam_mask[None, :] + + # Further process topk logits for the finished beams + # - add length penalty + topk_log_probs = topk_log_probs / ((cur_len + 1 - decoder_prompt_len) ** length_penalty) + # - make sure no scores can be added anymore if beam is full and early stopping is on + beams_in_batch_are_full = torch.all(is_sent_finished, axis=-1, keepdims=True) & (early_stopping is True) + topk_log_probs += beams_in_batch_are_full.to(torch.float32) * -1.0e9 + # - make sure still running sequences cannot be chosen as finalized beam + topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9 + + # Get finalized `num_beam` sequences for the next generation step -- combine the previous finalized + # data with the new finalized sequences (if any, non-finalized sequences have a very large negative score + # in this step), and keep the best `num_beams` sequences. + merged_sequences = torch.cat((sequences, topk_running_sequences), dim=1) + merged_scores = torch.cat((beam_scores, topk_log_probs), dim=1) + merged_beam_indices = torch.cat((beam_indices, topk_running_beam_indices), dim=1) + merged_is_sent_finished = torch.cat((is_sent_finished, did_top_num_beams_just_finished), dim=1) + topk_merged_indices = torch.topk(merged_scores, k=num_beams)[1] + sequences = self._gather_beams(merged_sequences, topk_merged_indices) + beam_scores = self._gather_beams(merged_scores, topk_merged_indices) + beam_indices = self._gather_beams(merged_beam_indices, topk_merged_indices) + is_sent_finished = self._gather_beams(merged_is_sent_finished, topk_merged_indices) + return sequences, beam_scores, beam_indices, is_sent_finished + + # end of auxiliary functions for beam search + + def _beam_search( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + **model_kwargs, + ) -> Union[GenerateBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **beam search decoding** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + If it's the first time you're diving into Beam Search, we recommend you read the following blog post: + https://huggingface.co/blog/how-to-generate (especially the beam search section). + + You can recompute the sequence scores from the individual scores using the `compute_transition_scores` function + (https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores) + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`: + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + + # 1. init beam_search values + pad_token_id = generation_config._pad_token_tensor + eos_token_id = generation_config._eos_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + do_sample = generation_config.do_sample + early_stopping = generation_config.early_stopping + length_penalty = generation_config.length_penalty + max_length = generation_config.max_length + num_beams = generation_config.num_beams + num_return_sequences = generation_config.num_return_sequences + + batch_size_unflattened, cur_len = input_ids.shape[:2] + batch_size = batch_size_unflattened // num_beams + # TODO (joao): standardize special cases + if self.__class__.__name__ == "MoshiDepthDecoder": + vocab_size = self.config.audio_vocab_size + elif self.__class__.__name__ == "ImageGPTForCausalImageModeling": + vocab_size = self.get_output_embeddings().out_features + else: + vocab_size = self.config.get_text_config().vocab_size + decoder_prompt_len = cur_len + this_peer_finished = False + + # At each beam search step, we want to keep top K [K = (number of EOS tokens + 1) * `num_beams`] candidates + # with the highest log-probabilities, or sample K continuations without replacement. We gather the top K + # (as opposed to `num_beams`, or any number lower than K) so that we have at least `num_beams` sequences + # non-finished to continue the live beam search, in case the top `num_beams` all select an EOS token. + n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 + beams_to_keep = max(2, 1 + n_eos_tokens) * num_beams + top_num_beam_mask = torch.cat( + (torch.ones((num_beams), dtype=torch.bool), torch.zeros((beams_to_keep - num_beams), dtype=torch.bool)), + dim=0, + ).to(input_ids.device) + + model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) + + # (joao) feature lost in the refactor. Probably won't implement, hurts readability with minimal gains (there + # are newer low-memory alternatives like the offloaded cache) + sequential = generation_config.low_memory + if sequential: + raise ValueError( + "`low_memory=True` is not supported after the beam search refactor. Please check the discussion in " + "#35802 *after the PR got merged*, and add a comment there if your questions are not yet answered." + ) + + # 2. init output tuples + all_scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + beam_indices = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # 3. init running tensors and static-shaped placeholders + + # per batch, beam-item holding current token in loop and completed sequences + output_fill_value = pad_token_id or eos_token_id[0] if eos_token_id is not None else -1 + running_sequences = torch.full( + (batch_size, num_beams, max_length), + fill_value=output_fill_value, + dtype=torch.int64, + device=input_ids.device, + ) + running_sequences[:, :, :cur_len] = self._unflatten_beam_dim(input_ids, batch_size, num_beams) + sequences = running_sequences.detach().clone() + + # per batch, beam-item score, logprobs + # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens + # of the first beam are considered to avoid sampling the exact same tokens across all beams. + running_beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + running_beam_scores[:, 1:] = -1e9 + beam_scores = torch.full((batch_size, num_beams), fill_value=-1e9, dtype=torch.float, device=input_ids.device) + + # per batch, beam-item state bit indicating if sentence has finished. + is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device) + + # per batch, beam-item state bit indicating if there are valid continuations. + next_token_hits_stopping_criteria = torch.zeros( + (batch_size, num_beams), dtype=torch.bool, device=input_ids.device + ) + + # per batch selected beam indices + running_beam_indices = torch.full( + (batch_size, num_beams, max_length - cur_len), fill_value=-1, dtype=torch.int32, device=input_ids.device + ) + beam_indices = running_beam_indices.detach().clone() + + # 4. run the generation loop + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # a. Forward current tokens, obtain the logits + flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len]) + model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + model_outputs = self(**model_inputs, return_dict=True) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + model_outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if synced_gpus and this_peer_finished: + continue + + # Copy is needed to avoid keeping a hanging ref + logits = model_outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device) + + # b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.* + # `temperature`, ...), and add new logprobs to existing running logprobs scores. + log_probs = nn.functional.log_softmax(logits, dim=-1) + log_probs = logits_processor(flat_running_sequences, log_probs) + + # Store logits, attentions and hidden_states when required + if return_dict_in_generate: + if output_logits: + raw_logits += (logits.clone(),) + if return_dict_in_generate and output_scores: + all_scores += (log_probs.clone(),) + + if output_attentions: + decoder_attentions += ( + (model_outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (model_outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (model_outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (model_outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (model_outputs.hidden_states,) + ) + + # This is needed to properly delete logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del model_outputs + + log_probs = self._unflatten_beam_dim(log_probs, batch_size, num_beams) + log_probs = log_probs + running_beam_scores[:, :, None] + log_probs = torch.reshape(log_probs, (batch_size, num_beams * vocab_size)) + + # c. Retrieve top-K continuations, i.e. select the next token (greedy or sampling) and then keep the best + # continuations among all beams based on the accumulated scores. + topk_log_probs, topk_running_sequences, topk_running_beam_indices = self._get_top_k_continuations( + accumulated_log_probs=log_probs, + running_sequences=running_sequences, + running_beam_indices=running_beam_indices, + cur_len=cur_len, + decoder_prompt_len=decoder_prompt_len, + do_sample=do_sample, + beams_to_keep=beams_to_keep, + num_beams=num_beams, + vocab_size=vocab_size, + batch_size=batch_size, + ) + + # d. Check which running sequences have finished + next_token_hits_stopping_criteria = stopping_criteria( + self._flatten_beam_dim(topk_running_sequences[:, :, : cur_len + 1]), # remove unfilled token indexes + all_scores, + ) + next_token_hits_stopping_criteria = self._unflatten_beam_dim( + next_token_hits_stopping_criteria, batch_size, beams_to_keep + ) + + # e. Get the non-finished running `num_beams` sequences for the next generation step + running_sequences, running_beam_scores, running_beam_indices = self._get_running_beams_for_next_iteration( + topk_log_probs=topk_log_probs, + topk_running_sequences=topk_running_sequences, + topk_running_beam_indices=topk_running_beam_indices, + next_token_hits_stopping_criteria=next_token_hits_stopping_criteria, + num_beams=num_beams, + ) + + # f. Update the completed beams if a new high score in a finished sequence is found + sequences, beam_scores, beam_indices, is_sent_finished = self._update_finished_beams( + sequences=sequences, + topk_running_sequences=topk_running_sequences, + beam_scores=beam_scores, + topk_log_probs=topk_log_probs, + beam_indices=beam_indices, + topk_running_beam_indices=topk_running_beam_indices, + is_sent_finished=is_sent_finished, + next_token_hits_stopping_criteria=next_token_hits_stopping_criteria, + top_num_beam_mask=top_num_beam_mask, + num_beams=num_beams, + cur_len=cur_len, + decoder_prompt_len=decoder_prompt_len, + length_penalty=length_penalty, + early_stopping=early_stopping, + ) + + # g. Prepare remaining data for the next iteration, including computing the stopping condition for + # beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`) + + # pluck the cache from the beam indices that will be used in the next iteration + if model_kwargs.get("past_key_values", None) is not None: + model_kwargs["past_key_values"] = self._temporary_reorder_cache( + past_key_values=model_kwargs["past_key_values"], + beam_idx=self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len]), + ) + + cur_len = cur_len + 1 + this_peer_finished = not self._beam_search_has_unfinished_sequences( + running_beam_scores, + beam_scores, + is_sent_finished, + next_token_hits_stopping_criteria, + cur_len, + max_length, + decoder_prompt_len, + early_stopping, + length_penalty, + ) + + # 5. prepare outputs + # Take best beams for each batch (the score is sorted in descending order) + sequences = self._flatten_beam_dim(sequences[:, :num_return_sequences, :]) + beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences]) + beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :]) + + # Crop the static-shaped tensors to the actual size. + # `beam_indices` is initialized with -1s, and is updated with the beam index of the generated token at each + # step. We can use it to detect the generated length, which may be != `cur_len` (e.g. selected beam is from a + # previous decoding iteration) + max_generated_length = ((beam_indices + 1).bool()).sum(dim=1).max() + output_length = decoder_prompt_len + max_generated_length + sequences = sequences[:, :output_length] + beam_indices = beam_indices[:, :max_generated_length] + + if return_dict_in_generate: + if not output_scores: + beam_scores = None + + if self.config.is_encoder_decoder: + return GenerateBeamEncoderDecoderOutput( + sequences=sequences, + sequences_scores=beam_scores, + scores=all_scores, + logits=raw_logits, + beam_indices=beam_indices, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateBeamDecoderOnlyOutput( + sequences=sequences, + sequences_scores=beam_scores, + scores=all_scores, + logits=raw_logits, + beam_indices=beam_indices, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return sequences + + def _group_beam_search( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + **model_kwargs, + ): + r""" + Generates sequences of token ids for models with a language modeling head using **diverse beam search + decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`): + The sequence used as a prompt for the generation. + beam_scorer (`BeamScorer`): + An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + model_kwargs: + Additional model specific kwargs that will be forwarded to the `forward` function of the model. If + model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + pad_token_id = generation_config._pad_token_tensor + eos_token_id = generation_config._eos_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + batch_size = len(beam_scorer._beam_hyps) // num_beam_groups + device = input_ids.device + + batch_beam_size, cur_len = input_ids.shape + model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) + + if return_dict_in_generate and output_scores: + beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] + else: + beam_indices = None + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in + # the same group don't produce same tokens every time. + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False + + decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + outputs = self(**model_inputs, return_dict=True) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue + + if output_scores: + processed_score = torch.zeros_like(outputs.logits[:, -1, :]) + if output_logits: + # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + raw_logit_score = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of current group only + # No need to clone() the logits here as they will not retain outputs.logits at the end of the loop + # .float() is needed to retain precision for later logits manipulations + next_token_logits = outputs.logits[batch_group_indices, -1, :].to( + dtype=torch.float32, device=input_ids.device + ) + + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * group_size, vocab_size) + vocab_size = next_token_scores.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + + if output_scores: + processed_score[batch_group_indices] = next_token_scores_processed + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. + n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 + next_token_scores, next_tokens = torch.topk( + next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=process_beam_indices, + group_index=beam_group_idx, + decoder_prompt_len=decoder_prompt_len, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + if return_dict_in_generate and output_scores: + beam_indices[beam_group_idx] = tuple( + beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) + ) + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + + group_start_idx + + (beam_idx % group_size) + ) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (processed_score,) + if output_logits: + raw_logits += (raw_logit_score,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory + # (that way the memory peak does not include outputs.logits) + del outputs + + if model_kwargs.get("past_key_values", None) is not None: + model_kwargs["past_key_values"] = self._temporary_reorder_cache( + model_kwargs["past_key_values"], reordering_indices + ) + + # increase cur_len + cur_len = cur_len + 1 + + if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): + this_peer_finished = True + + final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=final_beam_indices, + decoder_prompt_len=decoder_prompt_len, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + + if self.config.is_encoder_decoder: + return GenerateBeamEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + logits=raw_logits, + beam_indices=sequence_outputs["beam_indices"], + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateBeamDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + logits=raw_logits, + beam_indices=sequence_outputs["beam_indices"], + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return sequence_outputs["sequences"] + + def _constrained_beam_search( + self, + input_ids: torch.LongTensor, + constrained_beam_scorer: ConstrainedBeamSearchScorer, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + **model_kwargs, + ) -> Union[GenerateBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **constrained beam search + decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`): + The sequence used as a prompt for the generation. + constrained_beam_scorer (`ConstrainedBeamSearchScorer`): + A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation, while satisfying a list of positive constraints. For more information, the + documentation of [`ConstrainedBeamSearchScorer`] should be read. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + pad_token_id = generation_config._pad_token_tensor + eos_token_id = generation_config._eos_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + + batch_size = len(constrained_beam_scorer._beam_hyps) + num_beams = constrained_beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape[:2] + model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + beam_indices = ( + tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None + ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens + # of the first beam are considered to avoid sampling the exact same tokens across all beams. + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False + + decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + outputs = self(**model_inputs, return_dict=True) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue + + # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + # .float() is needed to retain precision for later logits manipulations + next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device) + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * num_beams, vocab_size) + + next_token_scores_processed = logits_processor(input_ids, next_token_scores) + + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( + next_token_scores_processed + ) + + scores_for_all_vocab = next_token_scores.clone() + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + + # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. + n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 + next_token_scores, next_tokens = torch.topk( + next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True + ) + + next_indices = (next_tokens / vocab_size).long() + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = constrained_beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + scores_for_all_vocab, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, + ) + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory + # (that way the memory peak does not include outputs.logits) + del outputs + + if model_kwargs.get("past_key_values", None) is not None: + model_kwargs["past_key_values"] = self._temporary_reorder_cache( + model_kwargs["past_key_values"], beam_idx + ) + + if return_dict_in_generate and output_scores: + beam_indices = tuple(beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))) + + # increase cur_len + cur_len = cur_len + 1 + + if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): + this_peer_finished = True + + sequence_outputs = constrained_beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + if self.config.is_encoder_decoder: + return GenerateBeamEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + logits=raw_logits, + beam_indices=sequence_outputs["beam_indices"], + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateBeamDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + logits=raw_logits, + beam_indices=sequence_outputs["beam_indices"], + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return sequence_outputs["sequences"] + + def _assisted_decoding( + self, + input_ids: torch.LongTensor, + candidate_generator: CandidateGenerator, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **greedy decoding** or + **sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a + candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text + models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + candidate_generator (`CandidateGenerator`): + A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For + more information, the documentation of [`CandidateGenerator`] should be read. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + do_sample = generation_config.do_sample + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size, cur_len = input_ids.shape[:2] + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) + + this_peer_finished = False + is_first_iteration = True # to preserve the same API in the output as other generation methods + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + cur_len = input_ids.shape[1] + + # 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device + candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) + candidate_input_ids = candidate_input_ids.to(self.device) + if candidate_logits is not None: + candidate_logits = candidate_logits.to(self.device) + + candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] + is_done_candidate = stopping_criteria(candidate_input_ids, None) + + # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain + # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, + # we use this forward pass to also pick the subsequent logits in the original model. + + # 2.1. Prepare the model inputs + candidate_kwargs = copy.copy(model_kwargs) + candidate_kwargs = _prepare_attention_mask( + candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder + ) + candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) + if "cache_position" in candidate_kwargs: + candidate_kwargs["cache_position"] = torch.cat( + ( + candidate_kwargs["cache_position"], + torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long), + ), + dim=0, + ) + + model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) + if "logits_to_keep" in model_inputs: + model_inputs["logits_to_keep"] = candidate_length + 1 + + # 2.2. Run a forward pass on the candidate sequence + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + outputs = self(**model_inputs) + + # 2.3. Process the new logits + # .float() is needed to retain precision for later logits manipulations + new_logits = outputs.logits[:, -candidate_length - 1 :].to( + dtype=torch.float32, device=input_ids.device + ) # excludes the input prompt if present + next_token_logits = new_logits.clone() + if len(logits_processor) > 0: + for i in range(candidate_length + 1): + new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) + + # 3. Select the accepted tokens. There are two possible cases: + # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) + # 👉 Apply algorithm 1 from the speculative decoding paper (https://huggingface.co/papers/2211.17192). + if do_sample and candidate_logits is not None: + valid_tokens, n_matches = _speculative_sampling( + candidate_input_ids, + candidate_logits, + candidate_length, + new_logits, + is_done_candidate, + ) + + # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the + # original model logits with the candidate tokens. We can keep the candidate tokens until the first + # mismatch, or until the max length is reached. + else: + if do_sample: + probs = new_logits.softmax(dim=-1) + selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] + else: + selected_tokens = new_logits.argmax(dim=-1) + + candidate_new_tokens = candidate_input_ids[:, cur_len:] + n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() + + # Ensure we don't generate beyond max_len or an EOS token + if is_done_candidate and n_matches == candidate_length: + n_matches -= 1 + valid_tokens = selected_tokens[:, : n_matches + 1] + + # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated + # by the model after the last candidate match is also valid, as it is generated from a correct sequence. + # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there + # is no match. + + # 4.1. Get the valid continuation, after the matching tokens + input_ids = torch.cat((input_ids, valid_tokens), dim=-1) + if streamer is not None: + streamer.put(valid_tokens.cpu()) + new_cur_len = input_ids.shape[1] + + # 4.2. Discard past key values relative to unused assistant tokens + new_cache_size = new_cur_len - 1 + outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) + + # 5. Update the candidate generation strategy if needed + candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + num_new_tokens=n_matches + 1, + ) + if synced_gpus and this_peer_finished: + continue + + # Store scores, attentions and hidden_states when required + # Assistant: modified to append one tuple element per token, as in the other generation methods. + if return_dict_in_generate: + newly_added_length = n_matches + 1 + if output_scores: + scores += tuple(new_logits[:, i, :] for i in range(newly_added_length)) + if output_logits: + raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length)) + + newly_added_length = new_cur_len if is_first_iteration else newly_added_length + if output_attentions: + if self.config.is_encoder_decoder: + cross_attentions = _split_model_outputs( + cross_attentions, outputs.cross_attentions, cur_len, newly_added_length + ) + decoder_attentions = _split_model_outputs( + decoder_attentions, + outputs.decoder_attentions, + cur_len, + newly_added_length, + is_decoder_attention=True, + ) + # some (V)LLMs have hard requirement on SDPA and thus never return attn + elif outputs.attentions[0] is not None: + decoder_attentions = _split_model_outputs( + decoder_attentions, + outputs.attentions, + cur_len, + newly_added_length, + is_decoder_attention=True, + ) + if output_hidden_states: + if self.config.is_encoder_decoder: + decoder_hidden_states = _split_model_outputs( + decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length + ) + else: + decoder_hidden_states = _split_model_outputs( + decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length + ) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + is_first_iteration = False + + if streamer is not None: + streamer.end() + + if ( + hasattr(candidate_generator, "assistant_model") + and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic" + ): + candidate_generator.assistant_model.generation_config.num_assistant_tokens = ( + candidate_generator.num_assistant_tokens + ) + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids + + def _prefill_chunking(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, **model_kwargs): + # Even if we are not compiling the forward, flex is always compiled when used. With chunk prefill, we may + # end up needing just a bit more graphs than the default (which is 8). Doing this avoids very cryptic warnings + torch._dynamo.config.cache_size_limit = 64 + + chunk_size = generation_config.prefill_chunk_size + # Only chunk up the token just before last, so that decoding is completely performed outside this function + # (here we simply prefill the cache) + input_chunks = torch.split(input_ids[:, :-1], chunk_size, dim=-1) + + if "past_key_values" not in model_kwargs: + raise ValueError("Cannot use prefill chunking without a cache") + + model_forward = self.forward + + compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) + if compile_forward: + model_forward = self.get_compiled_call(generation_config.compile_config) + + attention_mask = model_kwargs.pop("attention_mask", None) + + past_length = 0 + for input_chunk in input_chunks: + current_length = past_length + input_chunk.shape[-1] + # Prepare inputs + if attention_mask is not None: + model_kwargs["attention_mask"] = attention_mask[:, :current_length] + model_kwargs["cache_position"] = torch.arange( + past_length, current_length, dtype=torch.long, device=input_chunk.device + ) + model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0) + model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs) + + outputs = model_forward(**model_inputs, return_dict=True) + + model_kwargs["past_key_values"] = outputs.past_key_values + past_length = current_length + + model_kwargs["attention_mask"] = attention_mask + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 + _ = model_kwargs.pop("position_ids", None) + + return model_kwargs + + +def _speculative_sampling( + candidate_input_ids, + candidate_logits, + candidate_length, + new_logits, + is_done_candidate, +): + """ + Applies sampling as in the speculative decoding paper (https://huggingface.co/papers/2211.17192, algorithm 1). Returns + the selected tokens, as well as the number of candidate matches. + + NOTE: Unless otherwise stated, the variable names match those in the paper. + """ + new_candidate_input_ids = candidate_input_ids[:, -candidate_length:] + # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens + # selected by the assistant, respectively. + q = candidate_logits.softmax(dim=-1) + q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1) + p = new_logits.softmax(dim=-1) + p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1) + probability_ratio = p_i / q_i + + # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller + # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio + # (= keep with p = probability_ratio). Keep all the tokens until the first rejection + r_i = torch.rand_like(probability_ratio) + is_accepted = r_i <= probability_ratio + n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 + + # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior) + if is_done_candidate and n_matches == candidate_length: + # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model + # due to acceptance on EOS we fix `n_matches` + n_matches -= 1 + valid_tokens = new_candidate_input_ids[:, : n_matches + 1] + else: + # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. + gamma = candidate_logits.shape[1] + p_n_plus_1 = p[:, n_matches, :] + if n_matches < gamma: + q_n_plus_1 = q[:, n_matches, :] + p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0) + p_prime.div_(p_prime.sum()) + else: + p_prime = p_n_plus_1 + t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] + + # The selected tokens include the matches (if any) plus the next sampled tokens + if n_matches > 0: + valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1) + else: + valid_tokens = t + + return valid_tokens, n_matches + + +def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): + """ + Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple + where each member corresponds to a single generated token. + """ + # Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the + # prompt. + if len(outputs) == 0: + new_tuple = () + for layer in new_outputs: + last_dim_size = cur_len if is_decoder_attention else layer.shape[-1] + new_tuple += (layer[..., :cur_len, :last_dim_size],) + outputs += (new_tuple,) + # The first iteration contains the prompt + 1 generated token, let's update the length variables accordingly + cur_len += 1 + added_len -= cur_len + + for i in range(added_len): + new_tuple = () + for layer in new_outputs: + last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1] + new_tuple += (layer[..., i : i + 1, :last_dim_size],) + outputs += (new_tuple,) + return outputs + + +def _ranking_fast( + context_hidden: torch.FloatTensor, + next_hidden: torch.FloatTensor, + next_top_k_probs: torch.FloatTensor, + cosine_matrix_mask: torch.LongTensor, + alpha: float, + beam_width: int, +) -> torch.FloatTensor: + """ + Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described + in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each + row in the batch. + """ + norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) + norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) + cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] + + # Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions) + # Using a large negative value for masked positions + cosine_matrix_mask = cosine_matrix_mask.to(dtype=cosine_matrix.dtype) + cosine_matrix_mask = (1 - cosine_matrix_mask) * torch.finfo(cosine_matrix.dtype).min + cosine_matrix = cosine_matrix + cosine_matrix_mask + + degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K] + next_top_k_probs = next_top_k_probs.view(-1) # [B*K] + contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty + contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] + _, selected_idx = contrastive_score.max(dim=-1) # [B] + return selected_idx + + +def _split(data, full_batch_size: int, split_size: int): + """ + Takes care of three cases: + 1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim + 2. data is a tuple: e.g. hidden_states, attentions etc. Keep the tuple as it is and split each tensor in it and + return a list of tuples + 3. data is a tuple of tuples, e.g. past_key_values. Keep the tuple as it is and split each tuple in it and + return a list of tuples of tuples + (see documentation of ModelOutput) + """ + if data is None: + return [None] * (full_batch_size // split_size) + if isinstance(data, torch.Tensor): + return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)] + # New cache format + elif isinstance(data, DynamicCache) or ( + isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache) + ): + return data.batch_split(full_batch_size, split_size) + elif isinstance(data, tuple): + # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) + if isinstance(data[0], tuple): + return [ + tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data) + for i in range(0, full_batch_size, split_size) + ] + + else: + return [ + tuple(sub_tensor[i : i + split_size] for sub_tensor in data) + for i in range(0, full_batch_size, split_size) + ] + else: + raise TypeError(f"Unexpected attribute type: {type(data)}") + + +def _split_model_inputs( + model_input: Union[ModelOutput, dict], split_size: int, full_batch_size: int, config: PretrainedConfig +) -> list[Union[ModelOutput, dict]]: + """ + Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split + size. The input object is dict when it was prepared for forward pass and ModelOutput when it was returned from + previous forward pass. + """ + # Edge case: if model_input is None, return a list of Nones + # this happens with Whisper where encoder_outputs is None + if model_input is None: + return [model_input] * (full_batch_size // split_size) + # Infer the class from the object + model_output_cls = type(model_input) + if (full_batch_size % split_size) != 0: + raise ValueError("`full_batch_size` must be divisible by `split_size`") + + if split_size > full_batch_size: + raise ValueError("`split_size` must be smaller or equal to `full_batch_size`") + + # Helper function to split tensors or tuples of tensors + + # Find all the dataclass fields (e.g., last_hidden_state, pooler_output etc.) and split them + keys = ( + model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys() + ) + # We only keep keys that are in the model_input + keys = [k for k in keys if k in model_input] + # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a + # ModelOutput object. + # bool should not be split but replicated for each split + bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] + keys_to_ignore = ["cache_position", "encoder_outputs", "logits_to_keep"] + non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] + + # we split the tensors and tuples of tensors + data_split_list = [ + {k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys} + for i in range(full_batch_size // split_size) + ] + # bool values are the same and replicated for each split + bool_data = {k: model_input[k] for k in bool_keys} + # encoder_outputs is a ModelOutput object and should be split by its own + if "encoder_outputs" in model_input: + encoder_outputs_split = _split_model_inputs( + model_input["encoder_outputs"], split_size, full_batch_size, config.get_text_config() + ) + data_split_list = [ + {**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list) + ] + # logits_to_keep should be replicated for each split, similar to bool values + if "logits_to_keep" in model_input: + data_split_list = [ + {**data_split, "logits_to_keep": model_input["logits_to_keep"]} for data_split in data_split_list + ] + + # Convert each dictionary in the list to an object of the inferred class + split_model_inputs: list[Union[ModelOutput, dict]] = [ + model_output_cls(**data_split, **bool_data) for data_split in data_split_list + ] + + return split_model_inputs + + +def stack_model_outputs(model_outputs: list[ModelOutput], config: PretrainedConfig) -> ModelOutput: + """ + Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the + specific ModelOutput subclass from the list provided. + """ + if not model_outputs: + raise ValueError("Input list is empty.") + + # Infer the class from the first object in the list + model_output_cls = type(model_outputs[0]) + + # Ensure all objects are of the same type + if not all(isinstance(obj, model_output_cls) for obj in model_outputs): + raise ValueError("All elements in the list should be of the same type.") + + # Helper function to concat tensors or tuples of tensors + def _concat(data): + """ + Reverse of `_split` function above. + """ + if any(data is None for data in data): + return None + if isinstance(data[0], torch.Tensor): + return torch.cat(data, dim=0) + # New cache format + elif isinstance(data[0], DynamicCache): + return DynamicCache.from_batch_splits(data) + elif isinstance(data[0], EncoderDecoderCache): + return EncoderDecoderCache.from_batch_splits(data) + elif isinstance(data[0], tuple): + # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) + if isinstance(data[0][0], tuple): + return tuple( + tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0]))) + for i in range(len(data[0])) + ) + else: + return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0]))) + elif isinstance(data[0], (int, float)): + # If the elements are integers or floats, return a tensor + return torch.tensor(data) + else: + raise TypeError(f"Unexpected attribute type: {type(data[0])}") + + # Use a dictionary comprehension to gather attributes from all objects and concatenate them + concatenated_data = { + k: _concat([getattr(model_output, k) for model_output in model_outputs]) + for k in model_output_cls.__dataclass_fields__.keys() + } + + # Return a new object of the inferred class with the concatenated attributes + return model_output_cls(**concatenated_data) + + +def _relative_top_filter( + scores: torch.FloatTensor, + baseline_scores: torch.FloatTensor, + relative_top: float = 0.1, + filter_value: float = -float("Inf"), + base_filter_value=-1e-3, + min_tokens_to_keep: int = 1, +) -> torch.FloatTensor: + """ + Reference: https://github.com/XiangLi1999/ContrastiveDecoding/blob/170e9142e92159c1237d731e240f5eb14aabf428/transformers/src/transformers/generation_logits_process.py#L235 + Apply filtering to only keep tokens with a probability above a certain threshold. The threshold is defined as `relative_top` * max probability in the distribution. + """ + scores_normalized = scores.log_softmax(dim=-1) + baseline_scores_normalized = baseline_scores.log_softmax(dim=-1) + sorted_logits, sorted_indices = torch.sort(scores_normalized, descending=True) + min_thresh = sorted_logits[..., min_tokens_to_keep - 1] + probs_max = torch.max(scores_normalized, dim=-1).values + probs_thresh = probs_max + np.log(relative_top) + probs_thresh = torch.min(min_thresh, probs_thresh) + probs_thresh = probs_thresh.unsqueeze(-1) + baseline_scores_normalized[scores_normalized < probs_thresh] = base_filter_value + scores_normalized[scores_normalized < probs_thresh] = filter_value + return scores_normalized, baseline_scores_normalized + + +def _dola_select_contrast( + candidate_premature_layers: list[int], + candidate_premature_logits: dict[int, torch.FloatTensor], + final_logits: torch.FloatTensor, +) -> torch.FloatTensor: + if len(candidate_premature_layers) == 1: + base_logits = candidate_premature_logits[candidate_premature_layers[0]] + final_logits, base_logits = _relative_top_filter(final_logits, base_logits) + logits = final_logits - base_logits + return logits + + # 1. Stacking all premature_layers into a new dimension + stacked_premature_layers = torch.stack([candidate_premature_logits[i] for i in candidate_premature_layers], dim=0) + + # 2. Calculate the softmax values for mature_layer and all premature_layers + # shape: (batch_size, vocab_size) + softmax_mature_layer = F.softmax(final_logits, dim=-1) + # shape: (num_premature_layers, batch_size, vocab_size) + softmax_premature_layers = F.softmax(stacked_premature_layers, dim=-1) + + # 3. Calculate the average distribution + # shape: (num_premature_layers, batch_size, vocab_size) + avg_dist = 0.5 * (softmax_mature_layer[None, :, :] + softmax_premature_layers) + + # 4. Calculate log-softmax for the KL divergence + # shape: (batch_size, vocab_size) + log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1) + # shape: (num_premature_layers, batch_size, vocab_size) + log_softmax_premature_layers = F.log_softmax(stacked_premature_layers, dim=-1) + + # 5. Calculate the KL divergences and then the JS divergences + # shape: (num_premature_layers, batch_size) + kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], avg_dist, reduction="none").mean(-1) + # shape: (num_premature_layers, batch_size) + kl2 = F.kl_div(log_softmax_premature_layers, avg_dist, reduction="none").mean(-1) + js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size) + + # 6. Reduce the batchmean + js_divs = js_divs.mean(-1) # shape: (num_premature_layers,) + premature_layer = candidate_premature_layers[int(js_divs.argmax().item())] + + base_logits = candidate_premature_logits[premature_layer] + final_logits, base_logits = _relative_top_filter(final_logits, base_logits) + logits = final_logits - base_logits + return logits + +