Spaces:
Runtime error
Runtime error
import argparse | |
import copy | |
import logging | |
import math | |
import os | |
import os.path as osp | |
import random | |
import time | |
import warnings | |
from collections import OrderedDict | |
from datetime import datetime | |
from pathlib import Path | |
from tempfile import TemporaryDirectory | |
import diffusers | |
import mlflow | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
import transformers | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger | |
from accelerate.utils import DistributedDataParallelKwargs | |
from diffusers import AutoencoderKL, DDIMScheduler | |
from diffusers.optimization import get_scheduler | |
from diffusers.utils import check_min_version | |
from diffusers.utils.import_utils import is_xformers_available | |
from einops import rearrange | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from torchvision import transforms | |
from tqdm.auto import tqdm | |
from transformers import CLIPVisionModelWithProjection | |
from src.dataset.dance_video import HumanDanceVideoDataset | |
from src.models.mutual_self_attention import ReferenceAttentionControl | |
from src.models.pose_guider import PoseGuider | |
from src.models.unet_2d_condition import UNet2DConditionModel | |
from src.models.unet_3d import UNet3DConditionModel | |
from src.pipelines.pipeline_pose2vid import Pose2VideoPipeline | |
from src.utils.util import ( | |
delete_additional_ckpt, | |
import_filename, | |
read_frames, | |
save_videos_grid, | |
seed_everything, | |
) | |
warnings.filterwarnings("ignore") | |
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. | |
check_min_version("0.10.0.dev0") | |
logger = get_logger(__name__, log_level="INFO") | |
class Net(nn.Module): | |
def __init__( | |
self, | |
reference_unet: UNet2DConditionModel, | |
denoising_unet: UNet3DConditionModel, | |
pose_guider: PoseGuider, | |
reference_control_writer, | |
reference_control_reader, | |
): | |
super().__init__() | |
self.reference_unet = reference_unet | |
self.denoising_unet = denoising_unet | |
self.pose_guider = pose_guider | |
self.reference_control_writer = reference_control_writer | |
self.reference_control_reader = reference_control_reader | |
def forward( | |
self, | |
noisy_latents, | |
timesteps, | |
ref_image_latents, | |
clip_image_embeds, | |
pose_img, | |
uncond_fwd: bool = False, | |
): | |
pose_cond_tensor = pose_img.to(device="cuda") | |
pose_fea = self.pose_guider(pose_cond_tensor) | |
if not uncond_fwd: | |
ref_timesteps = torch.zeros_like(timesteps) | |
self.reference_unet( | |
ref_image_latents, | |
ref_timesteps, | |
encoder_hidden_states=clip_image_embeds, | |
return_dict=False, | |
) | |
self.reference_control_reader.update(self.reference_control_writer) | |
model_pred = self.denoising_unet( | |
noisy_latents, | |
timesteps, | |
pose_cond_fea=pose_fea, | |
encoder_hidden_states=clip_image_embeds, | |
).sample | |
return model_pred | |
def compute_snr(noise_scheduler, timesteps): | |
""" | |
Computes SNR as per | |
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 | |
""" | |
alphas_cumprod = noise_scheduler.alphas_cumprod | |
sqrt_alphas_cumprod = alphas_cumprod**0.5 | |
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 | |
# Expand the tensors. | |
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 | |
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ | |
timesteps | |
].float() | |
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): | |
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | |
alpha = sqrt_alphas_cumprod.expand(timesteps.shape) | |
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( | |
device=timesteps.device | |
)[timesteps].float() | |
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): | |
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | |
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) | |
# Compute SNR. | |
snr = (alpha / sigma) ** 2 | |
return snr | |
def log_validation( | |
vae, | |
image_enc, | |
net, | |
scheduler, | |
accelerator, | |
width, | |
height, | |
clip_length=24, | |
generator=None, | |
): | |
logger.info("Running validation... ") | |
ori_net = accelerator.unwrap_model(net) | |
reference_unet = ori_net.reference_unet | |
denoising_unet = ori_net.denoising_unet | |
pose_guider = ori_net.pose_guider | |
if generator is None: | |
generator = torch.manual_seed(42) | |
tmp_denoising_unet = copy.deepcopy(denoising_unet) | |
tmp_denoising_unet = tmp_denoising_unet.to(dtype=torch.float16) | |
pipe = Pose2VideoPipeline( | |
vae=vae, | |
image_encoder=image_enc, | |
reference_unet=reference_unet, | |
denoising_unet=tmp_denoising_unet, | |
pose_guider=pose_guider, | |
scheduler=scheduler, | |
) | |
pipe = pipe.to(accelerator.device) | |
test_cases = [ | |
( | |
"./configs/inference/ref_images/anyone-3.png", | |
"./configs/inference/pose_videos/anyone-video-1_kps.mp4", | |
), | |
( | |
"./configs/inference/ref_images/anyone-2.png", | |
"./configs/inference/pose_videos/anyone-video-2_kps.mp4", | |
), | |
] | |
results = [] | |
for test_case in test_cases: | |
ref_image_path, pose_video_path = test_case | |
ref_name = Path(ref_image_path).stem | |
pose_name = Path(pose_video_path).stem | |
ref_image_pil = Image.open(ref_image_path).convert("RGB") | |
pose_list = [] | |
pose_tensor_list = [] | |
pose_images = read_frames(pose_video_path) | |
pose_transform = transforms.Compose( | |
[transforms.Resize((height, width)), transforms.ToTensor()] | |
) | |
for pose_image_pil in pose_images[:clip_length]: | |
pose_tensor_list.append(pose_transform(pose_image_pil)) | |
pose_list.append(pose_image_pil) | |
pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) | |
pose_tensor = pose_tensor.transpose(0, 1) | |
pipeline_output = pipe( | |
ref_image_pil, | |
pose_list, | |
width, | |
height, | |
clip_length, | |
20, | |
3.5, | |
generator=generator, | |
) | |
video = pipeline_output.videos | |
# Concat it with pose tensor | |
pose_tensor = pose_tensor.unsqueeze(0) | |
video = torch.cat([video, pose_tensor], dim=0) | |
results.append({"name": f"{ref_name}_{pose_name}", "vid": video}) | |
del tmp_denoising_unet | |
del pipe | |
torch.cuda.empty_cache() | |
return results | |
def main(cfg): | |
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) | |
accelerator = Accelerator( | |
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps, | |
mixed_precision=cfg.solver.mixed_precision, | |
log_with="mlflow", | |
project_dir="./mlruns", | |
kwargs_handlers=[kwargs], | |
) | |
# Make one log on every process with the configuration for debugging. | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
logger.info(accelerator.state, main_process_only=False) | |
if accelerator.is_local_main_process: | |
transformers.utils.logging.set_verbosity_warning() | |
diffusers.utils.logging.set_verbosity_info() | |
else: | |
transformers.utils.logging.set_verbosity_error() | |
diffusers.utils.logging.set_verbosity_error() | |
# If passed along, set the training seed now. | |
if cfg.seed is not None: | |
seed_everything(cfg.seed) | |
exp_name = cfg.exp_name | |
save_dir = f"{cfg.output_dir}/{exp_name}" | |
if accelerator.is_main_process: | |
if not os.path.exists(save_dir): | |
os.makedirs(save_dir) | |
inference_config_path = "./configs/inference/inference_v2.yaml" | |
infer_config = OmegaConf.load(inference_config_path) | |
if cfg.weight_dtype == "fp16": | |
weight_dtype = torch.float16 | |
elif cfg.weight_dtype == "fp32": | |
weight_dtype = torch.float32 | |
else: | |
raise ValueError( | |
f"Do not support weight dtype: {cfg.weight_dtype} during training" | |
) | |
sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs) | |
if cfg.enable_zero_snr: | |
sched_kwargs.update( | |
rescale_betas_zero_snr=True, | |
timestep_spacing="trailing", | |
prediction_type="v_prediction", | |
) | |
val_noise_scheduler = DDIMScheduler(**sched_kwargs) | |
sched_kwargs.update({"beta_schedule": "scaled_linear"}) | |
train_noise_scheduler = DDIMScheduler(**sched_kwargs) | |
image_enc = CLIPVisionModelWithProjection.from_pretrained( | |
cfg.image_encoder_path, | |
).to(dtype=weight_dtype, device="cuda") | |
vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to( | |
"cuda", dtype=weight_dtype | |
) | |
reference_unet = UNet2DConditionModel.from_pretrained( | |
cfg.base_model_path, | |
subfolder="unet", | |
).to(device="cuda", dtype=weight_dtype) | |
denoising_unet = UNet3DConditionModel.from_pretrained_2d( | |
cfg.base_model_path, | |
cfg.mm_path, | |
subfolder="unet", | |
unet_additional_kwargs=OmegaConf.to_container( | |
infer_config.unet_additional_kwargs | |
), | |
).to(device="cuda") | |
pose_guider = PoseGuider( | |
conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256) | |
).to(device="cuda", dtype=weight_dtype) | |
stage1_ckpt_dir = cfg.stage1_ckpt_dir | |
stage1_ckpt_step = cfg.stage1_ckpt_step | |
denoising_unet.load_state_dict( | |
torch.load( | |
os.path.join(stage1_ckpt_dir, f"denoising_unet-{stage1_ckpt_step}.pth"), | |
map_location="cpu", | |
), | |
strict=False, | |
) | |
reference_unet.load_state_dict( | |
torch.load( | |
os.path.join(stage1_ckpt_dir, f"reference_unet-{stage1_ckpt_step}.pth"), | |
map_location="cpu", | |
), | |
strict=False, | |
) | |
pose_guider.load_state_dict( | |
torch.load( | |
os.path.join(stage1_ckpt_dir, f"pose_guider-{stage1_ckpt_step}.pth"), | |
map_location="cpu", | |
), | |
strict=False, | |
) | |
# Freeze | |
vae.requires_grad_(False) | |
image_enc.requires_grad_(False) | |
reference_unet.requires_grad_(False) | |
denoising_unet.requires_grad_(False) | |
pose_guider.requires_grad_(False) | |
# Set motion module learnable | |
for name, module in denoising_unet.named_modules(): | |
if "motion_modules" in name: | |
for params in module.parameters(): | |
params.requires_grad = True | |
reference_control_writer = ReferenceAttentionControl( | |
reference_unet, | |
do_classifier_free_guidance=False, | |
mode="write", | |
fusion_blocks="full", | |
) | |
reference_control_reader = ReferenceAttentionControl( | |
denoising_unet, | |
do_classifier_free_guidance=False, | |
mode="read", | |
fusion_blocks="full", | |
) | |
net = Net( | |
reference_unet, | |
denoising_unet, | |
pose_guider, | |
reference_control_writer, | |
reference_control_reader, | |
) | |
if cfg.solver.enable_xformers_memory_efficient_attention: | |
if is_xformers_available(): | |
reference_unet.enable_xformers_memory_efficient_attention() | |
denoising_unet.enable_xformers_memory_efficient_attention() | |
else: | |
raise ValueError( | |
"xformers is not available. Make sure it is installed correctly" | |
) | |
if cfg.solver.gradient_checkpointing: | |
reference_unet.enable_gradient_checkpointing() | |
denoising_unet.enable_gradient_checkpointing() | |
if cfg.solver.scale_lr: | |
learning_rate = ( | |
cfg.solver.learning_rate | |
* cfg.solver.gradient_accumulation_steps | |
* cfg.data.train_bs | |
* accelerator.num_processes | |
) | |
else: | |
learning_rate = cfg.solver.learning_rate | |
# Initialize the optimizer | |
if cfg.solver.use_8bit_adam: | |
try: | |
import bitsandbytes as bnb | |
except ImportError: | |
raise ImportError( | |
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" | |
) | |
optimizer_cls = bnb.optim.AdamW8bit | |
else: | |
optimizer_cls = torch.optim.AdamW | |
trainable_params = list(filter(lambda p: p.requires_grad, net.parameters())) | |
logger.info(f"Total trainable params {len(trainable_params)}") | |
optimizer = optimizer_cls( | |
trainable_params, | |
lr=learning_rate, | |
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), | |
weight_decay=cfg.solver.adam_weight_decay, | |
eps=cfg.solver.adam_epsilon, | |
) | |
# Scheduler | |
lr_scheduler = get_scheduler( | |
cfg.solver.lr_scheduler, | |
optimizer=optimizer, | |
num_warmup_steps=cfg.solver.lr_warmup_steps | |
* cfg.solver.gradient_accumulation_steps, | |
num_training_steps=cfg.solver.max_train_steps | |
* cfg.solver.gradient_accumulation_steps, | |
) | |
train_dataset = HumanDanceVideoDataset( | |
width=cfg.data.train_width, | |
height=cfg.data.train_height, | |
n_sample_frames=cfg.data.n_sample_frames, | |
sample_rate=cfg.data.sample_rate, | |
img_scale=(1.0, 1.0), | |
data_meta_paths=cfg.data.meta_paths, | |
) | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4 | |
) | |
# Prepare everything with our `accelerator`. | |
( | |
net, | |
optimizer, | |
train_dataloader, | |
lr_scheduler, | |
) = accelerator.prepare( | |
net, | |
optimizer, | |
train_dataloader, | |
lr_scheduler, | |
) | |
# We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
num_update_steps_per_epoch = math.ceil( | |
len(train_dataloader) / cfg.solver.gradient_accumulation_steps | |
) | |
# Afterwards we recalculate our number of training epochs | |
num_train_epochs = math.ceil( | |
cfg.solver.max_train_steps / num_update_steps_per_epoch | |
) | |
# We need to initialize the trackers we use, and also store our configuration. | |
# The trackers initializes automatically on the main process. | |
if accelerator.is_main_process: | |
run_time = datetime.now().strftime("%Y%m%d-%H%M") | |
accelerator.init_trackers( | |
exp_name, | |
init_kwargs={"mlflow": {"run_name": run_time}}, | |
) | |
# dump config file | |
mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml") | |
# Train! | |
total_batch_size = ( | |
cfg.data.train_bs | |
* accelerator.num_processes | |
* cfg.solver.gradient_accumulation_steps | |
) | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {len(train_dataset)}") | |
logger.info(f" Num Epochs = {num_train_epochs}") | |
logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}") | |
logger.info( | |
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" | |
) | |
logger.info( | |
f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}" | |
) | |
logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}") | |
global_step = 0 | |
first_epoch = 0 | |
# Potentially load in the weights and states from a previous save | |
if cfg.resume_from_checkpoint: | |
if cfg.resume_from_checkpoint != "latest": | |
resume_dir = cfg.resume_from_checkpoint | |
else: | |
resume_dir = save_dir | |
# Get the most recent checkpoint | |
dirs = os.listdir(resume_dir) | |
dirs = [d for d in dirs if d.startswith("checkpoint")] | |
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) | |
path = dirs[-1] | |
accelerator.load_state(os.path.join(resume_dir, path)) | |
accelerator.print(f"Resuming from checkpoint {path}") | |
global_step = int(path.split("-")[1]) | |
first_epoch = global_step // num_update_steps_per_epoch | |
resume_step = global_step % num_update_steps_per_epoch | |
# Only show the progress bar once on each machine. | |
progress_bar = tqdm( | |
range(global_step, cfg.solver.max_train_steps), | |
disable=not accelerator.is_local_main_process, | |
) | |
progress_bar.set_description("Steps") | |
for epoch in range(first_epoch, num_train_epochs): | |
train_loss = 0.0 | |
t_data_start = time.time() | |
for step, batch in enumerate(train_dataloader): | |
t_data = time.time() - t_data_start | |
with accelerator.accumulate(net): | |
# Convert videos to latent space | |
pixel_values_vid = batch["pixel_values_vid"].to(weight_dtype) | |
with torch.no_grad(): | |
video_length = pixel_values_vid.shape[1] | |
pixel_values_vid = rearrange( | |
pixel_values_vid, "b f c h w -> (b f) c h w" | |
) | |
latents = vae.encode(pixel_values_vid).latent_dist.sample() | |
latents = rearrange( | |
latents, "(b f) c h w -> b c f h w", f=video_length | |
) | |
latents = latents * 0.18215 | |
noise = torch.randn_like(latents) | |
if cfg.noise_offset > 0: | |
noise += cfg.noise_offset * torch.randn( | |
(latents.shape[0], latents.shape[1], 1, 1, 1), | |
device=latents.device, | |
) | |
bsz = latents.shape[0] | |
# Sample a random timestep for each video | |
timesteps = torch.randint( | |
0, | |
train_noise_scheduler.num_train_timesteps, | |
(bsz,), | |
device=latents.device, | |
) | |
timesteps = timesteps.long() | |
pixel_values_pose = batch["pixel_values_pose"] # (bs, f, c, H, W) | |
pixel_values_pose = pixel_values_pose.transpose( | |
1, 2 | |
) # (bs, c, f, H, W) | |
uncond_fwd = random.random() < cfg.uncond_ratio | |
clip_image_list = [] | |
ref_image_list = [] | |
for batch_idx, (ref_img, clip_img) in enumerate( | |
zip( | |
batch["pixel_values_ref_img"], | |
batch["clip_ref_img"], | |
) | |
): | |
if uncond_fwd: | |
clip_image_list.append(torch.zeros_like(clip_img)) | |
else: | |
clip_image_list.append(clip_img) | |
ref_image_list.append(ref_img) | |
with torch.no_grad(): | |
ref_img = torch.stack(ref_image_list, dim=0).to( | |
dtype=vae.dtype, device=vae.device | |
) | |
ref_image_latents = vae.encode( | |
ref_img | |
).latent_dist.sample() # (bs, d, 64, 64) | |
ref_image_latents = ref_image_latents * 0.18215 | |
clip_img = torch.stack(clip_image_list, dim=0).to( | |
dtype=image_enc.dtype, device=image_enc.device | |
) | |
clip_img = clip_img.to(device="cuda", dtype=weight_dtype) | |
clip_image_embeds = image_enc( | |
clip_img.to("cuda", dtype=weight_dtype) | |
).image_embeds | |
clip_image_embeds = clip_image_embeds.unsqueeze(1) # (bs, 1, d) | |
# add noise | |
noisy_latents = train_noise_scheduler.add_noise( | |
latents, noise, timesteps | |
) | |
# Get the target for loss depending on the prediction type | |
if train_noise_scheduler.prediction_type == "epsilon": | |
target = noise | |
elif train_noise_scheduler.prediction_type == "v_prediction": | |
target = train_noise_scheduler.get_velocity( | |
latents, noise, timesteps | |
) | |
else: | |
raise ValueError( | |
f"Unknown prediction type {train_noise_scheduler.prediction_type}" | |
) | |
# ---- Forward!!! ----- | |
model_pred = net( | |
noisy_latents, | |
timesteps, | |
ref_image_latents, | |
clip_image_embeds, | |
pixel_values_pose, | |
uncond_fwd=uncond_fwd, | |
) | |
if cfg.snr_gamma == 0: | |
loss = F.mse_loss( | |
model_pred.float(), target.float(), reduction="mean" | |
) | |
else: | |
snr = compute_snr(train_noise_scheduler, timesteps) | |
if train_noise_scheduler.config.prediction_type == "v_prediction": | |
# Velocity objective requires that we add one to SNR values before we divide by them. | |
snr = snr + 1 | |
mse_loss_weights = ( | |
torch.stack( | |
[snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1 | |
).min(dim=1)[0] | |
/ snr | |
) | |
loss = F.mse_loss( | |
model_pred.float(), target.float(), reduction="none" | |
) | |
loss = ( | |
loss.mean(dim=list(range(1, len(loss.shape)))) | |
* mse_loss_weights | |
) | |
loss = loss.mean() | |
# Gather the losses across all processes for logging (if we use distributed training). | |
avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean() | |
train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps | |
# Backpropagate | |
accelerator.backward(loss) | |
if accelerator.sync_gradients: | |
accelerator.clip_grad_norm_( | |
trainable_params, | |
cfg.solver.max_grad_norm, | |
) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
if accelerator.sync_gradients: | |
reference_control_reader.clear() | |
reference_control_writer.clear() | |
progress_bar.update(1) | |
global_step += 1 | |
accelerator.log({"train_loss": train_loss}, step=global_step) | |
train_loss = 0.0 | |
if global_step % cfg.val.validation_steps == 0: | |
if accelerator.is_main_process: | |
generator = torch.Generator(device=accelerator.device) | |
generator.manual_seed(cfg.seed) | |
sample_dicts = log_validation( | |
vae=vae, | |
image_enc=image_enc, | |
net=net, | |
scheduler=val_noise_scheduler, | |
accelerator=accelerator, | |
width=cfg.data.train_width, | |
height=cfg.data.train_height, | |
clip_length=cfg.data.n_sample_frames, | |
generator=generator, | |
) | |
for sample_id, sample_dict in enumerate(sample_dicts): | |
sample_name = sample_dict["name"] | |
vid = sample_dict["vid"] | |
with TemporaryDirectory() as temp_dir: | |
out_file = Path( | |
f"{temp_dir}/{global_step:06d}-{sample_name}.gif" | |
) | |
save_videos_grid(vid, out_file, n_rows=2) | |
mlflow.log_artifact(out_file) | |
logs = { | |
"step_loss": loss.detach().item(), | |
"lr": lr_scheduler.get_last_lr()[0], | |
"td": f"{t_data:.2f}s", | |
} | |
t_data_start = time.time() | |
progress_bar.set_postfix(**logs) | |
if global_step >= cfg.solver.max_train_steps: | |
break | |
# save model after each epoch | |
if accelerator.is_main_process: | |
save_path = os.path.join(save_dir, f"checkpoint-{global_step}") | |
delete_additional_ckpt(save_dir, 1) | |
accelerator.save_state(save_path) | |
# save motion module only | |
unwrap_net = accelerator.unwrap_model(net) | |
save_checkpoint( | |
unwrap_net.denoising_unet, | |
save_dir, | |
"motion_module", | |
global_step, | |
total_limit=3, | |
) | |
# Create the pipeline using the trained modules and save it. | |
accelerator.wait_for_everyone() | |
accelerator.end_training() | |
def save_checkpoint(model, save_dir, prefix, ckpt_num, total_limit=None): | |
save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth") | |
if total_limit is not None: | |
checkpoints = os.listdir(save_dir) | |
checkpoints = [d for d in checkpoints if d.startswith(prefix)] | |
checkpoints = sorted( | |
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0]) | |
) | |
if len(checkpoints) >= total_limit: | |
num_to_remove = len(checkpoints) - total_limit + 1 | |
removing_checkpoints = checkpoints[0:num_to_remove] | |
logger.info( | |
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" | |
) | |
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") | |
for removing_checkpoint in removing_checkpoints: | |
removing_checkpoint = os.path.join(save_dir, removing_checkpoint) | |
os.remove(removing_checkpoint) | |
mm_state_dict = OrderedDict() | |
state_dict = model.state_dict() | |
for key in state_dict: | |
if "motion_module" in key: | |
mm_state_dict[key] = state_dict[key] | |
torch.save(mm_state_dict, save_path) | |
def decode_latents(vae, latents): | |
video_length = latents.shape[2] | |
latents = 1 / 0.18215 * latents | |
latents = rearrange(latents, "b c f h w -> (b f) c h w") | |
# video = self.vae.decode(latents).sample | |
video = [] | |
for frame_idx in tqdm(range(latents.shape[0])): | |
video.append(vae.decode(latents[frame_idx : frame_idx + 1]).sample) | |
video = torch.cat(video) | |
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) | |
video = (video / 2 + 0.5).clamp(0, 1) | |
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | |
video = video.cpu().float().numpy() | |
return video | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml") | |
args = parser.parse_args() | |
if args.config[-5:] == ".yaml": | |
config = OmegaConf.load(args.config) | |
elif args.config[-3:] == ".py": | |
config = import_filename(args.config).cfg | |
else: | |
raise ValueError("Do not support this format config file") | |
main(config) | |