Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,083 Bytes
113884e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
import torch
import random
import cv2
import fnmatch
import torch.nn.functional as F
from torchvision import transforms
import torchvision.transforms.functional as TF
from diffusers.optimization import get_scheduler
from einops import rearrange, repeat
from omegaconf import OmegaConf
from dataset import *
from models.unet.motion_embeddings import *
from .lora import *
from .lora_handler import *
def find_videos(directory, extensions=('.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.gif')):
video_files = []
for root, dirs, files in os.walk(directory):
for extension in extensions:
for filename in fnmatch.filter(files, '*' + extension):
video_files.append(os.path.join(root, filename))
return video_files
def param_optim(model, condition, extra_params=None, is_lora=False, negation=None):
extra_params = extra_params if len(extra_params.keys()) > 0 else None
return {
"model": model,
"condition": condition,
'extra_params': extra_params,
'is_lora': is_lora,
"negation": negation
}
def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None):
params = {
"name": name,
"params": params,
"lr": lr
}
if extra_params is not None:
for k, v in extra_params.items():
params[k] = v
return params
def create_optimizer_params(model_list, lr):
import itertools
optimizer_params = []
for optim in model_list:
model, condition, extra_params, is_lora, negation = optim.values()
# Check if we are doing LoRA training.
if is_lora and condition and isinstance(model, list):
params = create_optim_params(
params=itertools.chain(*model),
extra_params=extra_params
)
optimizer_params.append(params)
continue
if is_lora and condition and not isinstance(model, list):
for n, p in model.named_parameters():
if 'lora' in n:
params = create_optim_params(n, p, lr, extra_params)
optimizer_params.append(params)
continue
# If this is true, we can train it.
if condition:
for n, p in model.named_parameters():
should_negate = 'lora' in n and not is_lora
if should_negate: continue
params = create_optim_params(n, p, lr, extra_params)
optimizer_params.append(params)
return optimizer_params
def get_optimizer(use_8bit_adam):
if use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
return bnb.optim.AdamW8bit
else:
return torch.optim.AdamW
# Initialize the optimizer
def prepare_optimizers(params, config, **extra_params):
optimizer_cls = get_optimizer(config.train.use_8bit_adam)
optimizer_temporal = optimizer_cls(
params,
lr=config.loss.learning_rate
)
lr_scheduler_temporal = get_scheduler(
config.loss.lr_scheduler,
optimizer=optimizer_temporal,
num_warmup_steps=config.loss.lr_warmup_steps * config.train.gradient_accumulation_steps,
num_training_steps=config.train.max_train_steps * config.train.gradient_accumulation_steps,
)
# Insert Spatial LoRAs
if config.loss.type == 'DebiasedHybrid':
unet_lora_params_spatial_list = extra_params.get('unet_lora_params_spatial_list', [])
spatial_lora_num = extra_params.get('spatial_lora_num', 1)
optimizer_spatial_list = []
lr_scheduler_spatial_list = []
for i in range(spatial_lora_num):
unet_lora_params_spatial = unet_lora_params_spatial_list[i]
optimizer_spatial = optimizer_cls(
create_optimizer_params(
[
param_optim(
unet_lora_params_spatial,
config.loss.use_unet_lora,
is_lora=True,
extra_params={**{"lr": config.loss.learning_rate_spatial}}
)
],
config.loss.learning_rate_spatial
),
lr=config.loss.learning_rate_spatial
)
optimizer_spatial_list.append(optimizer_spatial)
# Scheduler
lr_scheduler_spatial = get_scheduler(
config.loss.lr_scheduler,
optimizer=optimizer_spatial,
num_warmup_steps=config.loss.lr_warmup_steps * config.train.gradient_accumulation_steps,
num_training_steps=config.train.max_train_steps * config.train.gradient_accumulation_steps,
)
lr_scheduler_spatial_list.append(lr_scheduler_spatial)
else:
optimizer_spatial_list = []
lr_scheduler_spatial_list = []
return [optimizer_temporal] + optimizer_spatial_list, [lr_scheduler_temporal] + lr_scheduler_spatial_list
def sample_noise(latents, noise_strength, use_offset_noise=False):
b, c, f, *_ = latents.shape
noise_latents = torch.randn_like(latents, device=latents.device)
if use_offset_noise:
offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device)
noise_latents = noise_latents + noise_strength * offset_noise
return noise_latents
@torch.no_grad()
def tensor_to_vae_latent(t, vae):
video_length = t.shape[1]
t = rearrange(t, "b f c h w -> (b f) c h w")
latents = vae.encode(t).latent_dist.sample()
latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
latents = latents * 0.18215
return latents
def prepare_data(config, tokenizer):
# Get the training dataset based on types (json, single_video, image)
# Assuming config.dataset is a DictConfig object
dataset_params_dict = OmegaConf.to_container(config.dataset, resolve=True)
# Remove the 'type' key
dataset_params_dict.pop('type', None) # 'None' ensures no error if 'type' key doesn't exist
train_datasets = []
# Loop through all available datasets, get the name, then add to list of data to process.
for DataSet in [VideoJsonDataset, SingleVideoDataset, ImageDataset, VideoFolderDataset]:
for dataset in config.dataset.type:
if dataset == DataSet.__getname__():
train_datasets.append(DataSet(**dataset_params_dict, tokenizer=tokenizer))
if len(train_datasets) < 0:
raise ValueError("Dataset type not found: 'json', 'single_video', 'folder', 'image'")
train_dataset = train_datasets[0]
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.train.train_batch_size,
shuffle=True
)
return train_dataloader, train_dataset
# create parameters for optimziation
def prepare_params(unet, config, train_dataset):
extra_params = {}
params,embedding_layers = inject_motion_embeddings(
unet,
combinations=config.model.motion_embeddings.combinations,
config=config
)
config.model.embedding_layers = embedding_layers
if config.loss.type == "DebiasedHybrid":
if config.loss.spatial_lora_num == -1:
config.loss.spatial_lora_num = train_dataset.__len__()
lora_managers_spatial, unet_lora_params_spatial_list, unet_negation_all = inject_spatial_loras(
unet=unet,
use_unet_lora=True,
lora_unet_dropout=0.1,
lora_path='',
lora_rank=32,
spatial_lora_num=1,
)
extra_params['lora_managers_spatial'] = lora_managers_spatial
extra_params['unet_lora_params_spatial_list'] = unet_lora_params_spatial_list
extra_params['unet_negation_all'] = unet_negation_all
return params, extra_params |