""" 2025.7.4 2025.7.3 4.53.2 0.19.1 __UNSLOTH_VERSIONING__ """ from torch import Tensor import torch import torch.nn as nn from torch.nn import functional as F from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable from trl.trainer.iterative_sft_trainer import (AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, DataCollator, DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, DataLoader, Dataset, EvalLoopOutput, FeatureExtractionMixin, IterativeSFTConfig, IterativeSFTTrainer, Optional, PPODecorators, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainingArguments, Union, generate_model_card, get_comet_experiment_url, is_peft_available, is_wandb_available, os, torch, wandb, warnings, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch) import os from typing import * from dataclasses import dataclass, field from packaging.version import Version import torch import numpy as np from contextlib import nullcontext from torch.nn import functional as F from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling torch_compile_options = { "epilogue_fusion" : True, "max_autotune" : False, "shape_padding" : True, "trace.enabled" : False, "triton.cudagraphs" : False, } @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) def selective_log_softmax(logits, index): logits = logits.to(torch.float32) selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1) # loop to reduce peak mem consumption # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) logsumexp_values = torch.logsumexp(logits, dim = -1) per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) return per_token_logps @dataclass class UnslothIterativeSFTConfig(IterativeSFTConfig): """ Configuration class for the [`IterativeSFTTrainer`]. This class includes only the parameters that are specific to Iterative SFT training. For a full list of training arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may differ from those in [`~transformers.TrainingArguments`]. Using [`~transformers.HfArgumentParser`] we can turn this class into [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the command line. Parameters: > Parameters that control the model model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` argument of the [`IterativeSFTTrainer`] is provided as a string. > Parameters that control the data preprocessing max_length (`int` or `None`, *optional*, defaults to `None`): Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated. truncation_mode (`str`, *optional*, defaults to `"keep_end"`): The truncation mode to use, either `"keep_end"` or `"keep_start"`. optimize_device_cache (`bool`, *optional*, defaults to `False`): Whether to optimize accelerator cache for slightly more memory-efficient training. """ vllm_sampling_params: Optional[Any] = field( default = None, metadata = {'help': 'vLLM SamplingParams'}, ) unsloth_num_chunks : Optional[int] = field( default = -1, metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, ) def __init__( self, output_dir = None, overwrite_output_dir = None, do_train = False, do_eval = False, do_predict = False, eval_strategy = 'no', prediction_loss_only = False, per_device_train_batch_size = 4, per_device_eval_batch_size = 4, per_gpu_train_batch_size = None, per_gpu_eval_batch_size = None, gradient_accumulation_steps = 2, eval_accumulation_steps = 2, eval_delay = 0, torch_empty_cache_steps = 250, learning_rate = 5e-05, weight_decay = 0.01, adam_beta1 = 0.9, adam_beta2 = 0.999, adam_epsilon = 1e-08, max_grad_norm = 1.0, num_train_epochs = 3.0, max_steps = -1, lr_scheduler_type = 'linear', warmup_ratio = 0.1, warmup_steps = 0, log_level = 'passive', log_level_replica = 'warning', log_on_each_node = True, logging_dir = None, logging_strategy = 'steps', logging_first_step = False, logging_steps = 1, logging_nan_inf_filter = False, save_strategy = 'steps', save_steps = 500, save_total_limit = None, save_safetensors = True, save_on_each_node = False, save_only_model = False, restore_callback_states_from_checkpoint = False, no_cuda = False, use_cpu = False, use_mps_device = False, seed = 3407, data_seed = 3407, jit_mode_eval = False, use_ipex = False, bf16 = False, fp16 = False, fp16_opt_level = 'O1', half_precision_backend = 'auto', bf16_full_eval = False, fp16_full_eval = False, tf32 = None, local_rank = -1, ddp_backend = None, tpu_num_cores = None, tpu_metrics_debug = False, debug = '', dataloader_drop_last = False, eval_steps = None, dataloader_num_workers = 0, dataloader_prefetch_factor = None, past_index = -1, run_name = None, disable_tqdm = None, remove_unused_columns = True, label_names = None, load_best_model_at_end = False, metric_for_best_model = None, greater_is_better = None, ignore_data_skip = False, fsdp = '', fsdp_min_num_params = 0, fsdp_config = None, fsdp_transformer_layer_cls_to_wrap = None, accelerator_config = None, deepspeed = None, label_smoothing_factor = 0.0, optim = 'adamw_8bit', optim_args = None, adafactor = False, group_by_length = False, length_column_name = 'length', report_to = None, ddp_find_unused_parameters = None, ddp_bucket_cap_mb = None, ddp_broadcast_buffers = None, dataloader_pin_memory = True, dataloader_persistent_workers = False, skip_memory_metrics = True, use_legacy_prediction_loop = False, push_to_hub = False, resume_from_checkpoint = None, hub_model_id = None, hub_strategy = 'every_save', hub_token = None, hub_private_repo = None, hub_always_push = False, hub_revision = None, gradient_checkpointing = False, gradient_checkpointing_kwargs = None, include_inputs_for_metrics = False, eval_do_concat_batches = True, fp16_backend = 'auto', push_to_hub_model_id = None, push_to_hub_organization = None, push_to_hub_token = None, mp_parameters = '', auto_find_batch_size = False, full_determinism = False, torchdynamo = None, ray_scope = 'last', ddp_timeout = 1800, torch_compile = False, torch_compile_backend = None, torch_compile_mode = None, include_tokens_per_second = False, include_num_input_tokens_seen = False, neftune_noise_alpha = None, optim_target_modules = None, batch_eval_metrics = False, eval_on_start = False, use_liger_kernel = False, liger_kernel_config = None, eval_use_gather_object = False, average_tokens_across_devices = False, model_init_kwargs = None, max_length = None, truncation_mode = 'keep_end', optimize_device_cache = False, vllm_sampling_params = None, unsloth_num_chunks = -1, **kwargs, ): if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') if output_dir is None and save_strategy == 'steps' and save_steps == 500: output_dir = 'unsloth_training_checkpoints' save_strategy = 'no' super().__init__( output_dir = output_dir, overwrite_output_dir = overwrite_output_dir, do_train = do_train, do_eval = do_eval, do_predict = do_predict, eval_strategy = eval_strategy, prediction_loss_only = prediction_loss_only, per_device_train_batch_size = per_device_train_batch_size, per_device_eval_batch_size = per_device_eval_batch_size, per_gpu_train_batch_size = per_gpu_train_batch_size, per_gpu_eval_batch_size = per_gpu_eval_batch_size, gradient_accumulation_steps = gradient_accumulation_steps, eval_accumulation_steps = eval_accumulation_steps, eval_delay = eval_delay, torch_empty_cache_steps = torch_empty_cache_steps, learning_rate = learning_rate, weight_decay = weight_decay, adam_beta1 = adam_beta1, adam_beta2 = adam_beta2, adam_epsilon = adam_epsilon, max_grad_norm = max_grad_norm, num_train_epochs = num_train_epochs, max_steps = max_steps, lr_scheduler_type = lr_scheduler_type, warmup_ratio = warmup_ratio, warmup_steps = warmup_steps, log_level = log_level, log_level_replica = log_level_replica, log_on_each_node = log_on_each_node, logging_dir = logging_dir, logging_strategy = logging_strategy, logging_first_step = logging_first_step, logging_steps = logging_steps, logging_nan_inf_filter = logging_nan_inf_filter, save_strategy = save_strategy, save_steps = save_steps, save_total_limit = save_total_limit, save_safetensors = save_safetensors, save_on_each_node = save_on_each_node, save_only_model = save_only_model, restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, no_cuda = no_cuda, use_cpu = use_cpu, use_mps_device = use_mps_device, seed = seed, data_seed = data_seed, jit_mode_eval = jit_mode_eval, use_ipex = use_ipex, bf16 = bf16, fp16 = fp16, fp16_opt_level = fp16_opt_level, half_precision_backend = half_precision_backend, bf16_full_eval = bf16_full_eval, fp16_full_eval = fp16_full_eval, tf32 = tf32, local_rank = local_rank, ddp_backend = ddp_backend, tpu_num_cores = tpu_num_cores, tpu_metrics_debug = tpu_metrics_debug, debug = debug, dataloader_drop_last = dataloader_drop_last, eval_steps = eval_steps, dataloader_num_workers = dataloader_num_workers, dataloader_prefetch_factor = dataloader_prefetch_factor, past_index = past_index, run_name = run_name, disable_tqdm = disable_tqdm, remove_unused_columns = remove_unused_columns, label_names = label_names, load_best_model_at_end = load_best_model_at_end, metric_for_best_model = metric_for_best_model, greater_is_better = greater_is_better, ignore_data_skip = ignore_data_skip, fsdp = fsdp, fsdp_min_num_params = fsdp_min_num_params, fsdp_config = fsdp_config, fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, accelerator_config = accelerator_config, deepspeed = deepspeed, label_smoothing_factor = label_smoothing_factor, optim = optim, optim_args = optim_args, adafactor = adafactor, group_by_length = group_by_length, length_column_name = length_column_name, report_to = report_to, ddp_find_unused_parameters = ddp_find_unused_parameters, ddp_bucket_cap_mb = ddp_bucket_cap_mb, ddp_broadcast_buffers = ddp_broadcast_buffers, dataloader_pin_memory = dataloader_pin_memory, dataloader_persistent_workers = dataloader_persistent_workers, skip_memory_metrics = skip_memory_metrics, use_legacy_prediction_loop = use_legacy_prediction_loop, push_to_hub = push_to_hub, resume_from_checkpoint = resume_from_checkpoint, hub_model_id = hub_model_id, hub_strategy = hub_strategy, hub_token = hub_token, hub_private_repo = hub_private_repo, hub_always_push = hub_always_push, hub_revision = hub_revision, gradient_checkpointing = gradient_checkpointing, gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, include_inputs_for_metrics = include_inputs_for_metrics, eval_do_concat_batches = eval_do_concat_batches, fp16_backend = fp16_backend, push_to_hub_model_id = push_to_hub_model_id, push_to_hub_organization = push_to_hub_organization, push_to_hub_token = push_to_hub_token, mp_parameters = mp_parameters, auto_find_batch_size = auto_find_batch_size, full_determinism = full_determinism, torchdynamo = torchdynamo, ray_scope = ray_scope, ddp_timeout = ddp_timeout, torch_compile = torch_compile, torch_compile_backend = torch_compile_backend, torch_compile_mode = torch_compile_mode, include_tokens_per_second = include_tokens_per_second, include_num_input_tokens_seen = include_num_input_tokens_seen, neftune_noise_alpha = neftune_noise_alpha, optim_target_modules = optim_target_modules, batch_eval_metrics = batch_eval_metrics, eval_on_start = eval_on_start, use_liger_kernel = use_liger_kernel, liger_kernel_config = liger_kernel_config, eval_use_gather_object = eval_use_gather_object, average_tokens_across_devices = average_tokens_across_devices, model_init_kwargs = model_init_kwargs, max_length = max_length, truncation_mode = truncation_mode, optimize_device_cache = optimize_device_cache,**kwargs) self.vllm_sampling_params = vllm_sampling_params self.unsloth_num_chunks = unsloth_num_chunks pass class _UnslothIterativeSFTTrainer(Trainer): """""" _tag_names = ["trl", "iterative-sft"] def __init__( self, model: Union[str, PreTrainedModel], args: Optional[Union[IterativeSFTConfig, TrainingArguments]] = None, data_collator: Optional[DataCollator] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, processing_class: Optional[ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] ] = None, optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( None, None, ), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, # Deprecated parameters max_length: Optional[int] = None, truncation_mode: Optional[str] = None, optimize_device_cache: Optional[bool] = None, ): # Handle deprecated parameters deprecated_params = {} if max_length is not None: deprecated_params["max_length"] = max_length warnings.warn( "The `max_length` parameter is deprecated and will be removed in version 0.20. " "Pass it through the `args` parameter using `IterativeSFTConfig(max_length=...)` instead.", DeprecationWarning, ) if truncation_mode is not None: deprecated_params["truncation_mode"] = truncation_mode warnings.warn( "The `truncation_mode` parameter is deprecated and will be removed in version 0.20. " "Pass it through the `args` parameter using `IterativeSFTConfig(truncation_mode=...)` instead.", DeprecationWarning, ) if optimize_device_cache is not None: deprecated_params["optimize_device_cache"] = optimize_device_cache warnings.warn( "The `optimize_device_cache` parameter is deprecated and will be removed in version 0.20 " "Pass it through the `args` parameter using `IterativeSFTConfig(optimize_device_cache=...)` instead.", DeprecationWarning, ) # Args model_id = model if isinstance(model, str) else model.config._name_or_path if args is None: model_name = model_id.split("/")[-1] args = IterativeSFTConfig(f"{model_name}-IterativeSFT") elif isinstance(args, TrainingArguments) and not isinstance(args, IterativeSFTConfig): dict_args = args.to_dict() dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token dict_args.pop("push_to_hub_token") args = IterativeSFTConfig(**dict_args) # Update args with deprecated parameters if provided if deprecated_params: for key, value in deprecated_params.items(): setattr(args, key, value) # Handle the tokenizer if processing_class is None: processing_class = AutoTokenizer.from_pretrained(model_id) # Model if args.model_init_kwargs is not None and not isinstance(model, str): warnings.warn( "You passed model_init_kwargs to the `IterativeSFTConfig`, but your model is already instantiated. " "The `model_init_kwargs` will be ignored." ) if isinstance(model, str): model = self._create_model_from_path(model, args) # PEFT configuration and model wrapping if is_peft_available() and isinstance(model, PeftModel): self.is_peft_model = True else: self.is_peft_model = False self.processing_class = processing_class self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False) if data_collator is None: if self.is_encoder_decoder: self.data_collator = DataCollatorForSeq2Seq( processing_class, label_pad_token_id=-100, pad_to_multiple_of=8 ) else: self.data_collator = DataCollatorForLanguageModeling(self.processing_class, mlm=False) else: self.data_collator = data_collator self.max_length = args.max_length self.truncation_mode = args.truncation_mode self.optimize_device_cache = args.optimize_device_cache super().__init__( model=model, args=args, data_collator=self.data_collator, eval_dataset=eval_dataset, processing_class=processing_class, compute_metrics=compute_metrics, optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) self.create_optimizer_and_scheduler(self.args.max_steps) # prepare model, optimizer and lr_scheduler self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( self.model, self.optimizer, self.lr_scheduler ) self.processing_class.truncation_side = "left" if self.truncation_mode == "keep_end" else "right" if not hasattr(self, "accelerator"): raise AttributeError( "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." ) PPODecorators.optimize_device_cache = self.optimize_device_cache def _create_model_from_path(self, model_path: str, args: IterativeSFTConfig) -> PreTrainedModel: """Creates a model from a path or model identifier.""" model_init_kwargs = args.model_init_kwargs or {} return AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor): if attention_mask is None: attention_mask = [torch.ones_like(ids) for ids in input_ids] if self.is_encoder_decoder: input_data = self.data_collator( [ {"input_ids": ids, "attention_mask": att, "labels": lab} for ids, att, lab in zip(input_ids, attention_mask, labels) ] ).to(self.model.device) input_data.pop("decoder_input_ids", None) # This is directly computed inside the model input_data["labels"][input_data["labels"] == self.processing_class.pad_token_id] = -100 else: input_data = self.data_collator( [{"input_ids": ids, "attention_mask": att} for ids, att in zip(input_ids, attention_mask)] ).to(self.model.device) # truncate in case the user has provided input_ids, attention_mask and labels if self.max_length is not None: if self.truncation_mode == "keep_start": input_data = {k: v[: self.max_length] for k, v in input_data.items()} elif self.truncation_mode == "keep_end": input_data = {k: v[-self.max_length :] for k, v in input_data.items()} else: raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") return input_data @staticmethod def _step_safety_checker( input_ids: list[torch.LongTensor], attention_mask: list[torch.LongTensor], labels: list[torch.LongTensor], texts: list[str], texts_labels: list[str], ): """ Check if the input data is valid for training. Args: input_ids (list[`torch.LongTensor`]): List of tensors containing the input_ids attention_mask (list[`torch.LongTensor`]): List of tensors containing the attention_mask labels (list[`torch.FloatTensor`]): List of tensors containing the labels texts (list[`str`]): List of string containing the text input. texts_labels (list[`str`]): List of string containing the text labels. Returns: `tuple`: The input data. """ if texts is None: if attention_mask is None: for name, tensor_list in zip(["input_ids", "labels"], [input_ids, labels]): if not isinstance(tensor_list, list): raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") if not isinstance(tensor_list[0], torch.Tensor): raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") else: for name, tensor_list in zip( ["input_ids", "attention_mask", "labels"], [input_ids, attention_mask, labels] ): if not isinstance(tensor_list, list): raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") if not isinstance(tensor_list[0], torch.Tensor): raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") else: if not isinstance(texts, list): raise ValueError(f"'text' must be a list of strings - got {type(texts)}") if not isinstance(texts[0], str): raise ValueError(f"Elements in 'text' must be strings - got {type(texts[0])}") if texts_labels is not None: if not isinstance(texts_labels, list): raise ValueError(f"'text_labels' must be a list of strings - got {type(texts_labels)}") if not isinstance(texts_labels[0], str): raise ValueError(f"Elements in 'text_labels' must be strings - got {type(texts_labels[0])}") return input_ids, attention_mask, labels, texts, texts_labels @PPODecorators.empty_device_cache() def step( self, input_ids: Optional[list[torch.LongTensor]] = None, attention_mask: Optional[list[torch.LongTensor]] = None, labels: Optional[list[torch.LongTensor]] = None, texts: Optional[list[str]] = None, texts_labels: Optional[list[str]] = None, ): """ Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels. Args: input_ids (list[`torch.LongTensor`]): List of tensors containing the input_ids (if not provided, text will be used) attention_mask (list[`torch.LongTensor`], , *optional*): List of tensors containing the attention_mask labels (list[`torch.FloatTensor`], *optional*): List of tensors containing the labels (if set to None, will default to input_ids) texts (list[`str`], *optional*): List of strings containing the text input (if not provided, input_ids will directly be used) texts_labels (list[`str`], *optional*): List of strings containing the text labels (if set to None, will default to text) Returns: `dict[str, Any]`: A summary of the training statistics """ self.model.train() if self.state.global_step == 0: self.tr_loss = torch.tensor(0.0).to(self.args.device) self._globalstep_last_logged = self.state.global_step if input_ids is None and texts is None: raise ValueError("Step should include `input_ids` or `texts` as keyword arguments.") elif input_ids is not None and texts is not None: warnings.warn( "Both `input_ids` and `texts` argument are provided. `input_ids` will be ignored. " "Please provide only one of the two.", UserWarning, ) if labels is None and texts_labels is None and self.is_encoder_decoder: raise ValueError( "No 'labels' or 'text_labels' are provided. When using an encoder-decoder architecture, 'labels' or 'text_labels' must be passed." ) # Convert Column to list if not already input_ids = input_ids[:] if input_ids is not None else None attention_mask = attention_mask[:] if attention_mask is not None else None labels = labels[:] if labels is not None else None texts = texts[:] if texts is not None else None texts_labels = texts_labels[:] if texts_labels is not None else None input_ids, attention_mask, labels, texts, texts_labels = self._step_safety_checker( input_ids, attention_mask, labels, texts, texts_labels ) if texts is not None: model_inputs = self.processing_class( texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" ) input_ids, attention_mask = model_inputs["input_ids"], model_inputs["attention_mask"] if texts_labels is not None: labels = self.processing_class( texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" )["input_ids"] if labels is None: labels = input_ids model_inputs = self.prepare_model_inputs(input_ids, attention_mask, labels) model_inputs_names = list(model_inputs.keys()) batch_dict = {} batch_dict.update(model_inputs) def collator(data): return_dict = dict() for key in data[0]: if key in ["input_ids", "attention_mask", "labels"]: return_dict[key] = torch.stack([d[key] for d in data]).to(self.model.device) return return_dict batch_data = Dataset.from_dict(batch_dict) batch_data.set_format("torch") step_dataloader = DataLoader( batch_data, batch_size=self.args.per_device_train_batch_size, shuffle=True, collate_fn=collator, ) for _, batch in enumerate(step_dataloader): with self.accelerator.accumulate(self.model): model_inputs = {k: batch[k] for k in model_inputs_names} loss = self.compute_loss(self.model, model_inputs) if self.args.n_gpu > 1: loss = loss.mean() tr_loss_step = loss.detach() self.accelerator.backward(loss) if self.accelerator.sync_gradients and self.args.max_grad_norm is not None: self.accelerator.clip_grad_norm_( self.model.parameters(), self.args.max_grad_norm, ) self.optimizer.step() self.optimizer.zero_grad() if self.lr_scheduler is not None: self.lr_scheduler.step() self.state.global_step += 1 # update stats etc self.tr_loss += tr_loss_step self._maybe_log_save_evaluate() def _maybe_log_save_evaluate(self): # check if eval is required if self.args.eval_steps is not None: if self.state.global_step % self.args.eval_steps == 0 and self.state.global_step != 0: self.evaluate(self.eval_dataset) # check if logging is required if self.args.logging_steps is not None: if self.state.global_step % self.args.logging_steps == 0 and self.state.global_step != 0: logs: dict[str, float] = {} tr_loss_scalar = self._nested_gather(self.tr_loss).mean().item() # reset tr_loss to zero self.tr_loss -= self.tr_loss logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) logs["learning_rate"] = self._get_learning_rate() self._globalstep_last_logged = self.state.global_step self.log(logs) # Ensure the model card is saved along with the checkpoint def _save_checkpoint(self, model, trial): if self.args.hub_model_id is None: model_name = Path(self.args.output_dir).name else: model_name = self.args.hub_model_id.split("/")[-1] self.create_model_card(model_name=model_name) super()._save_checkpoint(model, trial) def create_model_card( self, model_name: Optional[str] = None, dataset_name: Optional[str] = None, tags: Union[str, list[str], None] = None, ): """ Creates a draft of a model card using the information available to the `Trainer`. Args: model_name (`str` or `None`, *optional*, defaults to `None`): Name of the model. dataset_name (`str` or `None`, *optional*, defaults to `None`): Name of the dataset used for training. tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): Tags to be associated with the model card. """ if not self.is_world_process_zero(): return if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): base_model = self.model.config._name_or_path else: base_model = None # normalize `tags` to a mutable set if tags is None: tags = set() elif isinstance(tags, str): tags = {tags} else: tags = set(tags) if hasattr(self.model.config, "unsloth_version"): tags.add("unsloth") tags.update(self._tag_names) model_card = generate_model_card( base_model=base_model, model_name=model_name, hub_model_id=self.hub_model_id, dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, comet_url=get_comet_experiment_url(), trainer_name="Iterative SFT", ) model_card.save(os.path.join(self.args.output_dir, "README.md")) class UnslothIterativeSFTTrainer(_UnslothIterativeSFTTrainer): """ The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization. Args: model (`Union[str, PreTrainedModel]`): Model to be trained. Can be either: - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a path to a *directory* containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in `args.model_init_kwargs`. - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. args ([`IterativeSFTConfig`], *optional*, defaults to `None`): Configuration for this trainer. If `None`, a default configuration is used. data_collator (`DataCollator`, *optional*): Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or tokenizer. eval_dataset (`datasets.Dataset`): The dataset to use for evaluation. processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): Processing class used to process the data. If `None`, the processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`]. optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): The optimizer and scheduler to use for training. preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): The function to use to preprocess the logits before computing the metrics. compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values. max_length (`int`, *optional*, deprecated): Maximum length of the tokenized sequence. Use `args.max_length` instead. truncation_mode (`str`, *optional*, deprecated): The truncation mode to use. Use `args.truncation_mode` instead. optimize_device_cache (`bool`, *optional*, deprecated): Whether to optimize accelerator cache. Use `args.optimize_device_cache` instead. """ def __init__( self, model, args = None, data_collator = None, eval_dataset = None, processing_class = None, preprocess_logits_for_metrics = None, compute_metrics = None, max_length = None, truncation_mode = None, optimize_device_cache = None, **kwargs ): if args is None: args = UnslothIterativeSFTConfig() use_bf16 = getattr(args, 'bf16', False) if type(use_bf16) is not bool: use_bf16 = False use_fp16 = getattr(args, 'fp16', False) if type(use_fp16) is not bool: use_fp16 = False force_float32 = False if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': print('Unsloth: Switching to float32 training since model cannot work with float16') force_float32 = True mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') dtype = getattr(model.config, 'torch_dtype', None) if dtype is None: dtype = model.get_input_embeddings().dtype from unsloth_zoo.utils import _get_dtype dtype = _get_dtype(dtype) float16 = dtype == torch.float16 if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') if force_float32: args.fp16 = False args.bf16 = False os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': args.fp16 = float16 args.bf16 = not float16 os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': args.eval_strategy = 'steps' if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 ga_steps = getattr(args, 'gradient_accumulation_steps', None) if ga_steps is not None and ga_steps > 1: from transformers import __version__ as transformers_version if Version(transformers_version) <= Version('4.45.2'): print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') if getattr(args, 'eval_strategy', 'no') != 'no': eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps fp16_full_eval = getattr(args, 'fp16_full_eval', False) if type(fp16_full_eval) is not bool: fp16_full_eval = False bf16_full_eval = getattr(args, 'bf16_full_eval', False) if type(bf16_full_eval) is not bool: bf16_full_eval = False if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False if force_float32: args.bf16_full_eval = False args.fp16_full_eval = False elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': args.bf16_full_eval = True args.fp16_full_eval = False elif not bf16_full_eval and not fp16_full_eval: args.bf16_full_eval = args.bf16 args.fp16_full_eval = args.fp16 _output_logits = False if locals().get('compute_metrics', None) is not None: _output_logits = True if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True if _output_logits: os.environ['UNSLOTH_RETURN_LOGITS'] = '1' if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): pass else: model_max_seq_length = getattr(model, 'max_seq_length', None) args_max_seq_length = getattr(args, 'max_seq_length', None) if args_max_seq_length is None and model_max_seq_length is not None: max_seq_length = model.max_seq_length if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length if model is not None and hasattr(model, 'for_training'): model.for_training() if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' if 'processing_class' in locals(): if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' other_metrics = [] from unsloth_zoo.logging_utils import PatchRLStatistics PatchRLStatistics('iterative_sft_trainer', other_metrics) super().__init__( model = model, args = args, data_collator = data_collator, eval_dataset = eval_dataset, processing_class = processing_class, preprocess_logits_for_metrics = preprocess_logits_for_metrics, compute_metrics = compute_metrics, max_length = max_length, truncation_mode = truncation_mode, optimize_device_cache = optimize_device_cache,**kwargs) if hasattr(self, 'neftune_hook_handle'): self.neftune_hook_handle.remove() if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle if getattr(args, 'neftune_noise_alpha', None) is not None: model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha pass pass