Spaces:
Runtime error
Runtime error
""" | |
2025.3.13 | |
2025.3.15 | |
4.48.3 | |
0.15.2 | |
__UNSLOTH_VERSIONING__ | |
""" | |
from torch import Tensor | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn) | |
import os | |
from typing import * | |
from dataclasses import dataclass, field | |
from packaging.version import Version | |
import torch | |
import numpy as np | |
from contextlib import nullcontext | |
from torch.nn import functional as F | |
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling | |
torch_compile_options = { | |
"epilogue_fusion" : True, | |
"max_autotune" : False, | |
"shape_padding" : True, | |
"trace.enabled" : False, | |
"triton.cudagraphs" : False, | |
} | |
def selective_log_softmax(logits, index): | |
logits = logits.to(torch.float32) | |
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1) | |
# loop to reduce peak mem consumption | |
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) | |
logsumexp_values = torch.logsumexp(logits, dim = -1) | |
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) | |
return per_token_logps | |
class UnslothDDPOConfig(DDPOConfig): | |
""" | |
Configuration class for the [`DDPOTrainer`]. | |
Using [`~transformers.HfArgumentParser`] we can turn this class into | |
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the | |
command line. | |
Parameters: | |
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`): | |
Name of this experiment (by default is the file name without the extension name). | |
run_name (`str`, *optional*, defaults to `""`): | |
Name of this run. | |
seed (`int`, *optional*, defaults to `0`): | |
Random seed. | |
log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`): | |
Log with either 'wandb' or 'tensorboard', check | |
https://huggingface.co/docs/accelerate/usage_guides/tracking for more details. | |
tracker_kwargs (`Dict`, *optional*, defaults to `{}`): | |
Keyword arguments for the tracker (e.g. wandb_project). | |
accelerator_kwargs (`Dict`, *optional*, defaults to `{}`): | |
Keyword arguments for the accelerator. | |
project_kwargs (`Dict`, *optional*, defaults to `{}`): | |
Keyword arguments for the accelerator project config (e.g. `logging_dir`). | |
tracker_project_name (`str`, *optional*, defaults to `"trl"`): | |
Name of project to use for tracking. | |
logdir (`str`, *optional*, defaults to `"logs"`): | |
Top-level logging directory for checkpoint saving. | |
num_epochs (`int`, *optional*, defaults to `100`): | |
Number of epochs to train. | |
save_freq (`int`, *optional*, defaults to `1`): | |
Number of epochs between saving model checkpoints. | |
num_checkpoint_limit (`int`, *optional*, defaults to `5`): | |
Number of checkpoints to keep before overwriting old ones. | |
mixed_precision (`str`, *optional*, defaults to `"fp16"`): | |
Mixed precision training. | |
allow_tf32 (`bool`, *optional*, defaults to `True`): | |
Allow `tf32` on Ampere GPUs. | |
resume_from (`str`, *optional*, defaults to `""`): | |
Resume training from a checkpoint. | |
sample_num_steps (`int`, *optional*, defaults to `50`): | |
Number of sampler inference steps. | |
sample_eta (`float`, *optional*, defaults to `1.0`): | |
Eta parameter for the DDIM sampler. | |
sample_guidance_scale (`float`, *optional*, defaults to `5.0`): | |
Classifier-free guidance weight. | |
sample_batch_size (`int`, *optional*, defaults to `1`): | |
Batch size (per GPU) to use for sampling. | |
sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`): | |
Number of batches to sample per epoch. | |
train_batch_size (`int`, *optional*, defaults to `1`): | |
Batch size (per GPU) to use for training. | |
train_use_8bit_adam (`bool`, *optional*, defaults to `False`): | |
Use 8bit Adam optimizer from bitsandbytes. | |
train_learning_rate (`float`, *optional*, defaults to `3e-4`): | |
Learning rate. | |
train_adam_beta1 (`float`, *optional*, defaults to `0.9`): | |
Adam beta1. | |
train_adam_beta2 (`float`, *optional*, defaults to `0.999`): | |
Adam beta2. | |
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`): | |
Adam weight decay. | |
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`): | |
Adam epsilon. | |
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`): | |
Number of gradient accumulation steps. | |
train_max_grad_norm (`float`, *optional*, defaults to `1.0`): | |
Maximum gradient norm for gradient clipping. | |
train_num_inner_epochs (`int`, *optional*, defaults to `1`): | |
Number of inner epochs per outer epoch. | |
train_cfg (`bool`, *optional*, defaults to `True`): | |
Whether to use classifier-free guidance during training. | |
train_adv_clip_max (`float`, *optional*, defaults to `5.0`): | |
Clip advantages to the range. | |
train_clip_range (`float`, *optional*, defaults to `1e-4`): | |
PPO clip range. | |
train_timestep_fraction (`float`, *optional*, defaults to `1.0`): | |
Fraction of timesteps to train on. | |
per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`): | |
Whether to track statistics for each prompt separately. | |
per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`): | |
Number of reward values to store in the buffer for each prompt. | |
per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`): | |
Minimum number of reward values to store in the buffer. | |
async_reward_computation (`bool`, *optional*, defaults to `False`): | |
Whether to compute rewards asynchronously. | |
max_workers (`int`, *optional*, defaults to `2`): | |
Maximum number of workers to use for async reward computation. | |
negative_prompts (`str`, *optional*, defaults to `""`): | |
Comma-separated list of prompts to use as negative examples. | |
push_to_hub (`bool`, *optional*, defaults to `False`): | |
Whether to push the final model checkpoint to the Hub. | |
""" | |
vllm_sampling_params: Optional[Any] = field( | |
default = None, | |
metadata = {'help': 'vLLM SamplingParams'}, | |
) | |
unsloth_num_chunks : Optional[int] = field( | |
default = -1, | |
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, | |
) | |
def __init__( | |
self, | |
exp_name = 'main', | |
run_name = '', | |
seed = 3407, | |
log_with = None, | |
tracker_project_name = 'trl', | |
logdir = 'logs', | |
num_epochs = 100, | |
save_freq = 1, | |
num_checkpoint_limit = 5, | |
mixed_precision = 'fp16', | |
allow_tf32 = True, | |
resume_from = '', | |
sample_num_steps = 50, | |
sample_eta = 1.0, | |
sample_guidance_scale = 5.0, | |
sample_batch_size = 1, | |
sample_num_batches_per_epoch = 2, | |
train_batch_size = 1, | |
train_use_8bit_adam = False, | |
train_learning_rate = 5e-05, | |
train_adam_beta1 = 0.9, | |
train_adam_beta2 = 0.999, | |
train_adam_weight_decay = 0.01, | |
train_adam_epsilon = 1e-08, | |
train_gradient_accumulation_steps = 2, | |
train_max_grad_norm = 1.0, | |
train_num_inner_epochs = 1, | |
train_cfg = True, | |
train_adv_clip_max = 5.0, | |
train_clip_range = 0.0001, | |
train_timestep_fraction = 1.0, | |
per_prompt_stat_tracking = False, | |
per_prompt_stat_tracking_buffer_size = 16, | |
per_prompt_stat_tracking_min_count = 16, | |
async_reward_computation = False, | |
max_workers = 2, | |
negative_prompts = '', | |
push_to_hub = False, | |
vllm_sampling_params = None, | |
unsloth_num_chunks = -1, | |
**kwargs, | |
): | |
super().__init__( | |
exp_name = exp_name, | |
run_name = run_name, | |
seed = seed, | |
log_with = log_with, | |
tracker_project_name = tracker_project_name, | |
logdir = logdir, | |
num_epochs = num_epochs, | |
save_freq = save_freq, | |
num_checkpoint_limit = num_checkpoint_limit, | |
mixed_precision = mixed_precision, | |
allow_tf32 = allow_tf32, | |
resume_from = resume_from, | |
sample_num_steps = sample_num_steps, | |
sample_eta = sample_eta, | |
sample_guidance_scale = sample_guidance_scale, | |
sample_batch_size = sample_batch_size, | |
sample_num_batches_per_epoch = sample_num_batches_per_epoch, | |
train_batch_size = train_batch_size, | |
train_use_8bit_adam = train_use_8bit_adam, | |
train_learning_rate = train_learning_rate, | |
train_adam_beta1 = train_adam_beta1, | |
train_adam_beta2 = train_adam_beta2, | |
train_adam_weight_decay = train_adam_weight_decay, | |
train_adam_epsilon = train_adam_epsilon, | |
train_gradient_accumulation_steps = train_gradient_accumulation_steps, | |
train_max_grad_norm = train_max_grad_norm, | |
train_num_inner_epochs = train_num_inner_epochs, | |
train_cfg = train_cfg, | |
train_adv_clip_max = train_adv_clip_max, | |
train_clip_range = train_clip_range, | |
train_timestep_fraction = train_timestep_fraction, | |
per_prompt_stat_tracking = per_prompt_stat_tracking, | |
per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size, | |
per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count, | |
async_reward_computation = async_reward_computation, | |
max_workers = max_workers, | |
negative_prompts = negative_prompts, | |
push_to_hub = push_to_hub,**kwargs) | |
self.vllm_sampling_params = vllm_sampling_params | |
self.unsloth_num_chunks = unsloth_num_chunks | |
pass | |
class _UnslothDDPOTrainer(PyTorchModelHubMixin): | |
"""""" | |
_tag_names = ["trl", "ddpo"] | |
def __init__( | |
self, | |
config: DDPOConfig, | |
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor], | |
prompt_function: Callable[[], tuple[str, Any]], | |
sd_pipeline: DDPOStableDiffusionPipeline, | |
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None, | |
): | |
if image_samples_hook is None: | |
warn("No image_samples_hook provided; no images will be logged") | |
self.prompt_fn = prompt_function | |
self.reward_fn = reward_function | |
self.config = config | |
self.image_samples_callback = image_samples_hook | |
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs) | |
if self.config.resume_from: | |
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from)) | |
if "checkpoint_" not in os.path.basename(self.config.resume_from): | |
# get the most recent checkpoint in this directory | |
checkpoints = list( | |
filter( | |
lambda x: "checkpoint_" in x, | |
os.listdir(self.config.resume_from), | |
) | |
) | |
if len(checkpoints) == 0: | |
raise ValueError(f"No checkpoints found in {self.config.resume_from}") | |
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints]) | |
self.config.resume_from = os.path.join( | |
self.config.resume_from, | |
f"checkpoint_{checkpoint_numbers[-1]}", | |
) | |
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1 | |
# number of timesteps within each trajectory to train on | |
self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction) | |
self.accelerator = Accelerator( | |
log_with=self.config.log_with, | |
mixed_precision=self.config.mixed_precision, | |
project_config=accelerator_project_config, | |
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the | |
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get | |
# the total number of optimizer steps to accumulate across. | |
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps, | |
**self.config.accelerator_kwargs, | |
) | |
is_okay, message = self._config_check() | |
if not is_okay: | |
raise ValueError(message) | |
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" | |
if self.accelerator.is_main_process: | |
self.accelerator.init_trackers( | |
self.config.tracker_project_name, | |
config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(), | |
init_kwargs=self.config.tracker_kwargs, | |
) | |
logger.info(f"\n{config}") | |
set_seed(self.config.seed, device_specific=True) | |
self.sd_pipeline = sd_pipeline | |
self.sd_pipeline.set_progress_bar_config( | |
position=1, | |
disable=not self.accelerator.is_local_main_process, | |
leave=False, | |
desc="Timestep", | |
dynamic_ncols=True, | |
) | |
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision | |
# as these weights are only used for inference, keeping weights in full precision is not required. | |
if self.accelerator.mixed_precision == "fp16": | |
inference_dtype = torch.float16 | |
elif self.accelerator.mixed_precision == "bf16": | |
inference_dtype = torch.bfloat16 | |
else: | |
inference_dtype = torch.float32 | |
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype) | |
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype) | |
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype) | |
trainable_layers = self.sd_pipeline.get_trainable_layers() | |
self.accelerator.register_save_state_pre_hook(self._save_model_hook) | |
self.accelerator.register_load_state_pre_hook(self._load_model_hook) | |
# Enable TF32 for faster training on Ampere GPUs, | |
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | |
if self.config.allow_tf32: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
self.optimizer = self._setup_optimizer( | |
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers | |
) | |
self.neg_prompt_embed = self.sd_pipeline.text_encoder( | |
self.sd_pipeline.tokenizer( | |
[""] if self.config.negative_prompts is None else self.config.negative_prompts, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=self.sd_pipeline.tokenizer.model_max_length, | |
).input_ids.to(self.accelerator.device) | |
)[0] | |
if config.per_prompt_stat_tracking: | |
self.stat_tracker = PerPromptStatTracker( | |
config.per_prompt_stat_tracking_buffer_size, | |
config.per_prompt_stat_tracking_min_count, | |
) | |
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses | |
# more memory | |
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast | |
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora: | |
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) | |
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters())) | |
else: | |
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) | |
if self.config.async_reward_computation: | |
self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers) | |
if config.resume_from: | |
logger.info(f"Resuming from {config.resume_from}") | |
self.accelerator.load_state(config.resume_from) | |
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1 | |
else: | |
self.first_epoch = 0 | |
def compute_rewards(self, prompt_image_pairs, is_async=False): | |
if not is_async: | |
rewards = [] | |
for images, prompts, prompt_metadata in prompt_image_pairs: | |
reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata) | |
rewards.append( | |
( | |
torch.as_tensor(reward, device=self.accelerator.device), | |
reward_metadata, | |
) | |
) | |
else: | |
rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs) | |
rewards = [ | |
(torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result()) | |
for reward, reward_metadata in rewards | |
] | |
return zip(*rewards) | |
def step(self, epoch: int, global_step: int): | |
""" | |
Perform a single step of training. | |
Args: | |
epoch (int): The current epoch. | |
global_step (int): The current global step. | |
Side Effects: | |
- Model weights are updated | |
- Logs the statistics to the accelerator trackers. | |
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker. | |
Returns: | |
global_step (int): The updated global step. | |
""" | |
samples, prompt_image_data = self._generate_samples( | |
iterations=self.config.sample_num_batches_per_epoch, | |
batch_size=self.config.sample_batch_size, | |
) | |
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) | |
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} | |
rewards, rewards_metadata = self.compute_rewards( | |
prompt_image_data, is_async=self.config.async_reward_computation | |
) | |
for i, image_data in enumerate(prompt_image_data): | |
image_data.extend([rewards[i], rewards_metadata[i]]) | |
if self.image_samples_callback is not None: | |
self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0]) | |
rewards = torch.cat(rewards) | |
rewards = self.accelerator.gather(rewards).cpu().numpy() | |
self.accelerator.log( | |
{ | |
"reward": rewards, | |
"epoch": epoch, | |
"reward_mean": rewards.mean(), | |
"reward_std": rewards.std(), | |
}, | |
step=global_step, | |
) | |
if self.config.per_prompt_stat_tracking: | |
# gather the prompts across processes | |
prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy() | |
prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True) | |
advantages = self.stat_tracker.update(prompts, rewards) | |
else: | |
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) | |
# ungather advantages; keep the entries corresponding to the samples on this process | |
samples["advantages"] = ( | |
torch.as_tensor(advantages) | |
.reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index] | |
.to(self.accelerator.device) | |
) | |
del samples["prompt_ids"] | |
total_batch_size, num_timesteps = samples["timesteps"].shape | |
for inner_epoch in range(self.config.train_num_inner_epochs): | |
# shuffle samples along batch dimension | |
perm = torch.randperm(total_batch_size, device=self.accelerator.device) | |
samples = {k: v[perm] for k, v in samples.items()} | |
# shuffle along time dimension independently for each sample | |
# still trying to understand the code below | |
perms = torch.stack( | |
[torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)] | |
) | |
for key in ["timesteps", "latents", "next_latents", "log_probs"]: | |
samples[key] = samples[key][ | |
torch.arange(total_batch_size, device=self.accelerator.device)[:, None], | |
perms, | |
] | |
original_keys = samples.keys() | |
original_values = samples.values() | |
# rebatch them as user defined train_batch_size is different from sample_batch_size | |
reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values] | |
# Transpose the list of original values | |
transposed_values = zip(*reshaped_values) | |
# Create new dictionaries for each row of transposed values | |
samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values] | |
self.sd_pipeline.unet.train() | |
global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched) | |
# ensure optimization step at the end of the inner epoch | |
if not self.accelerator.sync_gradients: | |
raise ValueError( | |
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings." | |
) | |
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process: | |
self.accelerator.save_state() | |
return global_step | |
def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds): | |
""" | |
Calculate the loss for a batch of an unpacked sample | |
Args: | |
latents (torch.Tensor): | |
The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width] | |
timesteps (torch.Tensor): | |
The timesteps sampled from the diffusion model, shape: [batch_size] | |
next_latents (torch.Tensor): | |
The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width] | |
log_probs (torch.Tensor): | |
The log probabilities of the latents, shape: [batch_size] | |
advantages (torch.Tensor): | |
The advantages of the latents, shape: [batch_size] | |
embeds (torch.Tensor): | |
The embeddings of the prompts, shape: [2*batch_size or batch_size, ...] | |
Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds | |
Returns: | |
loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor) | |
(all of these are of shape (1,)) | |
""" | |
with self.autocast(): | |
if self.config.train_cfg: | |
noise_pred = self.sd_pipeline.unet( | |
torch.cat([latents] * 2), | |
torch.cat([timesteps] * 2), | |
embeds, | |
).sample | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
else: | |
noise_pred = self.sd_pipeline.unet( | |
latents, | |
timesteps, | |
embeds, | |
).sample | |
# compute the log prob of next_latents given latents under the current model | |
scheduler_step_output = self.sd_pipeline.scheduler_step( | |
noise_pred, | |
timesteps, | |
latents, | |
eta=self.config.sample_eta, | |
prev_sample=next_latents, | |
) | |
log_prob = scheduler_step_output.log_probs | |
advantages = torch.clamp( | |
advantages, | |
-self.config.train_adv_clip_max, | |
self.config.train_adv_clip_max, | |
) | |
ratio = torch.exp(log_prob - log_probs) | |
loss = self.loss(advantages, self.config.train_clip_range, ratio) | |
approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2) | |
clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float()) | |
return loss, approx_kl, clipfrac | |
def loss( | |
self, | |
advantages: torch.Tensor, | |
clip_range: float, | |
ratio: torch.Tensor, | |
): | |
unclipped_loss = -advantages * ratio | |
clipped_loss = -advantages * torch.clamp( | |
ratio, | |
1.0 - clip_range, | |
1.0 + clip_range, | |
) | |
return torch.mean(torch.maximum(unclipped_loss, clipped_loss)) | |
def _setup_optimizer(self, trainable_layers_parameters): | |
if self.config.train_use_8bit_adam: | |
import bitsandbytes | |
optimizer_cls = bitsandbytes.optim.AdamW8bit | |
else: | |
optimizer_cls = torch.optim.AdamW | |
return optimizer_cls( | |
trainable_layers_parameters, | |
lr=self.config.train_learning_rate, | |
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2), | |
weight_decay=self.config.train_adam_weight_decay, | |
eps=self.config.train_adam_epsilon, | |
) | |
def _save_model_hook(self, models, weights, output_dir): | |
self.sd_pipeline.save_checkpoint(models, weights, output_dir) | |
weights.pop() # ensures that accelerate doesn't try to handle saving of the model | |
def _load_model_hook(self, models, input_dir): | |
self.sd_pipeline.load_checkpoint(models, input_dir) | |
models.pop() # ensures that accelerate doesn't try to handle loading of the model | |
def _generate_samples(self, iterations, batch_size): | |
""" | |
Generate samples from the model | |
Args: | |
iterations (int): Number of iterations to generate samples for | |
batch_size (int): Batch size to use for sampling | |
Returns: | |
samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]]) | |
""" | |
samples = [] | |
prompt_image_pairs = [] | |
self.sd_pipeline.unet.eval() | |
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1) | |
for _ in range(iterations): | |
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)]) | |
prompt_ids = self.sd_pipeline.tokenizer( | |
prompts, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=self.sd_pipeline.tokenizer.model_max_length, | |
).input_ids.to(self.accelerator.device) | |
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0] | |
with self.autocast(): | |
sd_output = self.sd_pipeline( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=sample_neg_prompt_embeds, | |
num_inference_steps=self.config.sample_num_steps, | |
guidance_scale=self.config.sample_guidance_scale, | |
eta=self.config.sample_eta, | |
output_type="pt", | |
) | |
images = sd_output.images | |
latents = sd_output.latents | |
log_probs = sd_output.log_probs | |
latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...) | |
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1) | |
timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps) | |
samples.append( | |
{ | |
"prompt_ids": prompt_ids, | |
"prompt_embeds": prompt_embeds, | |
"timesteps": timesteps, | |
"latents": latents[:, :-1], # each entry is the latent before timestep t | |
"next_latents": latents[:, 1:], # each entry is the latent after timestep t | |
"log_probs": log_probs, | |
"negative_prompt_embeds": sample_neg_prompt_embeds, | |
} | |
) | |
prompt_image_pairs.append([images, prompts, prompt_metadata]) | |
return samples, prompt_image_pairs | |
def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples): | |
""" | |
Train on a batch of samples. Main training segment | |
Args: | |
inner_epoch (int): The current inner epoch | |
epoch (int): The current epoch | |
global_step (int): The current global step | |
batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on | |
Side Effects: | |
- Model weights are updated | |
- Logs the statistics to the accelerator trackers. | |
Returns: | |
global_step (int): The updated global step | |
""" | |
info = defaultdict(list) | |
for _i, sample in enumerate(batched_samples): | |
if self.config.train_cfg: | |
# concat negative prompts to sample prompts to avoid two forward passes | |
embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]]) | |
else: | |
embeds = sample["prompt_embeds"] | |
for j in range(self.num_train_timesteps): | |
with self.accelerator.accumulate(self.sd_pipeline.unet): | |
loss, approx_kl, clipfrac = self.calculate_loss( | |
sample["latents"][:, j], | |
sample["timesteps"][:, j], | |
sample["next_latents"][:, j], | |
sample["log_probs"][:, j], | |
sample["advantages"], | |
embeds, | |
) | |
info["approx_kl"].append(approx_kl) | |
info["clipfrac"].append(clipfrac) | |
info["loss"].append(loss) | |
self.accelerator.backward(loss) | |
if self.accelerator.sync_gradients: | |
self.accelerator.clip_grad_norm_( | |
self.trainable_layers.parameters() | |
if not isinstance(self.trainable_layers, list) | |
else self.trainable_layers, | |
self.config.train_max_grad_norm, | |
) | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
# Checks if the accelerator has performed an optimization step behind the scenes | |
if self.accelerator.sync_gradients: | |
# log training-related stuff | |
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} | |
info = self.accelerator.reduce(info, reduction="mean") | |
info.update({"epoch": epoch, "inner_epoch": inner_epoch}) | |
self.accelerator.log(info, step=global_step) | |
global_step += 1 | |
info = defaultdict(list) | |
return global_step | |
def _config_check(self) -> tuple[bool, str]: | |
samples_per_epoch = ( | |
self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch | |
) | |
total_train_batch_size = ( | |
self.config.train_batch_size | |
* self.accelerator.num_processes | |
* self.config.train_gradient_accumulation_steps | |
) | |
if not self.config.sample_batch_size >= self.config.train_batch_size: | |
return ( | |
False, | |
f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})", | |
) | |
if not self.config.sample_batch_size % self.config.train_batch_size == 0: | |
return ( | |
False, | |
f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})", | |
) | |
if not samples_per_epoch % total_train_batch_size == 0: | |
return ( | |
False, | |
f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})", | |
) | |
return True, "" | |
def train(self, epochs: Optional[int] = None): | |
""" | |
Train the model for a given number of epochs | |
""" | |
global_step = 0 | |
if epochs is None: | |
epochs = self.config.num_epochs | |
for epoch in range(self.first_epoch, epochs): | |
global_step = self.step(epoch, global_step) | |
def _save_pretrained(self, save_directory): | |
self.sd_pipeline.save_pretrained(save_directory) | |
self.create_model_card() | |
def create_model_card( | |
self, | |
model_name: Optional[str] = None, | |
dataset_name: Optional[str] = None, | |
tags: Union[str, list[str], None] = None, | |
): | |
""" | |
Creates a draft of a model card using the information available to the `Trainer`. | |
Args: | |
model_name (`str` or `None`, *optional*, defaults to `None`): | |
Name of the model. | |
dataset_name (`str` or `None`, *optional*, defaults to `None`): | |
Name of the dataset used for training. | |
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): | |
Tags to be associated with the model card. | |
""" | |
if not self.is_world_process_zero(): | |
return | |
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): | |
base_model = self.model.config._name_or_path | |
else: | |
base_model = None | |
tags = tags or [] | |
if isinstance(tags, str): | |
tags = [tags] | |
if hasattr(self.model.config, "unsloth_version"): | |
tags.append("unsloth") | |
citation = textwrap.dedent("""\ | |
@inproceedings{black2024training, | |
title = {{Training Diffusion Models with Reinforcement Learning}}, | |
author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine}, | |
year = 2024, | |
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024}, | |
publisher = {OpenReview.net}, | |
url = {https://openreview.net/forum?id=YCWjhGrJFD}, | |
}""") | |
model_card = generate_model_card( | |
base_model=base_model, | |
model_name=model_name, | |
hub_model_id=self.hub_model_id, | |
dataset_name=dataset_name, | |
tags=tags, | |
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, | |
comet_url=get_comet_experiment_url(), | |
trainer_name="DDPO", | |
trainer_citation=citation, | |
paper_title="Training Diffusion Models with Reinforcement Learning", | |
paper_id="2305.13301", | |
) | |
model_card.save(os.path.join(self.args.output_dir, "README.md")) | |
class UnslothDDPOTrainer(_UnslothDDPOTrainer): | |
""" | |
The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. | |
Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch | |
As of now only Stable Diffusion based pipelines are supported | |
Attributes: | |
**config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more | |
details. | |
**reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used | |
**prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model | |
**sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training. | |
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images | |
""" | |
def __init__( | |
self, | |
config, | |
reward_function, | |
prompt_function, | |
sd_pipeline, | |
image_samples_hook = None, | |
**kwargs | |
): | |
if args is None: args = UnslothDDPOConfig() | |
other_metrics = [] | |
from unsloth_zoo.logging_utils import PatchRLStatistics | |
PatchRLStatistics('ddpo_trainer', other_metrics) | |
super().__init__( | |
config = config, | |
reward_function = reward_function, | |
prompt_function = prompt_function, | |
sd_pipeline = sd_pipeline, | |
image_samples_hook = image_samples_hook,**kwargs) | |
pass | |