Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/transformers
/generation
/candidate_generator.py
# coding=utf-8 | |
# Copyright 2023 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import copy | |
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple | |
import torch | |
from ..cache_utils import DynamicCache | |
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor | |
if TYPE_CHECKING: | |
from ..modeling_utils import PreTrainedModel | |
from .configuration_utils import GenerationConfig | |
class CandidateGenerator: | |
"""Abstract base class for all candidate generators that can be applied during assisted generation.""" | |
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: | |
""" | |
Fetches the candidates to be tried for the current input. | |
Args: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) | |
Return: | |
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be | |
assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length, | |
vocabulary_size)` containing the logits associated to each candidate. | |
""" | |
raise NotImplementedError( | |
f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`." | |
) | |
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): | |
""" | |
Updates the candidate generation strategy based on the outcomes. | |
Args: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) | |
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): | |
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using | |
beam search or log softmax for each vocabulary token when using beam search | |
num_matches (`int`): | |
The number of matches between the candidate sequences and the model predictions. | |
""" | |
raise NotImplementedError( | |
f"{self.__class__} is an abstract class. Only classes inheriting this class can call " | |
"`update_candidate_strategy`." | |
) | |
class AssistedCandidateGenerator(CandidateGenerator): | |
""" | |
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates | |
candidates through the use of a smaller model. Read the following blog post for more information: | |
https://huggingface.co/blog/assisted-generation | |
Args: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) | |
assistant_model (`PreTrainedModel`): | |
The model to be used for generating candidates. This model should be smaller than the main model. | |
generation_config (`~generation.GenerationConfig`, *optional*): | |
The generation configuration to be used as base parametrization for the generation call. | |
logits_processor (`LogitsProcessorList`): | |
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | |
used to modify the prediction scores of the language modeling head applied at each generation step. | |
model_kwargs (`Dict`): | |
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant | |
model as well. | |
inputs_tensor (`torch.Tensor`, *optional*): | |
The model input tensor. In encoder-decoder models, this is the encoder input. | |
""" | |
def __init__( | |
self, | |
input_ids: torch.LongTensor, | |
assistant_model: "PreTrainedModel", | |
generation_config: "GenerationConfig", | |
model_kwargs: Dict, | |
inputs_tensor: Optional[torch.Tensor] = None, | |
logits_processor: "LogitsProcessorList" = None, | |
): | |
# Make sure all data at the same device as assistant model | |
device = assistant_model.device | |
input_ids = input_ids.to(device) | |
if inputs_tensor is not None: | |
inputs_tensor = inputs_tensor.to(device) | |
# Prepare the assistant and the starting number of candidate tokens | |
self.assistant_model = assistant_model | |
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens | |
# Set eos in assistant same as in target model | |
self.assistant_model.generation_config.eos_token_id = generation_config.eos_token_id | |
# Prepare the kwargs for the assistant model | |
assistant_kwargs = {} | |
for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads | |
if key not in ("encoder_outputs", "assistant_encoder_outputs", "past_key_values"): | |
assistant_kwargs[key] = ( | |
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value) | |
) | |
if "assistant_encoder_outputs" in model_kwargs: | |
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] | |
elif assistant_model.config.is_encoder_decoder: | |
inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs( | |
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs | |
) | |
assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( | |
inputs_tensor, assistant_kwargs, model_input_name, assistant_model.generation_config | |
) | |
elif "encoder_outputs" in model_kwargs: | |
assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"] | |
self.assistant_kwargs = assistant_kwargs | |
# Prepare assistant model's keys of inputs | |
if assistant_model.config.is_encoder_decoder: | |
# both are encoder-decoder | |
self.input_ids_key = "decoder_input_ids" | |
elif "encoder_outputs" in assistant_kwargs: | |
# special case for encoder-decoder with decoder-only assistant (like DistilWhisper) | |
self.input_ids_key = "input_ids" | |
self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get( | |
"decoder_attention_mask", | |
torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long), | |
) | |
else: | |
# both are decoder-only | |
self.input_ids_key = "input_ids" | |
# Prepare generation-related options. | |
self.logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
self.generation_config = copy.deepcopy(generation_config) | |
self.generation_config.return_dict_in_generate = True | |
self.generation_config.output_scores = True | |
# Disable sampling -- this implementation of assisted generation/speculative decoding uses the assistant | |
# greedily to maximize matches. Disables sampling-related flags to prevent warnings | |
self.generation_config.do_sample = False | |
for attr in ("temperature", "top_p", "min_p", "typical_p", "top_k", "epsilon_cutoff", "eta_cutoff"): | |
setattr(self.generation_config, attr, None) | |
# avoid unnecessary warnings that min_length is larger than max_new_tokens | |
# remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`) | |
self.main_model_min_length = self.generation_config.min_length | |
self.generation_config.min_length = 0 | |
self.generation_config.min_new_tokens = None | |
for processor in self.logits_processor: | |
if isinstance(processor, MinLengthLogitsProcessor): | |
raise ValueError( | |
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. " | |
"Please pass in `min_length` into `.generate()` instead" | |
) | |
# We need to roll back the cache in assisted generation, only DynamicCache is supported | |
self.generation_config.cache_implementation = None | |
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: | |
""" | |
Fetches the candidates to be tried for the current input. | |
Args: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) | |
Return: | |
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be | |
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length, | |
vocabulary_size)` containing the logits associated to each candidate. | |
""" | |
input_ids = input_ids.to(self.assistant_model.device) | |
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token. | |
new_cur_len = input_ids.shape[-1] | |
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) | |
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) | |
if max_new_tokens == 0: | |
return input_ids, None | |
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length | |
# (which implicitly contains the number of accepted candidates from the previous round) | |
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None | |
if has_past_key_values: | |
new_cache_size = new_cur_len - 1 | |
self.assistant_kwargs["past_key_values"] = _crop_past_key_values( | |
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 | |
) # the assistant does not have the token after the last match, hence the -1 | |
self.assistant_kwargs = _prepare_attention_mask( | |
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder | |
) | |
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) | |
# 2. Forecast next N tokens using the assistant model. | |
assistant_generation_kwargs = { | |
self.input_ids_key: input_ids, | |
"min_new_tokens": min_new_tokens, | |
"max_new_tokens": max_new_tokens, | |
"generation_config": self.generation_config, | |
"logits_processor": self.logits_processor, | |
} | |
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs) | |
# 3. Update variables for the next round of candidate generation | |
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values | |
# 4. Prepare variables for output | |
candidate_logits = torch.stack(assistant_output.scores, dim=1) | |
candidate_ids = assistant_output.sequences | |
return candidate_ids, candidate_logits | |
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): | |
""" | |
Updates the candidate generation strategy based on the outcomes. | |
Args: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) | |
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): | |
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using | |
beam search or log softmax for each vocabulary token when using beam search | |
num_matches (`int`): | |
The number of matches between the candidate sequences and the model predictions. | |
""" | |
# Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, | |
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the | |
# cost of forecasting incorrect assistant tokens. | |
if self.assistant_model.generation_config.num_assistant_tokens_schedule in { | |
"heuristic", | |
"heuristic_transient", | |
}: | |
if num_matches == int(self.num_assistant_tokens): | |
self.num_assistant_tokens += 2.0 | |
else: | |
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) | |
class PromptLookupCandidateGenerator(CandidateGenerator): | |
""" | |
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up | |
likely continuations in the provided prompt (input_ids) itself. | |
Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding | |
Args: | |
max_matching_ngram_size (`int`): | |
The maximum ngram size to be considered for matching in the prompt | |
num_output_tokens (`int`): | |
The number of tokens to be output as candidate tokens. | |
max_length (`int`): | |
The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length. | |
Defaults to 20, which is the max length used as default in generation config. | |
""" | |
def __init__( | |
self, | |
eos_token_id: torch.Tensor = None, | |
num_output_tokens: int = 10, | |
max_matching_ngram_size: int = None, | |
max_length: int = 20, | |
): | |
self.num_output_tokens = num_output_tokens | |
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 | |
self.max_length = max_length | |
self.eos_token_id = eos_token_id | |
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: | |
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") | |
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: | |
""" | |
Fetches the candidates to be tried for the current input. | |
Args: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) | |
Return: | |
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. | |
""" | |
input_length = input_ids.size(1) | |
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token. | |
if self.max_length == input_length + 1: | |
return input_ids, None | |
chosen_ids = None | |
match_found = False | |
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1): | |
# Create sliding windows of size ngram_size | |
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) | |
# Convert ngram to a tensor for comparison | |
ngram_tensor = input_ids[0, -ngram_size:] | |
# Find where the windows match the ngram | |
matches = (windows == ngram_tensor).all(dim=2) | |
# Get the indices of matches | |
match_indices = matches.nonzero(as_tuple=True)[1] | |
# Iterate through match indices to find a valid continuation | |
for idx in match_indices: | |
start_idx = idx + ngram_size | |
end_idx = start_idx + self.num_output_tokens | |
end_idx = min(end_idx, input_length, self.max_length) | |
if start_idx < end_idx: | |
chosen_ids = input_ids[0, start_idx:end_idx] | |
match_found = True | |
# remove remaining candidate ids if an "eos" token is found, otherwise the target model may | |
# accept eos and the rest as valid, thus not stopping generation after "eos" | |
# NOTE: below code is written based on the fact that assisted decoding supports only bs=1 | |
mask = torch.isin(chosen_ids, self.eos_token_id) | |
match_indices_eos = torch.nonzero(mask) | |
if match_indices_eos.numel() > 0: | |
first_eos_index = match_indices_eos[0].item() | |
chosen_ids = chosen_ids[:first_eos_index] | |
break | |
if match_found: | |
break | |
if chosen_ids is None or len(chosen_ids) == 0: | |
# In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding | |
return input_ids, None | |
# Now need extend input_ids with chosen_ids | |
chosen_ids = chosen_ids.unsqueeze(0) | |
candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1) | |
# assisted_generation expects logits as well, but we don't have those here, so returning None | |
return candidate_input_ids, None | |
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): | |
""" | |
Updates the candidate generation strategy based on the outcomes. | |
Args: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) | |
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): | |
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using | |
beam search or log softmax for each vocabulary token when using beam search | |
num_matches (`int`): | |
The number of matches between the candidate sequences and the model predictions. | |
""" | |
# Currently does nothing | |
return | |
def _crop_past_key_values(model, past_key_values, max_length): | |
"""Crops the past key values up to a certain maximum length.""" | |
new_past = [] | |
if model.config.is_encoder_decoder: | |
for idx in range(len(past_key_values)): | |
new_past.append( | |
( | |
past_key_values[idx][0][:, :, :max_length, :], | |
past_key_values[idx][1][:, :, :max_length, :], | |
past_key_values[idx][2], | |
past_key_values[idx][3], | |
) | |
) | |
past_key_values = tuple(new_past) | |
# gptbigcode is special and stores kv in shape (batch_size, seq_len, dim), if it's a multi_query model | |
elif "gptbigcode" in model.__class__.__name__.lower() or ( | |
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower() | |
): | |
if model.config.multi_query: | |
for idx in range(len(past_key_values)): | |
past_key_values[idx] = past_key_values[idx][:, :max_length, :] | |
else: | |
for idx in range(len(past_key_values)): | |
past_key_values[idx] = past_key_values[idx][:, :, :max_length, :] | |
elif isinstance(past_key_values, DynamicCache): | |
past_key_values.crop(max_length) | |
elif past_key_values is not None: | |
for idx in range(len(past_key_values)): | |
new_past.append( | |
( | |
past_key_values[idx][0][:, :, :max_length, :], | |
past_key_values[idx][1][:, :, :max_length, :], | |
) | |
) | |
past_key_values = tuple(new_past) | |
return past_key_values | |
def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]: | |
"""Expands or crops the model's mask for decoding purposes, to the defined length""" | |
mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask" | |
if mask_key not in model_kwargs: | |
return model_kwargs | |
mask = model_kwargs[mask_key] | |
mask_length_diff = new_length - mask.shape[1] | |
if mask_length_diff < 0: | |
model_kwargs[mask_key] = mask[:, :mask_length_diff] | |
elif mask_length_diff > 0: | |
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1) | |
return model_kwargs | |
def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]: | |
"""Expands or crops the model's token_type_ids for decoding purposes, to the defined length""" | |
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None: | |
return model_kwargs | |
token_type_ids = model_kwargs["token_type_ids"] | |
final_token_type = token_type_ids[:, -1].unsqueeze(-1) | |
type_length_diff = new_length - token_type_ids.shape[1] | |
if type_length_diff < 0: | |
token_type_ids = token_type_ids[:, :type_length_diff] | |
elif type_length_diff > 0: | |
token_type_copies = final_token_type.repeat(1, type_length_diff) | |
model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1) | |
return model_kwargs | |