File size: 4,491 Bytes
3e8b723
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d76bdef
3e8b723
 
 
 
d76bdef
 
 
 
 
 
 
 
 
 
 
 
3e8b723
d76bdef
 
 
 
3e8b723
d76bdef
3e8b723
 
 
d76bdef
3e8b723
d76bdef
3e8b723
d76bdef
3e8b723
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d76bdef
 
 
 
 
 
 
 
 
3e8b723
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# pip install -e git+https://github.com/CompVis/stable-diffusion.git@master
# pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
# convert_original_stable_diffusion_to_diffusers.py

# TODO
# grayscale
# convert to huggingface / train huggingface

import os
import argparse

import torch
import torchvision
import numpy as np
from PIL import Image
import pytorch_lightning as pl
from omegaconf import OmegaConf
from datasets import load_dataset
from librosa.util import normalize
from ldm.util import instantiate_from_config
from pytorch_lightning.trainer import Trainer
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning.callbacks import Callback, ModelCheckpoint

from audiodiffusion.mel import Mel


class AudioDiffusion(Dataset):

    def __init__(self, model_id):
        super().__init__()
        self.hf_dataset = load_dataset(model_id)['train']

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        image = self.hf_dataset[idx]['image'].convert('RGB')
        image = np.frombuffer(image.tobytes(), dtype="uint8").reshape(
            (image.height, image.width, 3))
        image = ((image / 255) * 2 - 1)
        return {'image': image}


class AudioDiffusionDataModule(pl.LightningDataModule):

    def __init__(self, model_id, batch_size):
        super().__init__()
        self.batch_size = batch_size
        self.dataset = AudioDiffusion(model_id)
        self.num_workers = 1

    def train_dataloader(self):
        return DataLoader(self.dataset,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers)


class ImageLogger(Callback):

    def __init__(self, every=1000, resolution=256, hop_length=512):
        super().__init__()
        self.mel = Mel(x_res=resolution,
                       y_res=resolution,
                       hop_length=hop_length)
        self.every = every

    def on_train_batch_end(self, trainer, pl_module, outputs, batch,
                           batch_idx):
        if (batch_idx + 1) % self.every != 0:
            return

        pl_module.eval()
        with torch.no_grad():
            images = pl_module.log_images(batch, split='train')
        pl_module.train()

        for k in images:
            images[k] = images[k].detach().cpu()
            images[k] = torch.clamp(images[k], -1., 1.)
            images[k] = (images[k] + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w
            grid = torchvision.utils.make_grid(images[k])

            tag = f"train/{k}"
            pl_module.logger.experiment.add_image(
                tag, grid, global_step=pl_module.global_step)

            images[k] = (images[k].numpy() *
                         255).round().astype("uint8").transpose(0, 2, 3, 1)
            for _, image in enumerate(images[k]):
                audio = self.mel.image_to_audio(
                    Image.fromarray(image, mode='RGB').convert('L'))
                pl_module.logger.experiment.add_audio(
                    tag + f"/{_}",
                    normalize(audio),
                    global_step=pl_module.global_step,
                    sample_rate=self.mel.get_sample_rate())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train VAE using ldm.")
    parser.add_argument("--batch_size", type=int, default=1)
    args = parser.parse_args()

    config = OmegaConf.load('ldm_autoencoder_kl.yaml')
    lightning_config = config.pop("lightning", OmegaConf.create())
    trainer_config = lightning_config.get("trainer", OmegaConf.create())
    trainer_opt = argparse.Namespace(**trainer_config)
    trainer = Trainer.from_argparse_args(trainer_opt,
                                         callbacks=[
                                             ImageLogger(),
                                             ModelCheckpoint(
                                                 dirpath='checkpoints',
                                                 filename='{epoch:06}',
                                                 verbose=True,
                                                 save_last=True)
                                         ])
    model = instantiate_from_config(config.model)
    model.learning_rate = config.model.base_learning_rate
    data = AudioDiffusionDataModule('teticio/audio-diffusion-256',
                                    batch_size=args.batch_size)
    trainer.fit(model, data)