teticio commited on
Commit
96e8f55
β€’
2 Parent(s): 58bc92a 2561128

Merge pull request #9 from teticio/latent-audio-diffusion

Browse files
.gitignore CHANGED
@@ -1,8 +1,11 @@
1
  .vscode
2
  __pycache__
3
  .ipynb_checkpoints
4
- data*
5
- ddpm-ema-audio-*
6
  flagged
7
  build
8
  audiodiffusion.egg-info
 
 
 
 
1
  .vscode
2
  __pycache__
3
  .ipynb_checkpoints
4
+ data
5
+ models
6
  flagged
7
  build
8
  audiodiffusion.egg-info
9
+ lightning_logs
10
+ taming
11
+ checkpoints
README.md CHANGED
@@ -15,7 +15,10 @@ license: gpl-3.0
15
 
16
  ---
17
 
18
- **UPDATES**:
 
 
 
19
 
20
  4/10/2022
21
  It is now possible to mask parts of the input audio during generation which means you can stitch several samples together (think "out-painting").
@@ -45,35 +48,39 @@ You can play around with some pretrained models on [Google Colab](https://colab.
45
  ---
46
 
47
  ## Generate Mel spectrogram dataset from directory of audio files
 
 
 
 
 
48
  #### Training can be run with Mel spectrograms of resolution 64x64 on a single commercial grade GPU (e.g. RTX 2080 Ti). The `hop_length` should be set to 1024 for better results.
49
 
50
  ```bash
51
- python audio_to_images.py \
52
  --resolution 64 \
53
  --hop_length 1024 \
54
  --input_dir path-to-audio-files \
55
- --output_dir data-test
56
  ```
57
 
58
  #### Generate dataset of 256x256 Mel spectrograms and push to hub (you will need to be authenticated with `huggingface-cli login`).
59
-
60
  ```bash
61
- python audio_to_images.py \
62
  --resolution 256 \
63
  --input_dir path-to-audio-files \
64
- --output_dir data-256 \
65
  --push_to_hub teticio/audio-diffusion-256
66
  ```
 
67
  ## Train model
68
  #### Run training on local machine.
69
-
70
  ```bash
71
- accelerate launch --config_file accelerate_local.yaml \
72
- train_unconditional.py \
73
- --dataset_name data-64 \
74
  --resolution 64 \
75
  --hop_length 1024 \
76
- --output_dir ddpm-ema-audio-64 \
77
  --train_batch_size 16 \
78
  --num_epochs 100 \
79
  --gradient_accumulation_steps 1 \
@@ -83,13 +90,12 @@ accelerate launch --config_file accelerate_local.yaml \
83
  ```
84
 
85
  #### Run training on local machine with `batch_size` of 2 and `gradient_accumulation_steps` 8 to compensate, so that 256x256 resolution model fits on commercial grade GPU and push to hub.
86
-
87
  ```bash
88
- accelerate launch --config_file accelerate_local.yaml \
89
- train_unconditional.py \
90
  --dataset_name teticio/audio-diffusion-256 \
91
  --resolution 256 \
92
- --output_dir ddpm-ema-audio-256 \
93
  --num_epochs 100 \
94
  --train_batch_size 2 \
95
  --eval_batch_size 2 \
@@ -103,13 +109,12 @@ accelerate launch --config_file accelerate_local.yaml \
103
  ```
104
 
105
  #### Run training on SageMaker.
106
-
107
  ```bash
108
- accelerate launch --config_file accelerate_sagemaker.yaml \
109
- strain_unconditional.py \
110
  --dataset_name teticio/audio-diffusion-256 \
111
  --resolution 256 \
112
- --output_dir ddpm-ema-audio-256 \
113
  --train_batch_size 16 \
114
  --num_epochs 100 \
115
  --gradient_accumulation_steps 1 \
@@ -117,3 +122,22 @@ accelerate launch --config_file accelerate_sagemaker.yaml \
117
  --lr_warmup_steps 500 \
118
  --mixed_precision no
119
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  ---
17
 
18
+ **UPDATES**:
19
+
20
+ 15/10/2022
21
+ Added latent audio diffusion (see below). Also added the possibility to train a model to use DDIM ([Denoising Diffusion Implicit Models](https://arxiv.org/pdf/2010.02502.pdf)) by setting `--scheduler ddim`. These have the benefit that samples can be generated with much fewer steps (~50) than used in training.
22
 
23
  4/10/2022
24
  It is now possible to mask parts of the input audio during generation which means you can stitch several samples together (think "out-painting").
 
48
  ---
49
 
50
  ## Generate Mel spectrogram dataset from directory of audio files
51
+ #### Install
52
+ ```bash
53
+ pip install .
54
+ ```
55
+
56
  #### Training can be run with Mel spectrograms of resolution 64x64 on a single commercial grade GPU (e.g. RTX 2080 Ti). The `hop_length` should be set to 1024 for better results.
57
 
58
  ```bash
59
+ python scripts/audio_to_images.py \
60
  --resolution 64 \
61
  --hop_length 1024 \
62
  --input_dir path-to-audio-files \
63
+ --output_dir path-to-output-data
64
  ```
65
 
66
  #### Generate dataset of 256x256 Mel spectrograms and push to hub (you will need to be authenticated with `huggingface-cli login`).
 
67
  ```bash
68
+ python scripts/audio_to_images.py \
69
  --resolution 256 \
70
  --input_dir path-to-audio-files \
71
+ --output_dir data/audio-diffusion-256 \
72
  --push_to_hub teticio/audio-diffusion-256
73
  ```
74
+
75
  ## Train model
76
  #### Run training on local machine.
 
77
  ```bash
78
+ accelerate launch --config_file config/accelerate_local.yaml \
79
+ scripts/train_unconditional.py \
80
+ --dataset_name data/audio-diffusion-64 \
81
  --resolution 64 \
82
  --hop_length 1024 \
83
+ --output_dir models/ddpm-ema-audio-64 \
84
  --train_batch_size 16 \
85
  --num_epochs 100 \
86
  --gradient_accumulation_steps 1 \
 
90
  ```
91
 
92
  #### Run training on local machine with `batch_size` of 2 and `gradient_accumulation_steps` 8 to compensate, so that 256x256 resolution model fits on commercial grade GPU and push to hub.
 
93
  ```bash
94
+ accelerate launch --config_file config/accelerate_local.yaml \
95
+ scripts/train_unconditional.py \
96
  --dataset_name teticio/audio-diffusion-256 \
97
  --resolution 256 \
98
+ --output_dir models/audio-diffusion-256 \
99
  --num_epochs 100 \
100
  --train_batch_size 2 \
101
  --eval_batch_size 2 \
 
109
  ```
110
 
111
  #### Run training on SageMaker.
 
112
  ```bash
113
+ accelerate launch --config_file config/accelerate_sagemaker.yaml \
114
+ scripts/train_unconditional.py \
115
  --dataset_name teticio/audio-diffusion-256 \
116
  --resolution 256 \
117
+ --output_dir models/ddpm-ema-audio-256 \
118
  --train_batch_size 16 \
119
  --num_epochs 100 \
120
  --gradient_accumulation_steps 1 \
 
122
  --lr_warmup_steps 500 \
123
  --mixed_precision no
124
  ```
125
+ ## Latent Audio Diffusion
126
+ Rather than denoising images directly, it is interesting to work in the "latent space" after first encoding images using an autoencoder. This has a number of advantages. Firstly, the information in the images is compressed into a latent space of a much lower dimension, so it is much faster to train denoising diffusion models and run inference with them. Secondly, similar images tend to be clustered together and interpolating between two images in latent space can produce meaningful combinations.
127
+
128
+ At the time of writing, the Hugging Face `diffusers` library is geared towards inference and lacking in training functionality, rather like its cousin `transformers` in the early days of development. In order to train a VAE (Variational Autoencoder), I use the [stable-diffusion](https://github.com/CompVis/stable-diffusion) repo from CompVis and convert the checkpoints to `diffusers` format. Note that it uses a perceptual loss function for images; it would be nice to try a perceptual *audio* loss function.
129
+
130
+ #### Train an autoencoder.
131
+ ```bash
132
+ python scripts/train_vae.py \
133
+ --dataset_name teticio/audio-diffusion-256 \
134
+ --batch_size 2 \
135
+ --gradient_accumulation_steps 12
136
+ ```
137
+
138
+ #### Train latent diffusion model.
139
+ ```bash
140
+ accelerate launch ...
141
+ --vae models/autoencoder-kl
142
+ --latent_resoultion 32
143
+ ```
audiodiffusion/__init__.py CHANGED
@@ -1,15 +1,16 @@
1
- from typing import Iterable, Tuple
2
 
3
  import torch
4
  import numpy as np
5
  from PIL import Image
6
  from tqdm.auto import tqdm
7
  from librosa.beat import beat_track
8
- from diffusers import DDPMPipeline, DDPMScheduler
 
9
 
10
  from .mel import Mel
11
 
12
- VERSION = "1.1.5"
13
 
14
 
15
  class AudioDiffusion:
@@ -42,29 +43,35 @@ class AudioDiffusion:
42
  hop_length=hop_length,
43
  top_db=top_db)
44
  self.model_id = model_id
45
- self.ddpm = DDPMPipeline.from_pretrained(self.model_id)
 
 
 
 
46
  if cuda:
47
- self.ddpm.to("cuda")
48
  self.progress_bar = progress_bar or (lambda _: _)
49
 
50
  def generate_spectrogram_and_audio(
51
  self,
 
52
  generator: torch.Generator = None
53
  ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
54
  """Generate random mel spectrogram and convert to audio.
55
 
56
  Args:
 
57
  generator (torch.Generator): random number generator or None
58
 
59
  Returns:
60
  PIL Image: mel spectrogram
61
  (float, np.ndarray): sample rate and raw audio
62
  """
63
- images = self.ddpm(output_type="numpy", generator=generator)["sample"]
64
- images = (images * 255).round().astype("uint8").transpose(0, 3, 1, 2)
65
- image = Image.fromarray(images[0][0])
66
- audio = self.mel.image_to_audio(image)
67
- return image, (self.mel.get_sample_rate(), audio)
68
 
69
  @torch.no_grad()
70
  def generate_spectrogram_and_audio_from_audio(
@@ -95,44 +102,124 @@ class AudioDiffusion:
95
  (float, np.ndarray): sample rate and raw audio
96
  """
97
 
98
- # It would be better to derive a class from DDPMDiffusionPipeline
99
- # but currently the return type ImagePipelineOutput cannot be imported.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  if steps is None:
101
- steps = self.ddpm.scheduler.num_train_timesteps
102
- scheduler = DDPMScheduler(num_train_timesteps=steps)
 
103
  scheduler.set_timesteps(steps)
104
  mask = None
105
  images = noise = torch.randn(
106
- (1, self.ddpm.unet.in_channels, self.ddpm.unet.sample_size,
107
- self.ddpm.unet.sample_size),
108
  generator=generator)
109
 
110
  if audio_file is not None or raw_audio is not None:
111
- self.mel.load_audio(audio_file, raw_audio)
112
- input_image = self.mel.audio_slice_to_image(slice)
113
  input_image = np.frombuffer(input_image.tobytes(),
114
  dtype="uint8").reshape(
115
  (input_image.height,
116
  input_image.width))
117
  input_image = ((input_image / 255) * 2 - 1)
 
 
 
 
 
 
118
 
119
  if start_step > 0:
120
  images[0, 0] = scheduler.add_noise(
121
- torch.tensor(input_image[np.newaxis, np.newaxis, :]),
122
  noise, torch.tensor(steps - start_step))
123
 
124
- mask_start = int(mask_start_secs * self.mel.get_sample_rate() /
125
- self.mel.hop_length)
126
- mask_end = int(mask_end_secs * self.mel.get_sample_rate() /
127
- self.mel.hop_length)
 
128
  mask = scheduler.add_noise(
129
- torch.tensor(input_image[np.newaxis, np.newaxis, :]), noise,
130
  torch.tensor(scheduler.timesteps[start_step:]))
131
 
132
- images = images.to(self.ddpm.device)
133
  for step, t in enumerate(
134
  self.progress_bar(scheduler.timesteps[start_step:])):
135
- model_output = self.ddpm.unet(images, t)['sample']
136
  images = scheduler.step(model_output,
137
  t,
138
  images,
@@ -140,35 +227,36 @@ class AudioDiffusion:
140
 
141
  if mask is not None:
142
  if mask_start > 0:
143
- images[0, 0, :, :mask_start] = mask[step,
144
- 0, :, :mask_start]
145
  if mask_end > 0:
146
- images[0, 0, :, -mask_end:] = mask[step, 0, :, -mask_end:]
 
 
 
 
 
 
147
 
148
  images = (images / 2 + 0.5).clamp(0, 1)
149
  images = images.cpu().permute(0, 2, 3, 1).numpy()
 
 
 
 
 
150
 
151
- images = (images * 255).round().astype("uint8").transpose(0, 3, 1, 2)
152
- image = Image.fromarray(images[0][0])
153
- audio = self.mel.image_to_audio(image)
154
- return image, (self.mel.get_sample_rate(), audio)
155
 
156
- @staticmethod
157
- def loop_it(audio: np.ndarray,
158
- sample_rate: int,
159
- loops: int = 12) -> np.ndarray:
160
- """Loop audio
161
 
162
- Args:
163
- audio (np.ndarray): audio as numpy array
164
- sample_rate (int): sample rate of audio
165
- loops (int): number of times to loop
166
 
167
- Returns:
168
- (float, np.ndarray): sample rate and raw audio or None
169
- """
170
- _, beats = beat_track(y=audio, sr=sample_rate, units='samples')
171
- for beats_in_bar in [16, 12, 8, 4]:
172
- if len(beats) > beats_in_bar:
173
- return np.tile(audio[beats[0]:beats[beats_in_bar]], loops)
174
- return None
 
1
+ from typing import Iterable, Tuple, Union, List
2
 
3
  import torch
4
  import numpy as np
5
  from PIL import Image
6
  from tqdm.auto import tqdm
7
  from librosa.beat import beat_track
8
+ from diffusers import (DiffusionPipeline, DDPMPipeline, UNet2DConditionModel,
9
+ DDIMScheduler, DDPMScheduler, AutoencoderKL)
10
 
11
  from .mel import Mel
12
 
13
+ VERSION = "1.2.0"
14
 
15
 
16
  class AudioDiffusion:
 
43
  hop_length=hop_length,
44
  top_db=top_db)
45
  self.model_id = model_id
46
+ try: # a bit hacky
47
+ self.pipe = LatentAudioDiffusionPipeline.from_pretrained(self.model_id)
48
+ except:
49
+ self.pipe = AudioDiffusionPipeline.from_pretrained(self.model_id)
50
+
51
  if cuda:
52
+ self.pipe.to("cuda")
53
  self.progress_bar = progress_bar or (lambda _: _)
54
 
55
  def generate_spectrogram_and_audio(
56
  self,
57
+ steps: int = None,
58
  generator: torch.Generator = None
59
  ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
60
  """Generate random mel spectrogram and convert to audio.
61
 
62
  Args:
63
+ steps (int): number of de-noising steps to perform (defaults to num_train_timesteps)
64
  generator (torch.Generator): random number generator or None
65
 
66
  Returns:
67
  PIL Image: mel spectrogram
68
  (float, np.ndarray): sample rate and raw audio
69
  """
70
+ images, (sample_rate, audios) = self.pipe(mel=self.mel,
71
+ batch_size=1,
72
+ steps=steps,
73
+ generator=generator)
74
+ return images[0], (sample_rate, audios[0])
75
 
76
  @torch.no_grad()
77
  def generate_spectrogram_and_audio_from_audio(
 
102
  (float, np.ndarray): sample rate and raw audio
103
  """
104
 
105
+ images, (sample_rate,
106
+ audios) = self.pipe(mel=self.mel,
107
+ batch_size=1,
108
+ audio_file=audio_file,
109
+ raw_audio=raw_audio,
110
+ slice=slice,
111
+ start_step=start_step,
112
+ steps=steps,
113
+ generator=generator,
114
+ mask_start_secs=mask_start_secs,
115
+ mask_end_secs=mask_end_secs)
116
+ return images[0], (sample_rate, audios[0])
117
+
118
+ @staticmethod
119
+ def loop_it(audio: np.ndarray,
120
+ sample_rate: int,
121
+ loops: int = 12) -> np.ndarray:
122
+ """Loop audio
123
+
124
+ Args:
125
+ audio (np.ndarray): audio as numpy array
126
+ sample_rate (int): sample rate of audio
127
+ loops (int): number of times to loop
128
+
129
+ Returns:
130
+ (float, np.ndarray): sample rate and raw audio or None
131
+ """
132
+ _, beats = beat_track(y=audio, sr=sample_rate, units='samples')
133
+ for beats_in_bar in [16, 12, 8, 4]:
134
+ if len(beats) > beats_in_bar:
135
+ return np.tile(audio[beats[0]:beats[beats_in_bar]], loops)
136
+ return None
137
+
138
+
139
+ class AudioDiffusionPipeline(DiffusionPipeline):
140
+
141
+ def __init__(self, unet: UNet2DConditionModel,
142
+ scheduler: Union[DDIMScheduler, DDPMScheduler]):
143
+ super().__init__()
144
+ self.register_modules(unet=unet, scheduler=scheduler)
145
+
146
+ @torch.no_grad()
147
+ def __call__(
148
+ self,
149
+ mel: Mel,
150
+ batch_size: int = 1,
151
+ audio_file: str = None,
152
+ raw_audio: np.ndarray = None,
153
+ slice: int = 0,
154
+ start_step: int = 0,
155
+ steps: int = None,
156
+ generator: torch.Generator = None,
157
+ mask_start_secs: float = 0,
158
+ mask_end_secs: float = 0
159
+ ) -> Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]]:
160
+ """Generate random mel spectrogram from audio input and convert to audio.
161
+
162
+ Args:
163
+ mel (Mel): instance of Mel class to perform image <-> audio
164
+ batch_size (int): number of samples to generate
165
+ audio_file (str): must be a file on disk due to Librosa limitation or
166
+ raw_audio (np.ndarray): audio as numpy array
167
+ slice (int): slice number of audio to convert
168
+ start_step (int): step to start from
169
+ steps (int): number of de-noising steps to perform (defaults to num_train_timesteps)
170
+ generator (torch.Generator): random number generator or None
171
+ mask_start_secs (float): number of seconds of audio to mask (not generate) at start
172
+ mask_end_secs (float): number of seconds of audio to mask (not generate) at end
173
+
174
+ Returns:
175
+ List[PIL Image]: mel spectrograms
176
+ (float, List[np.ndarray]): sample rate and raw audios
177
+ """
178
+
179
  if steps is None:
180
+ steps = self.scheduler.num_train_timesteps
181
+ # Unfortunately, the schedule is set up in the constructor
182
+ scheduler = self.scheduler.__class__(num_train_timesteps=steps)
183
  scheduler.set_timesteps(steps)
184
  mask = None
185
  images = noise = torch.randn(
186
+ (batch_size, self.unet.in_channels, self.unet.sample_size,
187
+ self.unet.sample_size),
188
  generator=generator)
189
 
190
  if audio_file is not None or raw_audio is not None:
191
+ mel.load_audio(audio_file, raw_audio)
192
+ input_image = mel.audio_slice_to_image(slice)
193
  input_image = np.frombuffer(input_image.tobytes(),
194
  dtype="uint8").reshape(
195
  (input_image.height,
196
  input_image.width))
197
  input_image = ((input_image / 255) * 2 - 1)
198
+ input_images = np.tile(input_image, (batch_size, 1, 1, 1))
199
+
200
+ if hasattr(self, 'vqvae'):
201
+ input_images = self.vqvae.encode(
202
+ input_images).latent_dist.sample(generator=generator)
203
+ input_images = 0.18215 * input_images
204
 
205
  if start_step > 0:
206
  images[0, 0] = scheduler.add_noise(
207
+ torch.tensor(input_images[:, np.newaxis, np.newaxis, :]),
208
  noise, torch.tensor(steps - start_step))
209
 
210
+ pixels_per_second = (mel.get_sample_rate() *
211
+ self.unet.sample_size / mel.hop_length /
212
+ mel.x_res)
213
+ mask_start = int(mask_start_secs * pixels_per_second)
214
+ mask_end = int(mask_end_secs * pixels_per_second)
215
  mask = scheduler.add_noise(
216
+ torch.tensor(input_images[:, np.newaxis, :]), noise,
217
  torch.tensor(scheduler.timesteps[start_step:]))
218
 
219
+ images = images.to(self.device)
220
  for step, t in enumerate(
221
  self.progress_bar(scheduler.timesteps[start_step:])):
222
+ model_output = self.unet(images, t)['sample']
223
  images = scheduler.step(model_output,
224
  t,
225
  images,
 
227
 
228
  if mask is not None:
229
  if mask_start > 0:
230
+ images[:, :, :, :mask_start] = mask[
231
+ step, :, :, :, :mask_start]
232
  if mask_end > 0:
233
+ images[:, :, :, -mask_end:] = mask[step, :, :, :,
234
+ -mask_end:]
235
+
236
+ if hasattr(self, 'vqvae'):
237
+ # 0.18215 was scaling factor used in training to ensure unit variance
238
+ images = 1 / 0.18215 * images
239
+ images = self.vqvae.decode(images)['sample']
240
 
241
  images = (images / 2 + 0.5).clamp(0, 1)
242
  images = images.cpu().permute(0, 2, 3, 1).numpy()
243
+ images = (images * 255).round().astype("uint8")
244
+ images = list(
245
+ map(lambda _: Image.fromarray(_[:, :, 0]), images) if images.
246
+ shape[3] == 1 else map(
247
+ lambda _: Image.fromarray(_, mode='RGB').convert('L'), images))
248
 
249
+ audios = list(map(lambda _: mel.image_to_audio(_), images))
250
+ return images, (mel.get_sample_rate(), audios)
 
 
251
 
 
 
 
 
 
252
 
253
+ class LatentAudioDiffusionPipeline(AudioDiffusionPipeline):
 
 
 
254
 
255
+ def __init__(self, unet: UNet2DConditionModel,
256
+ scheduler: Union[DDIMScheduler,
257
+ DDPMScheduler], vqvae: AutoencoderKL):
258
+ super().__init__(unet=unet, scheduler=scheduler)
259
+ self.register_modules(vqvae=vqvae)
260
+
261
+ def __call__(self, *args, **kwargs):
262
+ return super().__call__(*args, **kwargs)
audiodiffusion/utils.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adpated from https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py
2
+
3
+ import torch
4
+ from diffusers import AutoencoderKL
5
+
6
+
7
+ def shave_segments(path, n_shave_prefix_segments=1):
8
+ """
9
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
10
+ """
11
+ if n_shave_prefix_segments >= 0:
12
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
13
+ else:
14
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
15
+
16
+
17
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
18
+ """
19
+ Updates paths inside resnets to the new naming scheme (local renaming)
20
+ """
21
+ mapping = []
22
+ for old_item in old_list:
23
+ new_item = old_item
24
+
25
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
26
+ new_item = shave_segments(
27
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments)
28
+
29
+ mapping.append({"old": old_item, "new": new_item})
30
+
31
+ return mapping
32
+
33
+
34
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
35
+ """
36
+ Updates paths inside attentions to the new naming scheme (local renaming)
37
+ """
38
+ mapping = []
39
+ for old_item in old_list:
40
+ new_item = old_item
41
+
42
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
43
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
44
+
45
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
46
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
47
+
48
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
49
+
50
+ mapping.append({"old": old_item, "new": new_item})
51
+
52
+ return mapping
53
+
54
+
55
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
56
+ """
57
+ Updates paths inside attentions to the new naming scheme (local renaming)
58
+ """
59
+ mapping = []
60
+ for old_item in old_list:
61
+ new_item = old_item
62
+
63
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
64
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
65
+
66
+ new_item = new_item.replace("q.weight", "query.weight")
67
+ new_item = new_item.replace("q.bias", "query.bias")
68
+
69
+ new_item = new_item.replace("k.weight", "key.weight")
70
+ new_item = new_item.replace("k.bias", "key.bias")
71
+
72
+ new_item = new_item.replace("v.weight", "value.weight")
73
+ new_item = new_item.replace("v.bias", "value.bias")
74
+
75
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
76
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
77
+
78
+ new_item = shave_segments(
79
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments)
80
+
81
+ mapping.append({"old": old_item, "new": new_item})
82
+
83
+ return mapping
84
+
85
+
86
+ def assign_to_checkpoint(paths,
87
+ checkpoint,
88
+ old_checkpoint,
89
+ attention_paths_to_split=None,
90
+ additional_replacements=None,
91
+ config=None):
92
+ """
93
+ This does the final conversion step: take locally converted weights and apply a global renaming
94
+ to them. It splits attention layers, and takes into account additional replacements
95
+ that may arise.
96
+
97
+ Assigns the weights to the new checkpoint.
98
+ """
99
+ assert isinstance(
100
+ paths, list
101
+ ), "Paths should be a list of dicts containing 'old' and 'new' keys."
102
+
103
+ # Splits the attention layers into three variables.
104
+ if attention_paths_to_split is not None:
105
+ for path, path_map in attention_paths_to_split.items():
106
+ old_tensor = old_checkpoint[path]
107
+ channels = old_tensor.shape[0] // 3
108
+
109
+ target_shape = (-1,
110
+ channels) if len(old_tensor.shape) == 3 else (-1)
111
+
112
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
113
+
114
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels //
115
+ num_heads) + old_tensor.shape[1:])
116
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
117
+
118
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
119
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
120
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
121
+
122
+ for path in paths:
123
+ new_path = path["new"]
124
+
125
+ # These have already been assigned
126
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
127
+ continue
128
+
129
+ # Global renaming happens here
130
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
131
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
132
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
133
+
134
+ if additional_replacements is not None:
135
+ for replacement in additional_replacements:
136
+ new_path = new_path.replace(replacement["old"],
137
+ replacement["new"])
138
+
139
+ # proj_attn.weight has to be converted from conv 1D to linear
140
+ if "proj_attn.weight" in new_path:
141
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
142
+ else:
143
+ checkpoint[new_path] = old_checkpoint[path["old"]]
144
+
145
+
146
+ def conv_attn_to_linear(checkpoint):
147
+ keys = list(checkpoint.keys())
148
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
149
+ for key in keys:
150
+ if ".".join(key.split(".")[-2:]) in attn_keys:
151
+ if checkpoint[key].ndim > 2:
152
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
153
+ elif "proj_attn.weight" in key:
154
+ if checkpoint[key].ndim > 2:
155
+ checkpoint[key] = checkpoint[key][:, :, 0]
156
+
157
+
158
+ def create_vae_diffusers_config(original_config):
159
+ """
160
+ Creates a config for the diffusers based on the config of the LDM model.
161
+ """
162
+ vae_params = original_config.model.params.ddconfig
163
+ _ = original_config.model.params.embed_dim
164
+
165
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
166
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
167
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
168
+
169
+ config = dict(
170
+ sample_size=vae_params.resolution,
171
+ in_channels=vae_params.in_channels,
172
+ out_channels=vae_params.out_ch,
173
+ down_block_types=tuple(down_block_types),
174
+ up_block_types=tuple(up_block_types),
175
+ block_out_channels=tuple(block_out_channels),
176
+ latent_channels=vae_params.z_channels,
177
+ layers_per_block=vae_params.num_res_blocks,
178
+ )
179
+ return config
180
+
181
+
182
+ def convert_ldm_vae_checkpoint(checkpoint, config):
183
+ # extract state dict for VAE
184
+ vae_state_dict = checkpoint
185
+
186
+ new_checkpoint = {}
187
+
188
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict[
189
+ "encoder.conv_in.weight"]
190
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict[
191
+ "encoder.conv_in.bias"]
192
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[
193
+ "encoder.conv_out.weight"]
194
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict[
195
+ "encoder.conv_out.bias"]
196
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[
197
+ "encoder.norm_out.weight"]
198
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[
199
+ "encoder.norm_out.bias"]
200
+
201
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict[
202
+ "decoder.conv_in.weight"]
203
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict[
204
+ "decoder.conv_in.bias"]
205
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[
206
+ "decoder.conv_out.weight"]
207
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict[
208
+ "decoder.conv_out.bias"]
209
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[
210
+ "decoder.norm_out.weight"]
211
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[
212
+ "decoder.norm_out.bias"]
213
+
214
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
215
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
216
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict[
217
+ "post_quant_conv.weight"]
218
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict[
219
+ "post_quant_conv.bias"]
220
+
221
+ # Retrieves the keys for the encoder down blocks only
222
+ num_down_blocks = len({
223
+ ".".join(layer.split(".")[:3])
224
+ for layer in vae_state_dict if "encoder.down" in layer
225
+ })
226
+ down_blocks = {
227
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key]
228
+ for layer_id in range(num_down_blocks)
229
+ }
230
+
231
+ # Retrieves the keys for the decoder up blocks only
232
+ num_up_blocks = len({
233
+ ".".join(layer.split(".")[:3])
234
+ for layer in vae_state_dict if "decoder.up" in layer
235
+ })
236
+ up_blocks = {
237
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key]
238
+ for layer_id in range(num_up_blocks)
239
+ }
240
+
241
+ for i in range(num_down_blocks):
242
+ resnets = [
243
+ key for key in down_blocks[i]
244
+ if f"down.{i}" in key and f"down.{i}.downsample" not in key
245
+ ]
246
+
247
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
248
+ new_checkpoint[
249
+ f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
250
+ f"encoder.down.{i}.downsample.conv.weight")
251
+ new_checkpoint[
252
+ f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
253
+ f"encoder.down.{i}.downsample.conv.bias")
254
+
255
+ paths = renew_vae_resnet_paths(resnets)
256
+ meta_path = {
257
+ "old": f"down.{i}.block",
258
+ "new": f"down_blocks.{i}.resnets"
259
+ }
260
+ assign_to_checkpoint(paths,
261
+ new_checkpoint,
262
+ vae_state_dict,
263
+ additional_replacements=[meta_path],
264
+ config=config)
265
+
266
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
267
+ num_mid_res_blocks = 2
268
+ for i in range(1, num_mid_res_blocks + 1):
269
+ resnets = [
270
+ key for key in mid_resnets if f"encoder.mid.block_{i}" in key
271
+ ]
272
+
273
+ paths = renew_vae_resnet_paths(resnets)
274
+ meta_path = {
275
+ "old": f"mid.block_{i}",
276
+ "new": f"mid_block.resnets.{i - 1}"
277
+ }
278
+ assign_to_checkpoint(paths,
279
+ new_checkpoint,
280
+ vae_state_dict,
281
+ additional_replacements=[meta_path],
282
+ config=config)
283
+
284
+ mid_attentions = [
285
+ key for key in vae_state_dict if "encoder.mid.attn" in key
286
+ ]
287
+ paths = renew_vae_attention_paths(mid_attentions)
288
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
289
+ assign_to_checkpoint(paths,
290
+ new_checkpoint,
291
+ vae_state_dict,
292
+ additional_replacements=[meta_path],
293
+ config=config)
294
+ conv_attn_to_linear(new_checkpoint)
295
+
296
+ for i in range(num_up_blocks):
297
+ block_id = num_up_blocks - 1 - i
298
+ resnets = [
299
+ key for key in up_blocks[block_id]
300
+ if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
301
+ ]
302
+
303
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
304
+ new_checkpoint[
305
+ f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
306
+ f"decoder.up.{block_id}.upsample.conv.weight"]
307
+ new_checkpoint[
308
+ f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
309
+ f"decoder.up.{block_id}.upsample.conv.bias"]
310
+
311
+ paths = renew_vae_resnet_paths(resnets)
312
+ meta_path = {
313
+ "old": f"up.{block_id}.block",
314
+ "new": f"up_blocks.{i}.resnets"
315
+ }
316
+ assign_to_checkpoint(paths,
317
+ new_checkpoint,
318
+ vae_state_dict,
319
+ additional_replacements=[meta_path],
320
+ config=config)
321
+
322
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
323
+ num_mid_res_blocks = 2
324
+ for i in range(1, num_mid_res_blocks + 1):
325
+ resnets = [
326
+ key for key in mid_resnets if f"decoder.mid.block_{i}" in key
327
+ ]
328
+
329
+ paths = renew_vae_resnet_paths(resnets)
330
+ meta_path = {
331
+ "old": f"mid.block_{i}",
332
+ "new": f"mid_block.resnets.{i - 1}"
333
+ }
334
+ assign_to_checkpoint(paths,
335
+ new_checkpoint,
336
+ vae_state_dict,
337
+ additional_replacements=[meta_path],
338
+ config=config)
339
+
340
+ mid_attentions = [
341
+ key for key in vae_state_dict if "decoder.mid.attn" in key
342
+ ]
343
+ paths = renew_vae_attention_paths(mid_attentions)
344
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
345
+ assign_to_checkpoint(paths,
346
+ new_checkpoint,
347
+ vae_state_dict,
348
+ additional_replacements=[meta_path],
349
+ config=config)
350
+ conv_attn_to_linear(new_checkpoint)
351
+ return new_checkpoint
352
+
353
+ def convert_ldm_to_hf_vae(ldm_checkpoint, ldm_config, hf_checkpoint):
354
+ checkpoint = torch.load(ldm_checkpoint)["state_dict"]
355
+
356
+ # Convert the VAE model.
357
+ vae_config = create_vae_diffusers_config(ldm_config)
358
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(
359
+ checkpoint, vae_config)
360
+
361
+ vae = AutoencoderKL(**vae_config)
362
+ vae.load_state_dict(converted_vae_checkpoint)
363
+ vae.save_pretrained(hf_checkpoint)
accelerate_deepspeed.yaml β†’ config/accelerate_deepspeed.yaml RENAMED
File without changes
accelerate_local.yaml β†’ config/accelerate_local.yaml RENAMED
File without changes
accelerate_sagemaker.yaml β†’ config/accelerate_sagemaker.yaml RENAMED
File without changes
config/ldm_autoencoder_kl.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model:
3
+ base_learning_rate: 4.5e-6
4
+ target: ldm.models.autoencoder.AutoencoderKL
5
+ params:
6
+ monitor: "val/rec_loss"
7
+ embed_dim: 3
8
+ lossconfig:
9
+ target: ldm.modules.losses.LPIPSWithDiscriminator
10
+ params:
11
+ disc_start: 50001
12
+ kl_weight: 0.000001
13
+ disc_weight: 0.5
14
+
15
+ ddconfig:
16
+ double_z: True
17
+ z_channels: 4
18
+ resolution: 256
19
+ in_channels: 3
20
+ out_ch: 3
21
+ ch: 128
22
+ ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
23
+ num_res_blocks: 2
24
+ attn_resolutions: [ ]
25
+ dropout: 0.0
26
+
27
+ lightning:
28
+ trainer:
29
+ benchmark: True
30
+ accelerator: gpu
31
+ devices: 1
notebooks/test_vae.ipynb ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "bcbbe26c",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "import sys\n",
12
+ "sys.path.insert(0, os.path.dirname(os.path.abspath(\"\")))"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "b451ab22",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "import torch\n",
23
+ "import random\n",
24
+ "import numpy as np\n",
25
+ "from PIL import Image\n",
26
+ "from datasets import load_dataset\n",
27
+ "from IPython.display import Audio\n",
28
+ "from diffusers import AutoencoderKL\n",
29
+ "from audiodiffusion.mel import Mel"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "id": "324cef44",
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "mel = Mel()\n",
40
+ "vae = AutoencoderKL.from_pretrained('../models/autoencoder-kl')"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "id": "da55ce79",
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "vae.config"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "id": "5fea99ff",
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "ds = load_dataset('teticio/audio-diffusion-256')"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "id": "426c6edd",
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "image = random.choice(ds['train'])['image']\n",
71
+ "display(image)\n",
72
+ "Audio(data=mel.image_to_audio(image), rate=mel.get_sample_rate())"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "id": "d123f8a0",
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "# encode\n",
83
+ "input_image = np.frombuffer(image.convert('RGB').tobytes(), dtype=\"uint8\").reshape(\n",
84
+ " (image.height, image.width, 3))\n",
85
+ "input_image = ((input_image / 255) * 2 - 1).transpose(2, 0, 1)\n",
86
+ "posterior = vae.encode(torch.tensor([input_image], dtype=torch.float32)).latent_dist\n",
87
+ "latents = posterior.sample()"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "id": "482c458f",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "# reconstruct\n",
98
+ "output_image = vae.decode(latents)['sample']\n",
99
+ "output_image = torch.clamp(output_image, -1., 1.)\n",
100
+ "output_image = (output_image + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w\n",
101
+ "output_image = (output_image.detach().cpu().numpy() *\n",
102
+ " 255).round().astype(\"uint8\").transpose(0, 2, 3, 1)[0]\n",
103
+ "output_image = Image.fromarray(output_image).convert('L')\n",
104
+ "display(output_image)\n",
105
+ "Audio(data=mel.image_to_audio(output_image), rate=mel.get_sample_rate())"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "id": "f10db020",
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "# sample\n",
116
+ "output_image = vae.decode(torch.randn_like(posterior.sample()))['sample']\n",
117
+ "output_image = torch.clamp(output_image, -1., 1.)\n",
118
+ "output_image = (output_image + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w\n",
119
+ "output_image = (output_image.detach().cpu().numpy() *\n",
120
+ " 255).round().astype(\"uint8\").transpose(0, 2, 3, 1)[0]\n",
121
+ "output_image = Image.fromarray(output_image).convert('L')\n",
122
+ "display(output_image)\n",
123
+ "Audio(data=mel.image_to_audio(output_image), rate=mel.get_sample_rate())"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "id": "46019770",
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": []
133
+ }
134
+ ],
135
+ "metadata": {
136
+ "kernelspec": {
137
+ "display_name": "huggingface",
138
+ "language": "python",
139
+ "name": "huggingface"
140
+ },
141
+ "language_info": {
142
+ "codemirror_mode": {
143
+ "name": "ipython",
144
+ "version": 3
145
+ },
146
+ "file_extension": ".py",
147
+ "mimetype": "text/x-python",
148
+ "name": "python",
149
+ "nbconvert_exporter": "python",
150
+ "pygments_lexer": "ipython3",
151
+ "version": "3.10.6"
152
+ },
153
+ "toc": {
154
+ "base_numbering": 1,
155
+ "nav_menu": {},
156
+ "number_sections": true,
157
+ "sideBar": true,
158
+ "skip_h1_title": false,
159
+ "title_cell": "Table of Contents",
160
+ "title_sidebar": "Contents",
161
+ "toc_cell": false,
162
+ "toc_position": {},
163
+ "toc_section_display": true,
164
+ "toc_window_display": false
165
+ }
166
+ },
167
+ "nbformat": 4,
168
+ "nbformat_minor": 5
169
+ }
audio_to_images.py β†’ scripts/audio_to_images.py RENAMED
File without changes
train_unconditional.py β†’ scripts/train_unconditional.py RENAMED
@@ -5,12 +5,12 @@ import os
5
 
6
  import torch
7
  import torch.nn.functional as F
8
- from PIL import Image
9
 
10
  from accelerate import Accelerator
11
  from accelerate.logging import get_logger
12
  from datasets import load_from_disk, load_dataset
13
- from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
 
14
  from diffusers.hub_utils import init_git_repo, push_to_hub
15
  from diffusers.optimization import get_scheduler
16
  from diffusers.training_utils import EMAModel
@@ -22,10 +22,12 @@ from torchvision.transforms import (
22
  Resize,
23
  ToTensor,
24
  )
 
25
  from tqdm.auto import tqdm
26
  from librosa.util import normalize
27
 
28
  from audiodiffusion.mel import Mel
 
29
 
30
  logger = get_logger(__name__)
31
 
@@ -34,18 +36,25 @@ def main(args):
34
  output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
35
  logging_dir = os.path.join(output_dir, args.logging_dir)
36
  accelerator = Accelerator(
 
37
  mixed_precision=args.mixed_precision,
38
  log_with="tensorboard",
39
  logging_dir=logging_dir,
40
  )
41
 
 
 
 
42
  if args.from_pretrained is not None:
43
- model = DDPMPipeline.from_pretrained(args.from_pretrained).unet
44
  else:
45
  model = UNet2DModel(
46
- sample_size=args.resolution,
47
- in_channels=1,
48
- out_channels=1,
 
 
 
49
  layers_per_block=2,
50
  block_out_channels=(128, 128, 256, 256, 512, 512),
51
  down_block_types=(
@@ -65,8 +74,14 @@ def main(args):
65
  "UpBlock2D",
66
  ),
67
  )
68
- noise_scheduler = DDPMScheduler(num_train_timesteps=1000,
69
- tensor_format="pt")
 
 
 
 
 
 
70
  optimizer = torch.optim.AdamW(
71
  model.parameters(),
72
  lr=args.learning_rate,
@@ -103,7 +118,13 @@ def main(args):
103
  )
104
 
105
  def transforms(examples):
106
- images = [augmentations(image) for image in examples["image"]]
 
 
 
 
 
 
107
  return {"input": images}
108
 
109
  dataset.set_transform(transforms)
@@ -158,6 +179,15 @@ def main(args):
158
  model.train()
159
  for step, batch in enumerate(train_dataloader):
160
  clean_images = batch["input"]
 
 
 
 
 
 
 
 
 
161
  # Sample noise that we'll add to the images
162
  noise = torch.randn(clean_images.shape).to(clean_images.device)
163
  bsz = clean_images.shape[0]
@@ -180,7 +210,8 @@ def main(args):
180
  loss = F.mse_loss(noise_pred, noise)
181
  accelerator.backward(loss)
182
 
183
- accelerator.clip_grad_norm_(model.parameters(), 1.0)
 
184
  optimizer.step()
185
  lr_scheduler.step()
186
  if args.use_ema:
@@ -188,6 +219,8 @@ def main(args):
188
  optimizer.zero_grad()
189
 
190
  progress_bar.update(1)
 
 
191
  logs = {
192
  "loss": loss.detach().item(),
193
  "lr": lr_scheduler.get_last_lr()[0],
@@ -197,7 +230,6 @@ def main(args):
197
  logs["ema_decay"] = ema_model.decay
198
  progress_bar.set_postfix(**logs)
199
  accelerator.log(logs, step=global_step)
200
- global_step += 1
201
  progress_bar.close()
202
 
203
  accelerator.wait_for_everyone()
@@ -205,11 +237,20 @@ def main(args):
205
  # Generate sample images for visual inspection
206
  if accelerator.is_main_process:
207
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
208
- pipeline = DDPMPipeline(
209
- unet=accelerator.unwrap_model(
210
- ema_model.averaged_model if args.use_ema else model),
211
- scheduler=noise_scheduler,
212
- )
 
 
 
 
 
 
 
 
 
213
 
214
  # save the model
215
  if args.push_to_hub:
@@ -226,27 +267,30 @@ def main(args):
226
  else:
227
  pipeline.save_pretrained(output_dir)
228
 
229
- generator = torch.manual_seed(0)
 
230
  # run pipeline in inference (sample random noise and denoise)
231
- images = pipeline(
 
232
  generator=generator,
233
  batch_size=args.eval_batch_size,
234
- output_type="numpy",
235
- )["sample"]
236
 
237
  # denormalize the images and save to tensorboard
238
- images_processed = ((images *
239
- 255).round().astype("uint8").transpose(
240
- 0, 3, 1, 2))
 
 
241
  accelerator.trackers[0].writer.add_images(
242
- "test_samples", images_processed, epoch)
243
- for _, image in enumerate(images_processed):
244
- audio = mel.image_to_audio(Image.fromarray(image[0]))
245
  accelerator.trackers[0].writer.add_audio(
246
  f"test_audio_{_}",
247
  normalize(audio),
248
  epoch,
249
- sample_rate=mel.get_sample_rate(),
250
  )
251
  accelerator.wait_for_everyone()
252
 
@@ -268,7 +312,7 @@ if __name__ == "__main__":
268
  parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
269
  parser.add_argument("--overwrite_output_dir", type=bool, default=False)
270
  parser.add_argument("--cache_dir", type=str, default=None)
271
- parser.add_argument("--resolution", type=int, default=64)
272
  parser.add_argument("--train_batch_size", type=int, default=16)
273
  parser.add_argument("--eval_batch_size", type=int, default=16)
274
  parser.add_argument("--num_epochs", type=int, default=100)
@@ -305,6 +349,16 @@ if __name__ == "__main__":
305
  parser.add_argument("--hop_length", type=int, default=512)
306
  parser.add_argument("--from_pretrained", type=str, default=None)
307
  parser.add_argument("--start_epoch", type=int, default=0)
 
 
 
 
 
 
 
 
 
 
308
 
309
  args = parser.parse_args()
310
  env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
 
5
 
6
  import torch
7
  import torch.nn.functional as F
 
8
 
9
  from accelerate import Accelerator
10
  from accelerate.logging import get_logger
11
  from datasets import load_from_disk, load_dataset
12
+ from diffusers import (DiffusionPipeline, DDPMScheduler, UNet2DModel,
13
+ DDIMScheduler, AutoencoderKL)
14
  from diffusers.hub_utils import init_git_repo, push_to_hub
15
  from diffusers.optimization import get_scheduler
16
  from diffusers.training_utils import EMAModel
 
22
  Resize,
23
  ToTensor,
24
  )
25
+ import numpy as np
26
  from tqdm.auto import tqdm
27
  from librosa.util import normalize
28
 
29
  from audiodiffusion.mel import Mel
30
+ from audiodiffusion import LatentAudioDiffusionPipeline, AudioDiffusionPipeline
31
 
32
  logger = get_logger(__name__)
33
 
 
36
  output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
37
  logging_dir = os.path.join(output_dir, args.logging_dir)
38
  accelerator = Accelerator(
39
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
40
  mixed_precision=args.mixed_precision,
41
  log_with="tensorboard",
42
  logging_dir=logging_dir,
43
  )
44
 
45
+ if args.vae is not None:
46
+ vqvae = AutoencoderKL.from_pretrained(args.vae)
47
+
48
  if args.from_pretrained is not None:
49
+ model = DiffusionPipeline.from_pretrained(args.from_pretrained).unet
50
  else:
51
  model = UNet2DModel(
52
+ sample_size=args.resolution
53
+ if args.vae is None else args.latent_resolution,
54
+ in_channels=1
55
+ if args.vae is None else vqvae.config['latent_channels'],
56
+ out_channels=1
57
+ if args.vae is None else vqvae.config['latent_channels'],
58
  layers_per_block=2,
59
  block_out_channels=(128, 128, 256, 256, 512, 512),
60
  down_block_types=(
 
74
  "UpBlock2D",
75
  ),
76
  )
77
+
78
+ if args.scheduler == "ddpm":
79
+ noise_scheduler = DDPMScheduler(
80
+ num_train_timesteps=args.num_train_steps, tensor_format="pt")
81
+ else:
82
+ noise_scheduler = DDIMScheduler(
83
+ num_train_timesteps=args.num_train_steps, tensor_format="pt")
84
+
85
  optimizer = torch.optim.AdamW(
86
  model.parameters(),
87
  lr=args.learning_rate,
 
118
  )
119
 
120
  def transforms(examples):
121
+ if args.vae is not None and vqvae.config['in_channels'] == 3:
122
+ images = [
123
+ augmentations(image.convert('RGB'))
124
+ for image in examples["image"]
125
+ ]
126
+ else:
127
+ images = [augmentations(image) for image in examples["image"]]
128
  return {"input": images}
129
 
130
  dataset.set_transform(transforms)
 
179
  model.train()
180
  for step, batch in enumerate(train_dataloader):
181
  clean_images = batch["input"]
182
+
183
+ if args.vae is not None:
184
+ vqvae.to(clean_images.device)
185
+ with torch.no_grad():
186
+ clean_images = vqvae.encode(
187
+ clean_images).latent_dist.sample()
188
+ # Scale latent images to ensure approximately unit variance
189
+ clean_images = clean_images * 0.18215
190
+
191
  # Sample noise that we'll add to the images
192
  noise = torch.randn(clean_images.shape).to(clean_images.device)
193
  bsz = clean_images.shape[0]
 
210
  loss = F.mse_loss(noise_pred, noise)
211
  accelerator.backward(loss)
212
 
213
+ if accelerator.sync_gradients:
214
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
215
  optimizer.step()
216
  lr_scheduler.step()
217
  if args.use_ema:
 
219
  optimizer.zero_grad()
220
 
221
  progress_bar.update(1)
222
+ global_step += 1
223
+
224
  logs = {
225
  "loss": loss.detach().item(),
226
  "lr": lr_scheduler.get_last_lr()[0],
 
230
  logs["ema_decay"] = ema_model.decay
231
  progress_bar.set_postfix(**logs)
232
  accelerator.log(logs, step=global_step)
 
233
  progress_bar.close()
234
 
235
  accelerator.wait_for_everyone()
 
237
  # Generate sample images for visual inspection
238
  if accelerator.is_main_process:
239
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
240
+ if args.vae is not None:
241
+ pipeline = LatentAudioDiffusionPipeline(
242
+ unet=accelerator.unwrap_model(
243
+ ema_model.averaged_model if args.use_ema else model
244
+ ),
245
+ vqvae=vqvae,
246
+ scheduler=noise_scheduler)
247
+ else:
248
+ pipeline = AudioDiffusionPipeline(
249
+ unet=accelerator.unwrap_model(
250
+ ema_model.averaged_model if args.use_ema else model
251
+ ),
252
+ scheduler=noise_scheduler,
253
+ )
254
 
255
  # save the model
256
  if args.push_to_hub:
 
267
  else:
268
  pipeline.save_pretrained(output_dir)
269
 
270
+ if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
271
+ generator = torch.manual_seed(42)
272
  # run pipeline in inference (sample random noise and denoise)
273
+ images, (sample_rate, audios) = pipeline(
274
+ mel=mel,
275
  generator=generator,
276
  batch_size=args.eval_batch_size,
277
+ steps=args.num_train_steps,
278
+ )
279
 
280
  # denormalize the images and save to tensorboard
281
+ images = np.array([
282
+ np.frombuffer(image.tobytes(), dtype="uint8").reshape(
283
+ (len(image.getbands()), image.height, image.width))
284
+ for image in images
285
+ ])
286
  accelerator.trackers[0].writer.add_images(
287
+ "test_samples", images, epoch)
288
+ for _, audio in enumerate(audios):
 
289
  accelerator.trackers[0].writer.add_audio(
290
  f"test_audio_{_}",
291
  normalize(audio),
292
  epoch,
293
+ sample_rate=sample_rate,
294
  )
295
  accelerator.wait_for_everyone()
296
 
 
312
  parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
313
  parser.add_argument("--overwrite_output_dir", type=bool, default=False)
314
  parser.add_argument("--cache_dir", type=str, default=None)
315
+ parser.add_argument("--resolution", type=int, default=256)
316
  parser.add_argument("--train_batch_size", type=int, default=16)
317
  parser.add_argument("--eval_batch_size", type=int, default=16)
318
  parser.add_argument("--num_epochs", type=int, default=100)
 
349
  parser.add_argument("--hop_length", type=int, default=512)
350
  parser.add_argument("--from_pretrained", type=str, default=None)
351
  parser.add_argument("--start_epoch", type=int, default=0)
352
+ parser.add_argument("--num_train_steps", type=int, default=1000)
353
+ parser.add_argument("--latent_resolution", type=int, default=None)
354
+ parser.add_argument("--scheduler",
355
+ type=str,
356
+ default="ddpm",
357
+ help="ddpm or ddim")
358
+ parser.add_argument("--vae",
359
+ type=str,
360
+ default=None,
361
+ help="pretrained VAE model for latent diffusion")
362
 
363
  args = parser.parse_args()
364
  env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
scripts/train_vae.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install -e git+https://github.com/CompVis/stable-diffusion.git@master
2
+ # pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
3
+
4
+ # TODO
5
+ # grayscale
6
+
7
+ import os
8
+ import argparse
9
+
10
+ import torch
11
+ import torchvision
12
+ import numpy as np
13
+ from PIL import Image
14
+ import pytorch_lightning as pl
15
+ from omegaconf import OmegaConf
16
+ from librosa.util import normalize
17
+ from ldm.util import instantiate_from_config
18
+ from pytorch_lightning.trainer import Trainer
19
+ from torch.utils.data import DataLoader, Dataset
20
+ from datasets import load_from_disk, load_dataset
21
+ from pytorch_lightning.callbacks import Callback, ModelCheckpoint
22
+ from pytorch_lightning.utilities.distributed import rank_zero_only
23
+
24
+ from audiodiffusion.mel import Mel
25
+ from audiodiffusion.utils import convert_ldm_to_hf_vae
26
+
27
+
28
+ class AudioDiffusion(Dataset):
29
+
30
+ def __init__(self, model_id):
31
+ super().__init__()
32
+ if os.path.exists(model_id):
33
+ self.hf_dataset = load_from_disk(model_id)['train']
34
+ else:
35
+ self.hf_dataset = load_dataset(model_id)['train']
36
+
37
+ def __len__(self):
38
+ return len(self.hf_dataset)
39
+
40
+ def __getitem__(self, idx):
41
+ image = self.hf_dataset[idx]['image'].convert('RGB')
42
+ image = np.frombuffer(image.tobytes(), dtype="uint8").reshape(
43
+ (image.height, image.width, 3))
44
+ image = ((image / 255) * 2 - 1)
45
+ return {'image': image}
46
+
47
+
48
+ class AudioDiffusionDataModule(pl.LightningDataModule):
49
+
50
+ def __init__(self, model_id, batch_size):
51
+ super().__init__()
52
+ self.batch_size = batch_size
53
+ self.dataset = AudioDiffusion(model_id)
54
+ self.num_workers = 1
55
+
56
+ def train_dataloader(self):
57
+ return DataLoader(self.dataset,
58
+ batch_size=self.batch_size,
59
+ num_workers=self.num_workers)
60
+
61
+
62
+ class ImageLogger(Callback):
63
+
64
+ def __init__(self, every=1000, resolution=256, hop_length=512):
65
+ super().__init__()
66
+ self.mel = Mel(x_res=resolution,
67
+ y_res=resolution,
68
+ hop_length=hop_length)
69
+ self.every = every
70
+
71
+ @rank_zero_only
72
+ def log_images_and_audios(self, pl_module, batch):
73
+ pl_module.eval()
74
+ with torch.no_grad():
75
+ images = pl_module.log_images(batch, split='train')
76
+ pl_module.train()
77
+
78
+ for k in images:
79
+ images[k] = images[k].detach().cpu()
80
+ images[k] = torch.clamp(images[k], -1., 1.)
81
+ images[k] = (images[k] + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
82
+ grid = torchvision.utils.make_grid(images[k])
83
+
84
+ tag = f"train/{k}"
85
+ pl_module.logger.experiment.add_image(
86
+ tag, grid, global_step=pl_module.global_step)
87
+
88
+ images[k] = (images[k].numpy() *
89
+ 255).round().astype("uint8").transpose(0, 2, 3, 1)
90
+ for _, image in enumerate(images[k]):
91
+ audio = self.mel.image_to_audio(
92
+ Image.fromarray(image, mode='RGB').convert('L'))
93
+ pl_module.logger.experiment.add_audio(
94
+ tag + f"/{_}",
95
+ normalize(audio),
96
+ global_step=pl_module.global_step,
97
+ sample_rate=self.mel.get_sample_rate())
98
+
99
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch,
100
+ batch_idx):
101
+ if (batch_idx + 1) % self.every != 0:
102
+ return
103
+ self.log_images_and_audios(pl_module, batch)
104
+
105
+
106
+ class HFModelCheckpoint(ModelCheckpoint):
107
+
108
+ def __init__(self, ldm_config, hf_checkpoint, *args, **kwargs):
109
+ super().__init__(*args, **kwargs)
110
+ self.ldm_config = ldm_config
111
+ self.hf_checkpoint = hf_checkpoint
112
+
113
+ def on_train_epoch_end(self, trainer, pl_module):
114
+ super().on_train_epoch_end(trainer, pl_module)
115
+ ldm_checkpoint = self.format_checkpoint_name(
116
+ {'epoch': trainer.current_epoch})
117
+ convert_ldm_to_hf_vae(ldm_checkpoint, self.ldm_config,
118
+ self.hf_checkpoint)
119
+
120
+
121
+ if __name__ == "__main__":
122
+ parser = argparse.ArgumentParser(description="Train VAE using ldm.")
123
+ parser.add_argument("-d", "--dataset_name", type=str, default=None)
124
+ parser.add_argument("-b", "--batch_size", type=int, default=1)
125
+ parser.add_argument("-c",
126
+ "--ldm_config_file",
127
+ type=str,
128
+ default="config/ldm_autoencoder_kl.yaml")
129
+ parser.add_argument("--ldm_checkpoint_dir",
130
+ type=str,
131
+ default="models/ldm-autoencoder-kl")
132
+ parser.add_argument("--hf_checkpoint_dir",
133
+ type=str,
134
+ default="models/autoencoder-kl")
135
+ parser.add_argument("-r",
136
+ "--resume_from_checkpoint",
137
+ type=str,
138
+ default=None)
139
+ parser.add_argument("-g",
140
+ "--gradient_accumulation_steps",
141
+ type=int,
142
+ default=1)
143
+ args = parser.parse_args()
144
+
145
+ config = OmegaConf.load(args.ldm_config_file)
146
+ lightning_config = config.pop("lightning", OmegaConf.create())
147
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
148
+ trainer_config.accumulate_grad_batches = args.gradient_accumulation_steps
149
+ trainer_opt = argparse.Namespace(**trainer_config)
150
+ trainer = Trainer.from_argparse_args(
151
+ trainer_opt,
152
+ resume_from_checkpoint=args.resume_from_checkpoint,
153
+ callbacks=[
154
+ ImageLogger(),
155
+ HFModelCheckpoint(ldm_config=config,
156
+ hf_checkpoint=args.hf_checkpoint_dir,
157
+ dirpath=args.ldm_checkpoint_dir,
158
+ filename='{epoch:06}',
159
+ verbose=True,
160
+ save_last=True)
161
+ ])
162
+ model = instantiate_from_config(config.model)
163
+ model.learning_rate = config.model.base_learning_rate
164
+ data = AudioDiffusionDataModule(args.dataset_name,
165
+ batch_size=args.batch_size)
166
+ trainer.fit(model, data)