File size: 8,083 Bytes
113884e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import torch
import random
import cv2
import fnmatch
import torch.nn.functional as F
from torchvision import transforms
import torchvision.transforms.functional as TF
from diffusers.optimization import get_scheduler
from einops import rearrange, repeat
from omegaconf import OmegaConf
from dataset import *
from models.unet.motion_embeddings import *
from .lora import *
from .lora_handler import *

def find_videos(directory, extensions=('.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.gif')):
    video_files = []
    for root, dirs, files in os.walk(directory):
        for extension in extensions:
            for filename in fnmatch.filter(files, '*' + extension):
                video_files.append(os.path.join(root, filename))
    return video_files

def param_optim(model, condition, extra_params=None, is_lora=False, negation=None):
    extra_params = extra_params if len(extra_params.keys()) > 0 else None
    return {
        "model": model,
        "condition": condition,
        'extra_params': extra_params,
        'is_lora': is_lora,
        "negation": negation
    }

def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None):
    params = {
        "name": name,
        "params": params,
        "lr": lr
    }
    if extra_params is not None:
        for k, v in extra_params.items():
            params[k] = v

    return params

def create_optimizer_params(model_list, lr):
    import itertools
    optimizer_params = []

    for optim in model_list:
        model, condition, extra_params, is_lora, negation = optim.values()
        # Check if we are doing LoRA training.
        if is_lora and condition and isinstance(model, list):
            params = create_optim_params(
                params=itertools.chain(*model),
                extra_params=extra_params
            )
            optimizer_params.append(params)
            continue

        if is_lora and condition and not isinstance(model, list):
            for n, p in model.named_parameters():
                if 'lora' in n:
                    params = create_optim_params(n, p, lr, extra_params)
                    optimizer_params.append(params)
            continue

        # If this is true, we can train it.
        if condition:
            for n, p in model.named_parameters():
                should_negate = 'lora' in n and not is_lora
                if should_negate: continue

                params = create_optim_params(n, p, lr, extra_params)
                optimizer_params.append(params)

    return optimizer_params

def get_optimizer(use_8bit_adam):
    if use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
            )

        return bnb.optim.AdamW8bit
    else:
        return torch.optim.AdamW
    
# Initialize the optimizer
def prepare_optimizers(params, config, **extra_params):   
    optimizer_cls = get_optimizer(config.train.use_8bit_adam)

    optimizer_temporal = optimizer_cls(
        params,
        lr=config.loss.learning_rate
    )

    lr_scheduler_temporal = get_scheduler(
        config.loss.lr_scheduler,
        optimizer=optimizer_temporal,
        num_warmup_steps=config.loss.lr_warmup_steps * config.train.gradient_accumulation_steps,
        num_training_steps=config.train.max_train_steps * config.train.gradient_accumulation_steps,
    )

    # Insert Spatial LoRAs
    if config.loss.type == 'DebiasedHybrid':
        unet_lora_params_spatial_list = extra_params.get('unet_lora_params_spatial_list', [])
        spatial_lora_num = extra_params.get('spatial_lora_num', 1)

        optimizer_spatial_list = []
        lr_scheduler_spatial_list = []
        for i in range(spatial_lora_num):
            unet_lora_params_spatial = unet_lora_params_spatial_list[i]

            optimizer_spatial = optimizer_cls(
                create_optimizer_params(
                    [
                        param_optim(
                            unet_lora_params_spatial, 
                            config.loss.use_unet_lora, 
                            is_lora=True,
                            extra_params={**{"lr": config.loss.learning_rate_spatial}}
                        )
                    ], 
                    config.loss.learning_rate_spatial
                ),
                lr=config.loss.learning_rate_spatial
            )
            optimizer_spatial_list.append(optimizer_spatial)

            # Scheduler
            lr_scheduler_spatial = get_scheduler(
                config.loss.lr_scheduler,
                optimizer=optimizer_spatial,
                num_warmup_steps=config.loss.lr_warmup_steps * config.train.gradient_accumulation_steps,
                num_training_steps=config.train.max_train_steps * config.train.gradient_accumulation_steps,
            )
            lr_scheduler_spatial_list.append(lr_scheduler_spatial)

    else:
        optimizer_spatial_list = []
        lr_scheduler_spatial_list = []
    

    
    return [optimizer_temporal] + optimizer_spatial_list, [lr_scheduler_temporal] + lr_scheduler_spatial_list

def sample_noise(latents, noise_strength, use_offset_noise=False):
    b, c, f, *_ = latents.shape
    noise_latents = torch.randn_like(latents, device=latents.device)

    if use_offset_noise:
        offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device)
        noise_latents = noise_latents + noise_strength * offset_noise

    return noise_latents

@torch.no_grad()
def tensor_to_vae_latent(t, vae):
    video_length = t.shape[1]

    t = rearrange(t, "b f c h w -> (b f) c h w")
    latents = vae.encode(t).latent_dist.sample()
    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
    latents = latents * 0.18215

    return latents

def prepare_data(config, tokenizer):
    # Get the training dataset based on types (json, single_video, image)

    # Assuming config.dataset is a DictConfig object
    dataset_params_dict = OmegaConf.to_container(config.dataset, resolve=True)

    # Remove the 'type' key
    dataset_params_dict.pop('type', None)  # 'None' ensures no error if 'type' key doesn't exist

    train_datasets = []

    # Loop through all available datasets, get the name, then add to list of data to process.
    for DataSet in [VideoJsonDataset, SingleVideoDataset, ImageDataset, VideoFolderDataset]:
        for dataset in config.dataset.type:
            if dataset == DataSet.__getname__():
                train_datasets.append(DataSet(**dataset_params_dict, tokenizer=tokenizer))

    if len(train_datasets) < 0:
        raise ValueError("Dataset type not found: 'json', 'single_video', 'folder', 'image'")
        
    train_dataset = train_datasets[0]


    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train.train_batch_size,
        shuffle=True
    )

    return train_dataloader, train_dataset

# create parameters for optimziation
def prepare_params(unet, config, train_dataset):
    extra_params = {}

    params,embedding_layers = inject_motion_embeddings(
        unet, 
        combinations=config.model.motion_embeddings.combinations,
        config=config
    )

    config.model.embedding_layers = embedding_layers
    if config.loss.type == "DebiasedHybrid":
        if config.loss.spatial_lora_num == -1:
            config.loss.spatial_lora_num = train_dataset.__len__()

        lora_managers_spatial, unet_lora_params_spatial_list, unet_negation_all = inject_spatial_loras(
            unet=unet, 
            use_unet_lora=True,
            lora_unet_dropout=0.1,
            lora_path='',
            lora_rank=32,
            spatial_lora_num=1,
        )
        
        extra_params['lora_managers_spatial'] = lora_managers_spatial
        extra_params['unet_lora_params_spatial_list'] = unet_lora_params_spatial_list
        extra_params['unet_negation_all'] = unet_negation_all

    return params, extra_params