Spaces:
Runtime error
Runtime error
# pip install -e git+https://github.com/CompVis/stable-diffusion.git@master | |
# pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers | |
# convert_original_stable_diffusion_to_diffusers.py | |
# TODO | |
# grayscale | |
# convert to huggingface / train huggingface | |
import os | |
import argparse | |
import torch | |
import torchvision | |
import numpy as np | |
from PIL import Image | |
import pytorch_lightning as pl | |
from omegaconf import OmegaConf | |
from datasets import load_dataset | |
from librosa.util import normalize | |
from ldm.util import instantiate_from_config | |
from pytorch_lightning.trainer import Trainer | |
from torch.utils.data import DataLoader, Dataset | |
from pytorch_lightning.callbacks import Callback, ModelCheckpoint | |
from audiodiffusion.mel import Mel | |
class AudioDiffusion(Dataset): | |
def __init__(self, model_id): | |
super().__init__() | |
self.hf_dataset = load_dataset(model_id)['train'] | |
def __len__(self): | |
return len(self.hf_dataset) | |
def __getitem__(self, idx): | |
image = self.hf_dataset[idx]['image'].convert('RGB') | |
image = np.frombuffer(image.tobytes(), dtype="uint8").reshape( | |
(image.height, image.width, 3)) | |
image = ((image / 255) * 2 - 1) | |
return {'image': image} | |
class AudioDiffusionDataModule(pl.LightningDataModule): | |
def __init__(self, model_id, batch_size): | |
super().__init__() | |
self.batch_size = batch_size | |
self.dataset = AudioDiffusion(model_id) | |
self.num_workers = 1 | |
def train_dataloader(self): | |
return DataLoader(self.dataset, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers) | |
class ImageLogger(Callback): | |
def __init__(self, every=1000, resolution=256, hop_length=512): | |
super().__init__() | |
self.mel = Mel(x_res=resolution, | |
y_res=resolution, | |
hop_length=hop_length) | |
self.every = every | |
def on_train_batch_end(self, trainer, pl_module, outputs, batch, | |
batch_idx): | |
if (batch_idx + 1) % self.every != 0: | |
return | |
pl_module.eval() | |
with torch.no_grad(): | |
images = pl_module.log_images(batch, split='train') | |
pl_module.train() | |
for k in images: | |
images[k] = images[k].detach().cpu() | |
images[k] = torch.clamp(images[k], -1., 1.) | |
images[k] = (images[k] + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w | |
grid = torchvision.utils.make_grid(images[k]) | |
tag = f"train/{k}" | |
pl_module.logger.experiment.add_image( | |
tag, grid, global_step=pl_module.global_step) | |
images[k] = (images[k].numpy() * | |
255).round().astype("uint8").transpose(0, 2, 3, 1) | |
for _, image in enumerate(images[k]): | |
audio = self.mel.image_to_audio( | |
Image.fromarray(image, mode='RGB').convert('L')) | |
pl_module.logger.experiment.add_audio( | |
tag + f"/{_}", | |
normalize(audio), | |
global_step=pl_module.global_step, | |
sample_rate=self.mel.get_sample_rate()) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Train VAE using ldm.") | |
parser.add_argument("--batch_size", type=int, default=1) | |
args = parser.parse_args() | |
config = OmegaConf.load('ldm_autoencoder_kl.yaml') | |
lightning_config = config.pop("lightning", OmegaConf.create()) | |
trainer_config = lightning_config.get("trainer", OmegaConf.create()) | |
trainer_opt = argparse.Namespace(**trainer_config) | |
trainer = Trainer.from_argparse_args(trainer_opt, | |
callbacks=[ | |
ImageLogger(), | |
ModelCheckpoint( | |
dirpath='checkpoints', | |
filename='{epoch:06}', | |
verbose=True, | |
save_last=True) | |
]) | |
model = instantiate_from_config(config.model) | |
model.learning_rate = config.model.base_learning_rate | |
data = AudioDiffusionDataModule('teticio/audio-diffusion-256', | |
batch_size=args.batch_size) | |
trainer.fit(model, data) | |