Spaces:
Runtime error
Runtime error
<!--Copyright 2023 The HuggingFace Team. All rights reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
the License. You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
specific language governing permissions and limitations under the License. | |
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
rendered properly in your Markdown viewer. | |
--> | |
# Text generation strategies | |
Text generation is essential to many NLP tasks, such as open-ended text generation, summarization, translation, and | |
more. It also plays a role in a variety of mixed-modality applications that have text as an output like speech-to-text | |
and vision-to-text. Some of the models that can generate text include | |
GPT2, XLNet, OpenAI GPT, CTRL, TransformerXL, XLM, Bart, T5, GIT, Whisper. | |
Check out a few examples that use [`~transformers.generation_utils.GenerationMixin.generate`] method to produce | |
text outputs for different tasks: | |
* [Text summarization](./tasks/summarization#inference) | |
* [Image captioning](./model_doc/git#transformers.GitForCausalLM.forward.example) | |
* [Audio transcription](./model_doc/whisper#transformers.WhisperForConditionalGeneration.forward.example) | |
Note that the inputs to the generate method depend on the model's modality. They are returned by the model's preprocessor | |
class, such as AutoTokenizer or AutoProcessor. If a model's preprocessor creates more than one kind of input, pass all | |
the inputs to generate(). You can learn more about the individual model's preprocessor in the corresponding model's documentation. | |
The process of selecting output tokens to generate text is known as decoding, and you can customize the decoding strategy | |
that the `generate()` method will use. Modifying a decoding strategy does not change the values of any trainable parameters. | |
However, it can have a noticeable impact on the quality of the generated output. It can help reduce repetition in the text | |
and make it more coherent. | |
This guide describes: | |
* default generation configuration | |
* common decoding strategies and their main parameters | |
* saving and sharing custom generation configurations with your fine-tuned model on 🤗 Hub | |
## Default text generation configuration | |
A decoding strategy for a model is defined in its generation configuration. When using pre-trained models for inference | |
within a [`pipeline`], the models call the `PreTrainedModel.generate()` method that applies a default generation | |
configuration under the hood. The default configuration is also used when no custom configuration has been saved with | |
the model. | |
When you load a model explicitly, you can inspect the generation configuration that comes with it through | |
`model.generation_config`: | |
```python | |
>>> from transformers import AutoModelForCausalLM | |
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") | |
>>> model.generation_config | |
GenerationConfig { | |
"_from_model_config": true, | |
"bos_token_id": 50256, | |
"eos_token_id": 50256, | |
"transformers_version": "4.26.0.dev0" | |
} | |
``` | |
Printing out the `model.generation_config` reveals only the values that are different from the default generation | |
configuration, and does not list any of the default values. | |
The default generation configuration limits the size of the output combined with the input prompt to a maximum of 20 | |
tokens to avoid running into resource limitations. The default decoding strategy is greedy search, which is the simplest decoding strategy that picks a token with the highest probability as the next token. For many tasks | |
and small output sizes this works well. However, when used to generate longer outputs, greedy search can start | |
producing highly repetitive results. | |
## Customize text generation | |
You can override any `generation_config` by passing the parameters and their values directly to the [`generate`] method: | |
```python | |
>>> my_model.generate(**inputs, num_beams=4, do_sample=True) | |
``` | |
Even if the default decoding strategy mostly works for your task, you can still tweak a few things. Some of the | |
commonly adjusted parameters include: | |
- `max_new_tokens`: the maximum number of tokens to generate. In other words, the size of the output sequence, not | |
including the tokens in the prompt. | |
- `num_beams`: by specifying a number of beams higher than 1, you are effectively switching from greedy search to | |
beam search. This strategy evaluates several hypotheses at each time step and eventually chooses the hypothesis that | |
has the overall highest probability for the entire sequence. This has the advantage of identifying high-probability | |
sequences that start with a lower probability initial tokens and would've been ignored by the greedy search. | |
- `do_sample`: if set to `True`, this parameter enables decoding strategies such as multinomial sampling, beam-search | |
multinomial sampling, Top-K sampling and Top-p sampling. All these strategies select the next token from the probability | |
distribution over the entire vocabulary with various strategy-specific adjustments. | |
- `num_return_sequences`: the number of sequence candidates to return for each input. This options is only available for | |
the decoding strategies that support multiple sequence candidates, e.g. variations of beam search and sampling. Decoding | |
strategies like greedy search and contrastive search return a single output sequence. | |
## Save a custom decoding strategy with your model | |
If you would like to share your fine-tuned model with a specific generation configuration, you can: | |
* Create a [`GenerationConfig`] class instance | |
* Specify the decoding strategy parameters | |
* Save your generation configuration with [`GenerationConfig.save_pretrained`], making sure to leave its `config_file_name` argument empty | |
* Set `push_to_hub` to `True` to upload your config to the model's repo | |
```python | |
>>> from transformers import AutoModelForCausalLM, GenerationConfig | |
>>> model = AutoModelForCausalLM.from_pretrained("my_account/my_model") | |
>>> generation_config = GenerationConfig( | |
... max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id | |
... ) | |
>>> generation_config.save_pretrained("my_account/my_model", push_to_hub=True) | |
``` | |
You can also store several generation configurations in a single directory, making use of the `config_file_name` | |
argument in [`GenerationConfig.save_pretrained`]. You can later instantiate them with [`GenerationConfig.from_pretrained`]. This is useful if you want to | |
store several generation configurations for a single model (e.g. one for creative text generation with sampling, and | |
one for summarization with beam search). You must have the right Hub permissions to add configuration files to a model. | |
```python | |
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig | |
>>> tokenizer = AutoTokenizer.from_pretrained("t5-small") | |
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") | |
>>> translation_generation_config = GenerationConfig( | |
... num_beams=4, | |
... early_stopping=True, | |
... decoder_start_token_id=0, | |
... eos_token_id=model.config.eos_token_id, | |
... pad_token=model.config.pad_token_id, | |
... ) | |
>>> translation_generation_config.save_pretrained("t5-small", "translation_generation_config.json", push_to_hub=True) | |
>>> # You could then use the named generation config file to parameterize generation | |
>>> generation_config = GenerationConfig.from_pretrained("t5-small", "translation_generation_config.json") | |
>>> inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt") | |
>>> outputs = model.generate(**inputs, generation_config=generation_config) | |
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |
['Les fichiers de configuration sont faciles à utiliser !'] | |
``` | |
## Streaming | |
The `generate()` supports streaming, through its `streamer` input. The `streamer` input is compatible any instance | |
from a class that has the following methods: `put()` and `end()`. Internally, `put()` is used to push new tokens and | |
`end()` is used to flag the end of text generation. | |
<Tip warning={true}> | |
The API for the streamer classes is still under development and may change in the future. | |
</Tip> | |
In practice, you can craft your own streaming class for all sorts of purposes! We also have basic streaming classes | |
ready for you to use. For example, you can use the [`TextStreamer`] class to stream the output of `generate()` into | |
your screen, one word at a time: | |
```python | |
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer | |
>>> tok = AutoTokenizer.from_pretrained("gpt2") | |
>>> model = AutoModelForCausalLM.from_pretrained("gpt2") | |
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") | |
>>> streamer = TextStreamer(tok) | |
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout. | |
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20) | |
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, | |
``` | |
## Decoding strategies | |
Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific | |
decoding strategies. If you are new to this concept, we recommend reading [this blog post that illustrates how common decoding strategies work](https://huggingface.co/blog/how-to-generate). | |
Here, we'll show some of the parameters that control the decoding strategies and illustrate how you can use them. | |
### Greedy Search | |
[`generate`] uses greedy search decoding by default so you don't have to pass any parameters to enable it. This means the parameters `num_beams` is set to 1 and `do_sample=False`. | |
```python | |
>>> from transformers import AutoModelForCausalLM, AutoTokenizer | |
>>> prompt = "I look forward to" | |
>>> checkpoint = "distilgpt2" | |
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) | |
>>> outputs = model.generate(**inputs) | |
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
['I look forward to seeing you all again!\n\n\n\n\n\n\n\n\n\n\n'] | |
``` | |
### Contrastive search | |
The contrastive search decoding strategy was proposed in the 2022 paper [A Contrastive Framework for Neural Text Generation](https://arxiv.org/abs/2202.06417). | |
It demonstrates superior results for generating non-repetitive yet coherent long outputs. To learn how contrastive search | |
works, check out [this blog post](https://huggingface.co/blog/introducing-csearch). | |
The two main parameters that enable and control the behavior of contrastive search are `penalty_alpha` and `top_k`: | |
```python | |
>>> from transformers import AutoTokenizer, AutoModelForCausalLM | |
>>> checkpoint = "gpt2-large" | |
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) | |
>>> prompt = "Hugging Face Company is" | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> outputs = model.generate(**inputs, penalty_alpha=0.6, top_k=4, max_new_tokens=100) | |
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
['Hugging Face Company is a family owned and operated business. \ | |
We pride ourselves on being the best in the business and our customer service is second to none.\ | |
\n\nIf you have any questions about our products or services, feel free to contact us at any time.\ | |
We look forward to hearing from you!'] | |
``` | |
### Multinomial sampling | |
As opposed to greedy search that always chooses a token with the highest probability as the | |
next token, multinomial sampling (also called ancestral sampling) randomly selects the next token based on the probability distribution over the entire | |
vocabulary given by the model. Every token with a non-zero probability has a chance of being selected, thus reducing the | |
risk of repetition. | |
To enable multinomial sampling set `do_sample=True` and `num_beams=1`. | |
```python | |
>>> from transformers import AutoTokenizer, AutoModelForCausalLM | |
>>> checkpoint = "gpt2-large" | |
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) | |
>>> prompt = "Today was an amazing day because" | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> outputs = model.generate(**inputs, do_sample=True, num_beams=1, max_new_tokens=100) | |
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
['Today was an amazing day because we are now in the final stages of our trip to New York City which was very tough. \ | |
It is a difficult schedule and a challenging part of the year but still worth it. I have been taking things easier and \ | |
I feel stronger and more motivated to be out there on their tour. Hopefully, that experience is going to help them with \ | |
their upcoming events which are currently scheduled in Australia.\n\nWe love that they are here. They want to make a \ | |
name for themselves and become famous for what they'] | |
``` | |
### Beam-search decoding | |
Unlike greedy search, beam-search decoding keeps several hypotheses at each time step and eventually chooses | |
the hypothesis that has the overall highest probability for the entire sequence. This has the advantage of identifying high-probability | |
sequences that start with lower probability initial tokens and would've been ignored by the greedy search. | |
To enable this decoding strategy, specify the `num_beams` (aka number of hypotheses to keep track of) that is greater than 1. | |
```python | |
>>> from transformers import AutoModelForCausalLM, AutoTokenizer | |
>>> prompt = "It is astonishing how one can" | |
>>> checkpoint = "gpt2-medium" | |
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) | |
>>> outputs = model.generate(**inputs, num_beams=5, max_new_tokens=50) | |
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
['It is astonishing how one can have such a profound impact on the lives of so many people in such a short period of \ | |
time."\n\nHe added: "I am very proud of the work I have been able to do in the last few years.\n\n"I have'] | |
``` | |
### Beam-search multinomial sampling | |
As the name implies, this decoding strategy combines beam search with multinomial sampling. You need to specify | |
the `num_beams` greater than 1, and set `do_sample=True` to use this decoding strategy. | |
```python | |
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
>>> prompt = "translate English to German: The house is wonderful." | |
>>> checkpoint = "t5-small" | |
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) | |
>>> outputs = model.generate(**inputs, num_beams=5, do_sample=True) | |
>>> tokenizer.decode(outputs[0], skip_special_tokens=True) | |
'Das Haus ist wunderbar.' | |
``` | |
### Diverse beam search decoding | |
The diverse beam search decoding strategy is an extension of the beam search strategy that allows for generating a more diverse | |
set of beam sequences to choose from. To learn how it works, refer to [Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence Models](https://arxiv.org/pdf/1610.02424.pdf). | |
This approach has three main parameters: `num_beams`, `num_beam_groups`, and `diversity_penalty`. | |
The diversily penalty ensures the outputs are distinct across groups, and beam search is used within each group. | |
```python | |
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
>>> checkpoint = "google/pegasus-xsum" | |
>>> prompt = "The Permaculture Design Principles are a set of universal design principles \ | |
>>> that can be applied to any location, climate and culture, and they allow us to design \ | |
>>> the most efficient and sustainable human habitation and food production systems. \ | |
>>> Permaculture is a design system that encompasses a wide variety of disciplines, such \ | |
>>> as ecology, landscape design, environmental science and energy conservation, and the \ | |
>>> Permaculture design principles are drawn from these various disciplines. Each individual \ | |
>>> design principle itself embodies a complete conceptual framework based on sound \ | |
>>> scientific principles. When we bring all these separate principles together, we can \ | |
>>> create a design system that both looks at whole systems, the parts that these systems \ | |
>>> consist of, and how those parts interact with each other to create a complex, dynamic, \ | |
>>> living system. Each design principle serves as a tool that allows us to integrate all \ | |
>>> the separate parts of a design, referred to as elements, into a functional, synergistic, \ | |
>>> whole system, where the elements harmoniously interact and work together in the most \ | |
>>> efficient way possible." | |
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) | |
>>> outputs = model.generate(**inputs, num_beams=5, num_beam_groups=5, max_new_tokens=30, diversity_penalty=1.0) | |
>>> tokenizer.decode(outputs[0], skip_special_tokens=True) | |
'The aim of this project is to create a new type of living system, one that is more sustainable and efficient than the current one.' | |
``` | |
This guide illustrates the main parameters that enable various decoding strategies. More advanced parameters exist for the | |
[`generate`] method, which gives you even further control over the [`generate`] method's behavior. | |
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.md). | |
### Assisted Decoding | |
Assisted decoding is a modification of the decoding strategies above that uses an assistant model with the same | |
tokenizer (ideally a much smaller model) to greedily generate a few candidate tokens. The main model then validates | |
the candidate tokens in a single forward pass, which speeds up the decoding process. Currently, only greedy search | |
and sampling are supported with assisted decoding, and doesn't support batched inputs. To learn more about assisted | |
decoding, check [this blog post](https://huggingface.co/blog/assisted-generation). | |
To enable assisted decoding, set the `assistant_model` argument with a model. | |
```python | |
>>> from transformers import AutoModelForCausalLM, AutoTokenizer | |
>>> prompt = "Alice and Bob" | |
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped" | |
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped" | |
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) | |
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint) | |
>>> outputs = model.generate(**inputs, assistant_model=assistant_model) | |
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] | |
``` | |
When using assisted decoding with sampling methods, you can use the `temperarure` argument to control the randomness | |
just like in multinomial sampling. However, in assisted decoding, reducing the temperature will help improving latency. | |
```python | |
>>> from transformers import AutoModelForCausalLM, AutoTokenizer | |
>>> prompt = "Alice and Bob" | |
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped" | |
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped" | |
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) | |
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint) | |
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5) | |
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
["Alice and Bob are sitting on the sofa. Alice says, 'I'm going to my room"] | |
``` | |