self-forcing / pipeline /self_forcing_training.py
multimodalart's picture
Upload 80 files
0fd2f06 verified
from utils.wan_wrapper import WanDiffusionWrapper
from utils.scheduler import SchedulerInterface
from typing import List, Optional
import torch
import torch.distributed as dist
class SelfForcingTrainingPipeline:
def __init__(self,
denoising_step_list: List[int],
scheduler: SchedulerInterface,
generator: WanDiffusionWrapper,
num_frame_per_block=3,
independent_first_frame: bool = False,
same_step_across_blocks: bool = False,
last_step_only: bool = False,
num_max_frames: int = 21,
context_noise: int = 0,
**kwargs):
super().__init__()
self.scheduler = scheduler
self.generator = generator
self.denoising_step_list = denoising_step_list
if self.denoising_step_list[-1] == 0:
self.denoising_step_list = self.denoising_step_list[:-1] # remove the zero timestep for inference
# Wan specific hyperparameters
self.num_transformer_blocks = 30
self.frame_seq_length = 1560
self.num_frame_per_block = num_frame_per_block
self.context_noise = context_noise
self.i2v = False
self.kv_cache1 = None
self.kv_cache2 = None
self.independent_first_frame = independent_first_frame
self.same_step_across_blocks = same_step_across_blocks
self.last_step_only = last_step_only
self.kv_cache_size = num_max_frames * self.frame_seq_length
def generate_and_sync_list(self, num_blocks, num_denoising_steps, device):
rank = dist.get_rank() if dist.is_initialized() else 0
if rank == 0:
# Generate random indices
indices = torch.randint(
low=0,
high=num_denoising_steps,
size=(num_blocks,),
device=device
)
if self.last_step_only:
indices = torch.ones_like(indices) * (num_denoising_steps - 1)
else:
indices = torch.empty(num_blocks, dtype=torch.long, device=device)
dist.broadcast(indices, src=0) # Broadcast the random indices to all ranks
return indices.tolist()
def inference_with_trajectory(
self,
noise: torch.Tensor,
initial_latent: Optional[torch.Tensor] = None,
return_sim_step: bool = False,
**conditional_dict
) -> torch.Tensor:
batch_size, num_frames, num_channels, height, width = noise.shape
if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
# If the first frame is independent and the first frame is provided, then the number of frames in the
# noise should still be a multiple of num_frame_per_block
assert num_frames % self.num_frame_per_block == 0
num_blocks = num_frames // self.num_frame_per_block
else:
# Using a [1, 4, 4, 4, 4, 4, ...] model to generate a video without image conditioning
assert (num_frames - 1) % self.num_frame_per_block == 0
num_blocks = (num_frames - 1) // self.num_frame_per_block
num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
num_output_frames = num_frames + num_input_frames # add the initial latent frames
output = torch.zeros(
[batch_size, num_output_frames, num_channels, height, width],
device=noise.device,
dtype=noise.dtype
)
# Step 1: Initialize KV cache to all zeros
self._initialize_kv_cache(
batch_size=batch_size, dtype=noise.dtype, device=noise.device
)
self._initialize_crossattn_cache(
batch_size=batch_size, dtype=noise.dtype, device=noise.device
)
# if self.kv_cache1 is None:
# self._initialize_kv_cache(
# batch_size=batch_size,
# dtype=noise.dtype,
# device=noise.device,
# )
# self._initialize_crossattn_cache(
# batch_size=batch_size,
# dtype=noise.dtype,
# device=noise.device
# )
# else:
# # reset cross attn cache
# for block_index in range(self.num_transformer_blocks):
# self.crossattn_cache[block_index]["is_init"] = False
# # reset kv cache
# for block_index in range(len(self.kv_cache1)):
# self.kv_cache1[block_index]["global_end_index"] = torch.tensor(
# [0], dtype=torch.long, device=noise.device)
# self.kv_cache1[block_index]["local_end_index"] = torch.tensor(
# [0], dtype=torch.long, device=noise.device)
# Step 2: Cache context feature
current_start_frame = 0
if initial_latent is not None:
timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
# Assume num_input_frames is 1 + self.num_frame_per_block * num_input_blocks
output[:, :1] = initial_latent
with torch.no_grad():
self.generator(
noisy_image_or_video=initial_latent,
conditional_dict=conditional_dict,
timestep=timestep * 0,
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length
)
current_start_frame += 1
# Step 3: Temporal denoising loop
all_num_frames = [self.num_frame_per_block] * num_blocks
if self.independent_first_frame and initial_latent is None:
all_num_frames = [1] + all_num_frames
num_denoising_steps = len(self.denoising_step_list)
exit_flags = self.generate_and_sync_list(len(all_num_frames), num_denoising_steps, device=noise.device)
start_gradient_frame_index = num_output_frames - 21
# for block_index in range(num_blocks):
for block_index, current_num_frames in enumerate(all_num_frames):
noisy_input = noise[
:, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames]
# Step 3.1: Spatial denoising loop
for index, current_timestep in enumerate(self.denoising_step_list):
if self.same_step_across_blocks:
exit_flag = (index == exit_flags[0])
else:
exit_flag = (index == exit_flags[block_index]) # Only backprop at the randomly selected timestep (consistent across all ranks)
timestep = torch.ones(
[batch_size, current_num_frames],
device=noise.device,
dtype=torch.int64) * current_timestep
if not exit_flag:
with torch.no_grad():
_, denoised_pred = self.generator(
noisy_image_or_video=noisy_input,
conditional_dict=conditional_dict,
timestep=timestep,
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length
)
next_timestep = self.denoising_step_list[index + 1]
noisy_input = self.scheduler.add_noise(
denoised_pred.flatten(0, 1),
torch.randn_like(denoised_pred.flatten(0, 1)),
next_timestep * torch.ones(
[batch_size * current_num_frames], device=noise.device, dtype=torch.long)
).unflatten(0, denoised_pred.shape[:2])
else:
# for getting real output
# with torch.set_grad_enabled(current_start_frame >= start_gradient_frame_index):
if current_start_frame < start_gradient_frame_index:
with torch.no_grad():
_, denoised_pred = self.generator(
noisy_image_or_video=noisy_input,
conditional_dict=conditional_dict,
timestep=timestep,
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length
)
else:
_, denoised_pred = self.generator(
noisy_image_or_video=noisy_input,
conditional_dict=conditional_dict,
timestep=timestep,
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length
)
break
# Step 3.2: record the model's output
output[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
# Step 3.3: rerun with timestep zero to update the cache
context_timestep = torch.ones_like(timestep) * self.context_noise
# add context noise
denoised_pred = self.scheduler.add_noise(
denoised_pred.flatten(0, 1),
torch.randn_like(denoised_pred.flatten(0, 1)),
context_timestep * torch.ones(
[batch_size * current_num_frames], device=noise.device, dtype=torch.long)
).unflatten(0, denoised_pred.shape[:2])
with torch.no_grad():
self.generator(
noisy_image_or_video=denoised_pred,
conditional_dict=conditional_dict,
timestep=context_timestep,
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length
)
# Step 3.4: update the start and end frame indices
current_start_frame += current_num_frames
# Step 3.5: Return the denoised timestep
if not self.same_step_across_blocks:
denoised_timestep_from, denoised_timestep_to = None, None
elif exit_flags[0] == len(self.denoising_step_list) - 1:
denoised_timestep_to = 0
denoised_timestep_from = 1000 - torch.argmin(
(self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0]].cuda()).abs(), dim=0).item()
else:
denoised_timestep_to = 1000 - torch.argmin(
(self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0] + 1].cuda()).abs(), dim=0).item()
denoised_timestep_from = 1000 - torch.argmin(
(self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0]].cuda()).abs(), dim=0).item()
if return_sim_step:
return output, denoised_timestep_from, denoised_timestep_to, exit_flags[0] + 1
return output, denoised_timestep_from, denoised_timestep_to
def _initialize_kv_cache(self, batch_size, dtype, device):
"""
Initialize a Per-GPU KV cache for the Wan model.
"""
kv_cache1 = []
for _ in range(self.num_transformer_blocks):
kv_cache1.append({
"k": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
"v": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
})
self.kv_cache1 = kv_cache1 # always store the clean cache
def _initialize_crossattn_cache(self, batch_size, dtype, device):
"""
Initialize a Per-GPU cross-attention cache for the Wan model.
"""
crossattn_cache = []
for _ in range(self.num_transformer_blocks):
crossattn_cache.append({
"k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
"v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
"is_init": False
})
self.crossattn_cache = crossattn_cache