TRL documentation
SFT Trainer
SFT Trainer
Overview
TRL supports the Supervised Fine-Tuning (SFT) Trainer for training language models.
This post-training method was contributed by Younes Belkada.
Quick start
This example demonstrates how to train a language model using the SFTTrainer from TRL. We train a Qwen 3 0.6B model on the Capybara dataset, a compact, diverse multi-turn dataset to benchmark reasoning and generalization.
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
trainer = SFTTrainer(
model="Qwen/Qwen3-0.6B",
train_dataset=load_dataset("trl-lib/Capybara", split="train"),
)
trainer.train()
Expected dataset type and format
SFT supports both language modeling and prompt-completion datasets. The SFTTrainer is compatible with both standard and conversational dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
# Standard language modeling
{"text": "The sky is blue."}
# Conversational language modeling
{"messages": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}]}
# Standard prompt-completion
{"prompt": "The sky is",
"completion": " blue."}
# Conversational prompt-completion
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}]}
If your dataset is not in one of these formats, you can preprocess it to convert it into the expected format. Here is an example with the FreedomIntelligence/medical-o1-reasoning-SFT dataset:
from datasets import load_dataset
dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en")
def preprocess_function(example):
return {
"prompt": [{"role": "user", "content": example["Question"]}],
"completion": [
{"role": "assistant", "content": f"<think>{example['Complex_CoT']}</think>{example['Response']}"}
],
}
dataset = dataset.map(preprocess_function, remove_columns=["Question", "Response", "Complex_CoT"])
print(next(iter(dataset["train"])))
{
"prompt": [
{
"content": "Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?",
"role": "user",
}
],
"completion": [
{
"content": "<think>Okay, let's see what's going on here. We've got sudden weakness [...] clicks into place!</think>The specific cardiac abnormality most likely to be found in [...] the presence of a PFO facilitating a paradoxical embolism.",
"role": "assistant",
}
],
}
Looking deeper into the SFT method
Supervised Fine-Tuning (SFT) is the simplest and most commonly used method to adapt a language model to a target dataset. The model is trained in a fully supervised fashion using pairs of input and output sequences. The goal is to minimize the negative log-likelihood (NLL) of the target sequence, conditioning on the input.
This section breaks down how SFT works in practice, covering the key steps: preprocessing, tokenization and loss computation.
Preprocessing and tokenization
During training, each example is expected to contain a text field or a (prompt, completion) pair, depending on the dataset format. For more details on the expected formats, see Dataset formats.
The SFTTrainer
tokenizes each input using the model’s tokenizer. If both prompt and completion are provided separately, they are concatenated before tokenization.
Computing the loss
The loss used in SFT is the token-level cross-entropy loss, defined as:
where is the target token at timestep , and the model is trained to predict the next token given the previous ones. In practice, padding tokens are masked out during loss computation.
Label shifting and masking
During training, the loss is computed using a one-token shift: the model is trained to predict each token in the sequence based on all previous tokens. Specifically, the input sequence is shifted right by one position to form the target labels.
Padding tokens (if present) are ignored in the loss computation by applying an ignore index (default: -100
) to the corresponding positions. This ensures that the loss focuses only on meaningful, non-padding tokens.
Logged metrics
global_step
: The total number of optimizer steps taken so far.epoch
: The current epoch number, based on dataset iteration.num_tokens
: The total number of tokens processed so far.loss
: The average cross-entropy loss computed over non-masked tokens in the current logging interval.mean_token_accuracy
: The proportion of non-masked tokens for which the model’s top-1 prediction matches the ground truth token.learning_rate
: The current learning rate, which may change dynamically if a scheduler is used.grad_norm
: The L2 norm of the gradients, computed before gradient clipping.
Customization
Model initialization
You can directly pass the kwargs of the from_pretrained()
method to the SFTConfig. For example, if you want to load a model in a different precision, analogous to
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", torch_dtype=torch.bfloat16)
you can do so by passing the model_init_kwargs={"torch_dtype": torch.bfloat16}
argument to the SFTConfig.
from trl import SFTConfig
training_args = SFTConfig(
model_init_kwargs={"torch_dtype": torch.bfloat16},
)
Note that all keyword arguments of from_pretrained()
are supported.
Packing
SFTTrainer supports example packing, where multiple examples are packed in the same input sequence to increase training efficiency. To enable packing, simply pass packing=True
to the SFTConfig constructor.
training_args = SFTConfig(packing=True)
For more details on packing, see Packing.
Train on assistant messages only
To train on assistant messages only, use a conversational dataset and set assistant_only_loss=True
in the SFTConfig. This setting ensures that loss is computed only on the assistant responses, ignoring user or system messages.
training_args = SFTConfig(assistant_only_loss=True)
This functionality is only available for chat templates that support returning the assistant tokens mask via the {% generation %}
and {% endgeneration %}
keywords. For an example of such a template, see HugggingFaceTB/SmolLM3-3B.
Train on completion only
To train on completion only, use a prompt-completion dataset. By default, the trainer computes the loss on the completion tokens only, ignoring the prompt tokens. If you want to train on the full sequence, set completion_only_loss=False
in the SFTConfig.
Train adapters with PEFT
We support tight integration with 🤗 PEFT library, allowing any user to conveniently train adapters and share them on the Hub, rather than training the entire model.
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
"Qwen/Qwen3-0.6B",
train_dataset=dataset,
peft_config=LoraConfig()
)
trainer.train()
You can also continue training your peft.PeftModel
. For that, first load a PeftModel
outside SFTTrainer and pass it directly to the trainer without the peft_config
argument being passed.
from datasets import load_dataset
from trl import SFTTrainer
from peft import AutoPeftModelForCausalLM
model = AutoPeftModelForCausalLM.from_pretrained("trl-lib/Qwen3-4B-LoRA", is_trainable=True)
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
)
trainer.train()
When training adapters, you typically use a higher learning rate (≈1e‑4) since only new parameters are being learned.
SFTConfig(learning_rate=1e-4, ...)
Train with Liger Kernel
Liger Kernel is a collection of Triton kernels for LLM training that boosts multi-GPU throughput by 20%, cuts memory use by 60% (enabling up to 4× longer context), and works seamlessly with tools like Flash Attention, PyTorch FSDP, and DeepSpeed. For more information, see Liger Kernel Integration.
Train with Unsloth
Unsloth is an open‑source framework for fine‑tuning and reinforcement learning that trains LLMs (like Llama, Mistral, Gemma, DeepSeek, and more) up to 2× faster with up to 70% less VRAM, while providing a streamlined, Hugging Face–compatible workflow for training, evaluation, and deployment. For more information, see Unsloth Integration.
Instruction tuning example
Instruction tuning teaches a base language model to follow user instructions and engage in conversations. This requires:
- Chat template: Defines how to structure conversations into text sequences, including role markers (user/assistant), special tokens, and turn boundaries. Read more about chat templates in Chat templates.
- Conversational dataset: Contains instruction-response pairs
This example shows how to transform the Qwen 3 0.6B Base model into an instruction-following model using the Capybara dataset and a chat template from HuggingFaceTB/SmolLM3-3B. The SFT Trainer automatically handles tokenizer updates and special token configuration.
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
trainer = SFTTrainer(
model="Qwen/Qwen3-0.6B-Base",
args=SFTConfig(
output_dir="Qwen3-0.6B-Instruct",
chat_template_path="HuggingFaceTB/SmolLM3-3B",
),
train_dataset=load_dataset("trl-lib/Capybara", split="train"),
)
trainer.train()
Some base models, like those from Qwen, have a predefined chat template in the model’s tokenizer. In these cases, it is not necessary to apply clone_chat_template()
, as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model’s responses terminate correctly. In these cases, specify eos_token
in SFTConfig; for example, for Qwen/Qwen2.5-1.5B
, one should set eos_token="<|im_end|>"
.
Once trained, your model can now follow instructions and engage in conversations using its new chat template.
>>> from transformers import pipeline
>>> pipe = pipeline("text-generation", model="Qwen3-0.6B-Instruct/checkpoint-5000")
>>> prompt = "<|im_start|>user\nWhat is the capital of France? Answer in one word.<|im_end|>\n<|im_start|>assistant\n"
>>> response = pipe(prompt)
>>> response[0]["generated_text"]
'<|im_start|>user\nWhat is the capital of France? Answer in one word.<|im_end|>\n<|im_start|>assistant\nThe capital of France is Paris.'
Alternatively, use the structured conversation format (recommended):
>>> prompt = [{"role": "user", "content": "What is the capital of France? Answer in one word."}]
>>> response = pipe(prompt)
>>> response[0]["generated_text"]
[{'role': 'user', 'content': 'What is the capital of France? Answer in one word.'}, {'role': 'assistant', 'content': 'The capital of France is Paris.'}]
Tool Calling with SFT
The SFT trainer fully supports fine-tuning models with tool calling capabilities. In this case, each dataset example should include:
- The conversation messages, including any tool calls (
tool_calls
) and tool responses (tool
role messages) - The list of available tools in the
tools
column, typically provided as JSON schemas
For details on the expected dataset structure, see the Dataset Format — Tool Calling section.
Extending SFTTrainer for Vision Language Models
SFTTrainer
does not yet inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script examples/scripts/sft_vlm.py
, which demonstrates how to fine-tune the LLaVA 1.5 model on the HuggingFaceH4/llava-instruct-mix-vsft dataset.
Preparing the Data
The data format is flexible, provided it is compatible with the custom collator that we will define later. A common approach is to use conversational data. Given that the data includes both text and images, the format needs to be adjusted accordingly. Below is an example of a conversational data format involving both text and images:
images = ["obama.png"]
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Who is this?"},
{"type": "image"}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Barack Obama"}
]
},
{
"role": "user",
"content": [
{"type": "text", "text": "What is he famous for?"}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "He is the 44th President of the United States."}
]
}
]
To illustrate how this data format will be processed using the LLaVA model, you can use the following code:
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
print(processor.apply_chat_template(messages, tokenize=False))
The output will be formatted as follows:
Who is this? ASSISTANT: Barack Obama USER: What is he famous for? ASSISTANT: He is the 44th President of the United States.
A custom collator for processing multi-modal data
Unlike the default behavior of SFTTrainer, processing multi-modal data is done on the fly during the data collation process. To do this, you need to define a custom collator that processes both the text and images. This collator must take a list of examples as input (see the previous section for an example of the data format) and return a batch of processed data. Below is an example of such a collator:
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
images = [example["images"][0] for example in examples]
# Tokenize the texts and process the images
batch = processor(images=images, text=texts, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch
We can verify that the collator works as expected by running the following code:
from datasets import load_dataset
dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft", split="train")
examples = [dataset[0], dataset[1]] # Just two examples for the sake of the example
collated_data = collate_fn(examples)
print(collated_data.keys()) # dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'labels'])
Training the vision-language model
Now that we have prepared the data and defined the collator, we can proceed with training the model. To ensure that the data is not processed as text-only, we need to set a couple of arguments in the SFTConfig, specifically remove_unused_columns
and skip_prepare_dataset
to True
to avoid the default processing of the dataset. Below is an example of how to set up the SFTTrainer
.
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=train_dataset,
processing_class=processor,
)
A full example of training LLaVa 1.5 on the HuggingFaceH4/llava-instruct-mix-vsft dataset can be found in the script examples/scripts/sft_vlm.py
.
SFTTrainer
class trl.SFTTrainer
< source >( model: typing.Union[str, torch.nn.modules.module.Module, transformers.modeling_utils.PreTrainedModel] args: typing.Union[trl.trainer.sft_config.SFTConfig, transformers.training_args.TrainingArguments, NoneType] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None train_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, NoneType] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None compute_loss_func: typing.Optional[typing.Callable] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], dict]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) optimizer_cls_and_kwargs: typing.Optional[tuple[type[torch.optim.optimizer.Optimizer], dict[str, typing.Any]]] = None preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None peft_config: typing.Optional[ForwardRef('PeftConfig')] = None formatting_func: typing.Optional[typing.Callable[[dict], str]] = None )
Parameters
- model (
Union[str, PreTrainedModel]
) — Model to be trained. Can be either:- A string, being the model id of a pretrained model hosted inside a model repo on huggingface.co, or a
path to a directory containing model weights saved using
save_pretrained
, e.g.,'./my_model_directory/'
. The model is loaded usingfrom_pretrained
with the keyword arguments inargs.model_init_kwargs
. - A
PreTrainedModel
object. Only causal language models are supported.
- A string, being the model id of a pretrained model hosted inside a model repo on huggingface.co, or a
path to a directory containing model weights saved using
- args (SFTConfig, optional, defaults to
None
) — Configuration for this trainer. IfNone
, a default configuration is used. - data_collator (
DataCollator
, optional) — Function to use to form a batch from a list of elements of the processedtrain_dataset
oreval_dataset
. Will default to a customDataCollatorForLanguageModeling
. - train_dataset (Dataset or IterableDataset) —
Dataset to use for training. SFT supports both language modeling type and
prompt-completion type. The format of the samples can be either:
- Standard: Each sample contains plain text.
- Conversational: Each sample contains structured messages (e.g., role and content).
The trainer also supports processed datasets (tokenized) as long as they contain an
input_ids
field. - eval_dataset (Dataset, IterableDataset or
dict[str, Union[Dataset, IterableDataset]]
) — Dataset to use for evaluation. It must meet the same requirements astrain_dataset
. - processing_class (
PreTrainedTokenizerBase
,BaseImageProcessor
,FeatureExtractionMixin
orProcessorMixin
, optional, defaults toNone
) — Processing class used to process the data. IfNone
, the processing class is loaded from the model’s name withfrom_pretrained
. - callbacks (list of
TrainerCallback
, optional, defaults toNone
) — List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in here.If you want to remove one of the default callbacks used, use the
remove_callback
method. - optimizers (
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
, optional, defaults to(None, None)
) — A tuple containing the optimizer and the scheduler to use. Will default to an instance ofAdamW
on your model and a scheduler given byget_linear_schedule_with_warmup
controlled byargs
. - optimizer_cls_and_kwargs (
Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]
, optional, defaults toNone
) — A tuple containing the optimizer class and keyword arguments to use. Overridesoptim
andoptim_args
inargs
. Incompatible with theoptimizers
argument.Unlike
optimizers
, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer. - preprocess_logits_for_metrics (
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
, optional, defaults toNone
) — A function that preprocess the logits right before caching them at each evaluation step. Must take two tensors, the logits and the labels, and return the logits once processed as desired. The modifications made by this function will be reflected in the predictions received bycompute_metrics
.Note that the labels (second parameter) will be
None
if the dataset does not have them. - peft_config (
~peft.PeftConfig
, optional, defaults toNone
) — PEFT configuration used to wrap the model. IfNone
, the model is not wrapped. - formatting_func (
Optional[Callable]
) — Formatting function applied to the dataset before tokenization. Applying the formatting function explicitly converts the dataset into a language modeling type.
Trainer for Supervised Fine-Tuning (SFT) method.
This class is a wrapper around the transformers.Trainer
class and inherits all of its attributes and methods.
Example:
from datasets import load_dataset
from trl import SFTTrainer
dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
trainer.train()
Compute training loss and additionally compute token accuracies
create_model_card
< source >( model_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None tags: typing.Union[str, list[str], NoneType] = None )
Creates a draft of a model card using the information available to the Trainer
.
SFTConfig
class trl.SFTConfig
< source >( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 2e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: bool = True model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None chat_template_path: typing.Optional[str] = None dataset_text_field: str = 'text' dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None dataset_num_proc: typing.Optional[int] = None eos_token: typing.Optional[str] = None pad_token: typing.Optional[str] = None max_length: typing.Optional[int] = 1024 packing: bool = False packing_strategy: str = 'bfd' padding_free: bool = False pad_to_multiple_of: typing.Optional[int] = None eval_packing: typing.Optional[bool] = None completion_only_loss: typing.Optional[bool] = None assistant_only_loss: bool = False activation_offloading: bool = False )
Parameters that control the model
- model_init_kwargs (
dict[str, Any]
orNone
, optional, defaults toNone
) — Keyword arguments forfrom_pretrained
, used when themodel
argument of the SFTTrainer is provided as a string. - chat_template_path (
str
orNone
, optional, defaults toNone
) — If specified, sets the model’s chat template. This can either be the path to a tokenizer (local directory or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must ensure that any special tokens referenced in the template are added to the tokenizer and that the model’s embedding layer is resized accordingly.
Parameters that control the data preprocessing
- dataset_text_field (
str
, optional, defaults to"text"
) — Name of the column that contains text data in the dataset. - dataset_kwargs (
dict[str, Any]
orNone
, optional, defaults toNone
) — Dictionary of optional keyword arguments for the dataset preparation. The only supported key isskip_prepare_dataset
. - dataset_num_proc (
int
orNone
, optional, defaults toNone
) — Number of processes to use for processing the dataset. - eos_token (
str
orNone
, optional, defaults toNone
) — Token used to indicate the end of a turn or sequence. IfNone
, it defaults toprocessing_class.eos_token
. - pad_token (
int
orNone
, optional, defaults toNone
) — Token used for padding. IfNone
, it defaults toprocessing_class.pad_token
, or if that is alsoNone
, it falls back toprocessing_class.eos_token
. - max_length (
int
orNone
, optional, defaults to1024
) — Maximum length of the tokenized sequence. Sequences longer thanmax_length
are truncated from the right. IfNone
, no truncation is applied. When packing is enabled, this value sets the sequence length. - packing (
bool
, optional, defaults toFalse
) — Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce padding. Usesmax_length
to define sequence length. - packing_strategy (
str
, optional, defaults to"bfd"
) — Strategy for packing sequences. Can be either"bfd"
(best-fit decreasing, default), or"wrapped"
. - padding_free (
bool
, optional, defaults toFalse
) — Whether to perform forward passes without padding by flattening all sequences in the batch into a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only supported with theflash_attention_2
attention implementation, which can efficiently handle the flattened batch structure. When packing is enabled with strategy"bfd"
, padding-free is enabled, regardless of the value of this parameter. - pad_to_multiple_of (
int
orNone
, optional, defaults toNone
) — If set, the sequences will be padded to a multiple of this value. - eval_packing (
bool
orNone
, optional, defaults toNone
) — Whether to pack the eval dataset. IfNone
, uses the same value aspacking
.
Parameters that control the training
- completion_only_loss (
bool
orNone
, optional, defaults toNone
) — Whether to compute loss only on the completion part of the sequence. If set toTrue
, loss is computed only on the completion, which is supported only for prompt-completion datasets. IfFalse
, loss is computed on the entire sequence. IfNone
(default), the behavior depends on the dataset: loss is computed on the completion for prompt-completion datasets, and on the full sequence for language modeling datasets. - assistant_only_loss (
bool
, optional, defaults toFalse
) — Whether to compute loss only on the assistant part of the sequence. If set toTrue
, loss is computed only on the assistant responses, which is supported only for conversational datasets. IfFalse
, loss is computed on the entire sequence. - activation_offloading (
bool
, optional, defaults toFalse
) — Whether to offload the activations to the CPU.
Configuration class for the SFTTrainer.
This class includes only the parameters that are specific to SFT training. For a full list of training arguments,
please refer to the TrainingArguments
documentation. Note that default values in this class may
differ from those in TrainingArguments
.
Using HfArgumentParser
we can turn this class into
argparse arguments that can be specified on the
command line.