deepanway commited on
Commit
6b448ad
1 Parent(s): 0a02aec

add required files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +1 -1
  2. audioldm/__init__.py +8 -0
  3. audioldm/__main__.py +183 -0
  4. audioldm/__pycache__/__init__.cpython-39.pyc +0 -0
  5. audioldm/__pycache__/ldm.cpython-39.pyc +0 -0
  6. audioldm/__pycache__/pipeline.cpython-39.pyc +0 -0
  7. audioldm/__pycache__/utils.cpython-39.pyc +0 -0
  8. audioldm/audio/__init__.py +2 -0
  9. audioldm/audio/__pycache__/__init__.cpython-39.pyc +0 -0
  10. audioldm/audio/__pycache__/audio_processing.cpython-39.pyc +0 -0
  11. audioldm/audio/__pycache__/mix.cpython-39.pyc +0 -0
  12. audioldm/audio/__pycache__/stft.cpython-39.pyc +0 -0
  13. audioldm/audio/__pycache__/tools.cpython-39.pyc +0 -0
  14. audioldm/audio/__pycache__/torch_tools.cpython-39.pyc +0 -0
  15. audioldm/audio/audio_processing.py +100 -0
  16. audioldm/audio/stft.py +186 -0
  17. audioldm/audio/tools.py +85 -0
  18. audioldm/hifigan/__init__.py +7 -0
  19. audioldm/hifigan/__pycache__/__init__.cpython-39.pyc +0 -0
  20. audioldm/hifigan/__pycache__/models.cpython-39.pyc +0 -0
  21. audioldm/hifigan/__pycache__/utilities.cpython-39.pyc +0 -0
  22. audioldm/hifigan/models.py +174 -0
  23. audioldm/hifigan/utilities.py +86 -0
  24. audioldm/latent_diffusion/__init__.py +0 -0
  25. audioldm/latent_diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
  26. audioldm/latent_diffusion/__pycache__/attention.cpython-39.pyc +0 -0
  27. audioldm/latent_diffusion/__pycache__/ddim.cpython-39.pyc +0 -0
  28. audioldm/latent_diffusion/__pycache__/ddpm.cpython-39.pyc +0 -0
  29. audioldm/latent_diffusion/__pycache__/ema.cpython-39.pyc +0 -0
  30. audioldm/latent_diffusion/__pycache__/openaimodel.cpython-39.pyc +0 -0
  31. audioldm/latent_diffusion/__pycache__/util.cpython-39.pyc +0 -0
  32. audioldm/latent_diffusion/attention.py +469 -0
  33. audioldm/latent_diffusion/ddim.py +377 -0
  34. audioldm/latent_diffusion/ddpm.py +441 -0
  35. audioldm/latent_diffusion/ema.py +82 -0
  36. audioldm/latent_diffusion/openaimodel.py +1069 -0
  37. audioldm/latent_diffusion/util.py +295 -0
  38. audioldm/ldm.py +818 -0
  39. audioldm/pipeline.py +301 -0
  40. audioldm/utils.py +281 -0
  41. audioldm/variational_autoencoder/__init__.py +1 -0
  42. audioldm/variational_autoencoder/__pycache__/__init__.cpython-39.pyc +0 -0
  43. audioldm/variational_autoencoder/__pycache__/autoencoder.cpython-39.pyc +0 -0
  44. audioldm/variational_autoencoder/__pycache__/distributions.cpython-39.pyc +0 -0
  45. audioldm/variational_autoencoder/__pycache__/modules.cpython-39.pyc +0 -0
  46. audioldm/variational_autoencoder/autoencoder.py +135 -0
  47. audioldm/variational_autoencoder/distributions.py +102 -0
  48. audioldm/variational_autoencoder/modules.py +1066 -0
  49. diffusers/CITATION.cff +40 -0
  50. diffusers/CODE_OF_CONDUCT.md +130 -0
app.py CHANGED
@@ -94,7 +94,7 @@ gr_interface = gr.Interface(
94
  inputs=input_text,
95
  outputs=[output_audio],
96
  title="Tango Audio Generator",
97
- description="Generate audio using Tango model by providing a text prompt.",
98
  allow_flagging=False,
99
  examples=[
100
  ["A Dog Barking"],
 
94
  inputs=input_text,
95
  outputs=[output_audio],
96
  title="Tango Audio Generator",
97
+ description="Generate audio using Tango by providing a text prompt.",
98
  allow_flagging=False,
99
  examples=[
100
  ["A Dog Barking"],
audioldm/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .ldm import LatentDiffusion
2
+ from .utils import seed_everything, save_wave, get_time, get_duration
3
+ from .pipeline import *
4
+
5
+
6
+
7
+
8
+
audioldm/__main__.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ import os
3
+ from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration
4
+ import argparse
5
+
6
+ CACHE_DIR = os.getenv(
7
+ "AUDIOLDM_CACHE_DIR",
8
+ os.path.join(os.path.expanduser("~"), ".cache/audioldm"))
9
+
10
+ parser = argparse.ArgumentParser()
11
+
12
+ parser.add_argument(
13
+ "--mode",
14
+ type=str,
15
+ required=False,
16
+ default="generation",
17
+ help="generation: text-to-audio generation; transfer: style transfer",
18
+ choices=["generation", "transfer"]
19
+ )
20
+
21
+ parser.add_argument(
22
+ "-t",
23
+ "--text",
24
+ type=str,
25
+ required=False,
26
+ default="",
27
+ help="Text prompt to the model for audio generation",
28
+ )
29
+
30
+ parser.add_argument(
31
+ "-f",
32
+ "--file_path",
33
+ type=str,
34
+ required=False,
35
+ default=None,
36
+ help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio",
37
+ )
38
+
39
+ parser.add_argument(
40
+ "--transfer_strength",
41
+ type=float,
42
+ required=False,
43
+ default=0.5,
44
+ help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text",
45
+ )
46
+
47
+ parser.add_argument(
48
+ "-s",
49
+ "--save_path",
50
+ type=str,
51
+ required=False,
52
+ help="The path to save model output",
53
+ default="./output",
54
+ )
55
+
56
+ parser.add_argument(
57
+ "--model_name",
58
+ type=str,
59
+ required=False,
60
+ help="The checkpoint you gonna use",
61
+ default="audioldm-s-full",
62
+ choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2"]
63
+ )
64
+
65
+ parser.add_argument(
66
+ "-ckpt",
67
+ "--ckpt_path",
68
+ type=str,
69
+ required=False,
70
+ help="The path to the pretrained .ckpt model",
71
+ default=None,
72
+ )
73
+
74
+ parser.add_argument(
75
+ "-b",
76
+ "--batchsize",
77
+ type=int,
78
+ required=False,
79
+ default=1,
80
+ help="Generate how many samples at the same time",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--ddim_steps",
85
+ type=int,
86
+ required=False,
87
+ default=200,
88
+ help="The sampling step for DDIM",
89
+ )
90
+
91
+ parser.add_argument(
92
+ "-gs",
93
+ "--guidance_scale",
94
+ type=float,
95
+ required=False,
96
+ default=2.5,
97
+ help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
98
+ )
99
+
100
+ parser.add_argument(
101
+ "-dur",
102
+ "--duration",
103
+ type=float,
104
+ required=False,
105
+ default=10.0,
106
+ help="The duration of the samples",
107
+ )
108
+
109
+ parser.add_argument(
110
+ "-n",
111
+ "--n_candidate_gen_per_text",
112
+ type=int,
113
+ required=False,
114
+ default=3,
115
+ help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
116
+ )
117
+
118
+ parser.add_argument(
119
+ "--seed",
120
+ type=int,
121
+ required=False,
122
+ default=42,
123
+ help="Change this value (any integer number) will lead to a different generation result.",
124
+ )
125
+
126
+ args = parser.parse_args()
127
+
128
+ if(args.ckpt_path is not None):
129
+ print("Warning: ckpt_path has no effect after version 0.0.20.")
130
+
131
+ assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5"
132
+
133
+ mode = args.mode
134
+ if(mode == "generation" and args.file_path is not None):
135
+ mode = "generation_audio_to_audio"
136
+ if(len(args.text) > 0):
137
+ print("Warning: You have specified the --file_path. --text will be ignored")
138
+ args.text = ""
139
+
140
+ save_path = os.path.join(args.save_path, mode)
141
+
142
+ if(args.file_path is not None):
143
+ save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0]))
144
+
145
+ text = args.text
146
+ random_seed = args.seed
147
+ duration = args.duration
148
+ guidance_scale = args.guidance_scale
149
+ n_candidate_gen_per_text = args.n_candidate_gen_per_text
150
+
151
+ os.makedirs(save_path, exist_ok=True)
152
+ audioldm = build_model(model_name=args.model_name)
153
+
154
+ if(args.mode == "generation"):
155
+ waveform = text_to_audio(
156
+ audioldm,
157
+ text,
158
+ args.file_path,
159
+ random_seed,
160
+ duration=duration,
161
+ guidance_scale=guidance_scale,
162
+ ddim_steps=args.ddim_steps,
163
+ n_candidate_gen_per_text=n_candidate_gen_per_text,
164
+ batchsize=args.batchsize,
165
+ )
166
+
167
+ elif(args.mode == "transfer"):
168
+ assert args.file_path is not None
169
+ assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path
170
+ waveform = style_transfer(
171
+ audioldm,
172
+ text,
173
+ args.file_path,
174
+ args.transfer_strength,
175
+ random_seed,
176
+ duration=duration,
177
+ guidance_scale=guidance_scale,
178
+ ddim_steps=args.ddim_steps,
179
+ batchsize=args.batchsize,
180
+ )
181
+ waveform = waveform[:,None,:]
182
+
183
+ save_wave(waveform, save_path, name="%s_%s" % (get_time(), text))
audioldm/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (322 Bytes). View file
 
audioldm/__pycache__/ldm.cpython-39.pyc ADDED
Binary file (16 kB). View file
 
audioldm/__pycache__/pipeline.cpython-39.pyc ADDED
Binary file (6.54 kB). View file
 
audioldm/__pycache__/utils.cpython-39.pyc ADDED
Binary file (7.35 kB). View file
 
audioldm/audio/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .tools import wav_to_fbank, read_wav_file
2
+ from .stft import TacotronSTFT
audioldm/audio/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (260 Bytes). View file
 
audioldm/audio/__pycache__/audio_processing.cpython-39.pyc ADDED
Binary file (2.78 kB). View file
 
audioldm/audio/__pycache__/mix.cpython-39.pyc ADDED
Binary file (1.7 kB). View file
 
audioldm/audio/__pycache__/stft.cpython-39.pyc ADDED
Binary file (4.99 kB). View file
 
audioldm/audio/__pycache__/tools.cpython-39.pyc ADDED
Binary file (2.19 kB). View file
 
audioldm/audio/__pycache__/torch_tools.cpython-39.pyc ADDED
Binary file (3.79 kB). View file
 
audioldm/audio/audio_processing.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import librosa.util as librosa_util
4
+ from scipy.signal import get_window
5
+
6
+
7
+ def window_sumsquare(
8
+ window,
9
+ n_frames,
10
+ hop_length,
11
+ win_length,
12
+ n_fft,
13
+ dtype=np.float32,
14
+ norm=None,
15
+ ):
16
+ """
17
+ # from librosa 0.6
18
+ Compute the sum-square envelope of a window function at a given hop length.
19
+
20
+ This is used to estimate modulation effects induced by windowing
21
+ observations in short-time fourier transforms.
22
+
23
+ Parameters
24
+ ----------
25
+ window : string, tuple, number, callable, or list-like
26
+ Window specification, as in `get_window`
27
+
28
+ n_frames : int > 0
29
+ The number of analysis frames
30
+
31
+ hop_length : int > 0
32
+ The number of samples to advance between frames
33
+
34
+ win_length : [optional]
35
+ The length of the window function. By default, this matches `n_fft`.
36
+
37
+ n_fft : int > 0
38
+ The length of each analysis frame.
39
+
40
+ dtype : np.dtype
41
+ The data type of the output
42
+
43
+ Returns
44
+ -------
45
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
46
+ The sum-squared envelope of the window function
47
+ """
48
+ if win_length is None:
49
+ win_length = n_fft
50
+
51
+ n = n_fft + hop_length * (n_frames - 1)
52
+ x = np.zeros(n, dtype=dtype)
53
+
54
+ # Compute the squared window at the desired length
55
+ win_sq = get_window(window, win_length, fftbins=True)
56
+ win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
57
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
58
+
59
+ # Fill the envelope
60
+ for i in range(n_frames):
61
+ sample = i * hop_length
62
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
63
+ return x
64
+
65
+
66
+ def griffin_lim(magnitudes, stft_fn, n_iters=30):
67
+ """
68
+ PARAMS
69
+ ------
70
+ magnitudes: spectrogram magnitudes
71
+ stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
72
+ """
73
+
74
+ angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
75
+ angles = angles.astype(np.float32)
76
+ angles = torch.autograd.Variable(torch.from_numpy(angles))
77
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
78
+
79
+ for i in range(n_iters):
80
+ _, angles = stft_fn.transform(signal)
81
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
82
+ return signal
83
+
84
+
85
+ def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
86
+ """
87
+ PARAMS
88
+ ------
89
+ C: compression factor
90
+ """
91
+ return normalize_fun(torch.clamp(x, min=clip_val) * C)
92
+
93
+
94
+ def dynamic_range_decompression(x, C=1):
95
+ """
96
+ PARAMS
97
+ ------
98
+ C: compression factor used to compress
99
+ """
100
+ return torch.exp(x) / C
audioldm/audio/stft.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from scipy.signal import get_window
5
+ from librosa.util import pad_center, tiny
6
+ from librosa.filters import mel as librosa_mel_fn
7
+
8
+ from audioldm.audio.audio_processing import (
9
+ dynamic_range_compression,
10
+ dynamic_range_decompression,
11
+ window_sumsquare,
12
+ )
13
+
14
+
15
+ class STFT(torch.nn.Module):
16
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
17
+
18
+ def __init__(self, filter_length, hop_length, win_length, window="hann"):
19
+ super(STFT, self).__init__()
20
+ self.filter_length = filter_length
21
+ self.hop_length = hop_length
22
+ self.win_length = win_length
23
+ self.window = window
24
+ self.forward_transform = None
25
+ scale = self.filter_length / self.hop_length
26
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
27
+
28
+ cutoff = int((self.filter_length / 2 + 1))
29
+ fourier_basis = np.vstack(
30
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
31
+ )
32
+
33
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
34
+ inverse_basis = torch.FloatTensor(
35
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :]
36
+ )
37
+
38
+ if window is not None:
39
+ assert filter_length >= win_length
40
+ # get window and zero center pad it to filter_length
41
+ fft_window = get_window(window, win_length, fftbins=True)
42
+ fft_window = pad_center(fft_window, filter_length)
43
+ fft_window = torch.from_numpy(fft_window).float()
44
+
45
+ # window the bases
46
+ forward_basis *= fft_window
47
+ inverse_basis *= fft_window
48
+
49
+ self.register_buffer("forward_basis", forward_basis.float())
50
+ self.register_buffer("inverse_basis", inverse_basis.float())
51
+
52
+ def transform(self, input_data):
53
+ device = self.forward_basis.device
54
+ input_data = input_data.to(device)
55
+
56
+ num_batches = input_data.size(0)
57
+ num_samples = input_data.size(1)
58
+
59
+ self.num_samples = num_samples
60
+
61
+ # similar to librosa, reflect-pad the input
62
+ input_data = input_data.view(num_batches, 1, num_samples)
63
+ input_data = F.pad(
64
+ input_data.unsqueeze(1),
65
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
66
+ mode="reflect",
67
+ )
68
+ input_data = input_data.squeeze(1)
69
+
70
+ forward_transform = F.conv1d(
71
+ input_data,
72
+ torch.autograd.Variable(self.forward_basis, requires_grad=False),
73
+ stride=self.hop_length,
74
+ padding=0,
75
+ )#.cpu()
76
+
77
+ cutoff = int((self.filter_length / 2) + 1)
78
+ real_part = forward_transform[:, :cutoff, :]
79
+ imag_part = forward_transform[:, cutoff:, :]
80
+
81
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
82
+ phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
83
+
84
+ return magnitude, phase
85
+
86
+ def inverse(self, magnitude, phase):
87
+ device = self.forward_basis.device
88
+ magnitude, phase = magnitude.to(device), phase.to(device)
89
+
90
+ recombine_magnitude_phase = torch.cat(
91
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
92
+ )
93
+
94
+ inverse_transform = F.conv_transpose1d(
95
+ recombine_magnitude_phase,
96
+ torch.autograd.Variable(self.inverse_basis, requires_grad=False),
97
+ stride=self.hop_length,
98
+ padding=0,
99
+ )
100
+
101
+ if self.window is not None:
102
+ window_sum = window_sumsquare(
103
+ self.window,
104
+ magnitude.size(-1),
105
+ hop_length=self.hop_length,
106
+ win_length=self.win_length,
107
+ n_fft=self.filter_length,
108
+ dtype=np.float32,
109
+ )
110
+ # remove modulation effects
111
+ approx_nonzero_indices = torch.from_numpy(
112
+ np.where(window_sum > tiny(window_sum))[0]
113
+ )
114
+ window_sum = torch.autograd.Variable(
115
+ torch.from_numpy(window_sum), requires_grad=False
116
+ )
117
+ window_sum = window_sum
118
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
119
+ approx_nonzero_indices
120
+ ]
121
+
122
+ # scale by hop ratio
123
+ inverse_transform *= float(self.filter_length) / self.hop_length
124
+
125
+ inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
126
+ inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
127
+
128
+ return inverse_transform
129
+
130
+ def forward(self, input_data):
131
+ self.magnitude, self.phase = self.transform(input_data)
132
+ reconstruction = self.inverse(self.magnitude, self.phase)
133
+ return reconstruction
134
+
135
+
136
+ class TacotronSTFT(torch.nn.Module):
137
+ def __init__(
138
+ self,
139
+ filter_length,
140
+ hop_length,
141
+ win_length,
142
+ n_mel_channels,
143
+ sampling_rate,
144
+ mel_fmin,
145
+ mel_fmax,
146
+ ):
147
+ super(TacotronSTFT, self).__init__()
148
+ self.n_mel_channels = n_mel_channels
149
+ self.sampling_rate = sampling_rate
150
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
151
+ mel_basis = librosa_mel_fn(
152
+ sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
153
+ )
154
+ mel_basis = torch.from_numpy(mel_basis).float()
155
+ self.register_buffer("mel_basis", mel_basis)
156
+
157
+ def spectral_normalize(self, magnitudes, normalize_fun):
158
+ output = dynamic_range_compression(magnitudes, normalize_fun)
159
+ return output
160
+
161
+ def spectral_de_normalize(self, magnitudes):
162
+ output = dynamic_range_decompression(magnitudes)
163
+ return output
164
+
165
+ def mel_spectrogram(self, y, normalize_fun=torch.log):
166
+ """Computes mel-spectrograms from a batch of waves
167
+ PARAMS
168
+ ------
169
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
170
+
171
+ RETURNS
172
+ -------
173
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
174
+ """
175
+ assert torch.min(y.data) >= -1, torch.min(y.data)
176
+ assert torch.max(y.data) <= 1, torch.max(y.data)
177
+
178
+ magnitudes, phases = self.stft_fn.transform(y)
179
+ magnitudes = magnitudes.data
180
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
181
+ mel_output = self.spectral_normalize(mel_output, normalize_fun)
182
+ energy = torch.norm(magnitudes, dim=1)
183
+
184
+ log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
185
+
186
+ return mel_output, log_magnitudes, energy
audioldm/audio/tools.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torchaudio
4
+
5
+
6
+ def get_mel_from_wav(audio, _stft):
7
+ audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
8
+ audio = torch.autograd.Variable(audio, requires_grad=False)
9
+ melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
10
+ melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
11
+ log_magnitudes_stft = (
12
+ torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)
13
+ )
14
+ energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
15
+ return melspec, log_magnitudes_stft, energy
16
+
17
+
18
+ def _pad_spec(fbank, target_length=1024):
19
+ n_frames = fbank.shape[0]
20
+ p = target_length - n_frames
21
+ # cut and pad
22
+ if p > 0:
23
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
24
+ fbank = m(fbank)
25
+ elif p < 0:
26
+ fbank = fbank[0:target_length, :]
27
+
28
+ if fbank.size(-1) % 2 != 0:
29
+ fbank = fbank[..., :-1]
30
+
31
+ return fbank
32
+
33
+
34
+ def pad_wav(waveform, segment_length):
35
+ waveform_length = waveform.shape[-1]
36
+ assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
37
+ if segment_length is None or waveform_length == segment_length:
38
+ return waveform
39
+ elif waveform_length > segment_length:
40
+ return waveform[:segment_length]
41
+ elif waveform_length < segment_length:
42
+ temp_wav = np.zeros((1, segment_length))
43
+ temp_wav[:, :waveform_length] = waveform
44
+ return temp_wav
45
+
46
+ def normalize_wav(waveform):
47
+ waveform = waveform - np.mean(waveform)
48
+ waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
49
+ return waveform * 0.5
50
+
51
+
52
+ def read_wav_file(filename, segment_length):
53
+ # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
54
+ waveform, sr = torchaudio.load(filename) # Faster!!!
55
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
56
+ waveform = waveform.numpy()[0, ...]
57
+ waveform = normalize_wav(waveform)
58
+ waveform = waveform[None, ...]
59
+ waveform = pad_wav(waveform, segment_length)
60
+
61
+ waveform = waveform / np.max(np.abs(waveform))
62
+ waveform = 0.5 * waveform
63
+
64
+ return waveform
65
+
66
+
67
+ def wav_to_fbank(filename, target_length=1024, fn_STFT=None):
68
+ assert fn_STFT is not None
69
+
70
+ # mixup
71
+ waveform = read_wav_file(filename, target_length * 160) # hop size is 160
72
+
73
+ waveform = waveform[0, ...]
74
+ waveform = torch.FloatTensor(waveform)
75
+
76
+ fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
77
+
78
+ fbank = torch.FloatTensor(fbank.T)
79
+ log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T)
80
+
81
+ fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
82
+ log_magnitudes_stft, target_length
83
+ )
84
+
85
+ return fbank, log_magnitudes_stft, waveform
audioldm/hifigan/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .models import Generator
2
+
3
+
4
+ class AttrDict(dict):
5
+ def __init__(self, *args, **kwargs):
6
+ super(AttrDict, self).__init__(*args, **kwargs)
7
+ self.__dict__ = self
audioldm/hifigan/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (574 Bytes). View file
 
audioldm/hifigan/__pycache__/models.cpython-39.pyc ADDED
Binary file (3.73 kB). View file
 
audioldm/hifigan/__pycache__/utilities.cpython-39.pyc ADDED
Binary file (2.37 kB). View file
 
audioldm/hifigan/models.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import Conv1d, ConvTranspose1d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm
6
+
7
+ LRELU_SLOPE = 0.1
8
+
9
+
10
+ def init_weights(m, mean=0.0, std=0.01):
11
+ classname = m.__class__.__name__
12
+ if classname.find("Conv") != -1:
13
+ m.weight.data.normal_(mean, std)
14
+
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size * dilation - dilation) / 2)
18
+
19
+
20
+ class ResBlock(torch.nn.Module):
21
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22
+ super(ResBlock, self).__init__()
23
+ self.h = h
24
+ self.convs1 = nn.ModuleList(
25
+ [
26
+ weight_norm(
27
+ Conv1d(
28
+ channels,
29
+ channels,
30
+ kernel_size,
31
+ 1,
32
+ dilation=dilation[0],
33
+ padding=get_padding(kernel_size, dilation[0]),
34
+ )
35
+ ),
36
+ weight_norm(
37
+ Conv1d(
38
+ channels,
39
+ channels,
40
+ kernel_size,
41
+ 1,
42
+ dilation=dilation[1],
43
+ padding=get_padding(kernel_size, dilation[1]),
44
+ )
45
+ ),
46
+ weight_norm(
47
+ Conv1d(
48
+ channels,
49
+ channels,
50
+ kernel_size,
51
+ 1,
52
+ dilation=dilation[2],
53
+ padding=get_padding(kernel_size, dilation[2]),
54
+ )
55
+ ),
56
+ ]
57
+ )
58
+ self.convs1.apply(init_weights)
59
+
60
+ self.convs2 = nn.ModuleList(
61
+ [
62
+ weight_norm(
63
+ Conv1d(
64
+ channels,
65
+ channels,
66
+ kernel_size,
67
+ 1,
68
+ dilation=1,
69
+ padding=get_padding(kernel_size, 1),
70
+ )
71
+ ),
72
+ weight_norm(
73
+ Conv1d(
74
+ channels,
75
+ channels,
76
+ kernel_size,
77
+ 1,
78
+ dilation=1,
79
+ padding=get_padding(kernel_size, 1),
80
+ )
81
+ ),
82
+ weight_norm(
83
+ Conv1d(
84
+ channels,
85
+ channels,
86
+ kernel_size,
87
+ 1,
88
+ dilation=1,
89
+ padding=get_padding(kernel_size, 1),
90
+ )
91
+ ),
92
+ ]
93
+ )
94
+ self.convs2.apply(init_weights)
95
+
96
+ def forward(self, x):
97
+ for c1, c2 in zip(self.convs1, self.convs2):
98
+ xt = F.leaky_relu(x, LRELU_SLOPE)
99
+ xt = c1(xt)
100
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
101
+ xt = c2(xt)
102
+ x = xt + x
103
+ return x
104
+
105
+ def remove_weight_norm(self):
106
+ for l in self.convs1:
107
+ remove_weight_norm(l)
108
+ for l in self.convs2:
109
+ remove_weight_norm(l)
110
+
111
+
112
+ class Generator(torch.nn.Module):
113
+ def __init__(self, h):
114
+ super(Generator, self).__init__()
115
+ self.h = h
116
+ self.num_kernels = len(h.resblock_kernel_sizes)
117
+ self.num_upsamples = len(h.upsample_rates)
118
+ self.conv_pre = weight_norm(
119
+ Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
120
+ )
121
+ resblock = ResBlock
122
+
123
+ self.ups = nn.ModuleList()
124
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
125
+ self.ups.append(
126
+ weight_norm(
127
+ ConvTranspose1d(
128
+ h.upsample_initial_channel // (2**i),
129
+ h.upsample_initial_channel // (2 ** (i + 1)),
130
+ k,
131
+ u,
132
+ padding=(k - u) // 2,
133
+ )
134
+ )
135
+ )
136
+
137
+ self.resblocks = nn.ModuleList()
138
+ for i in range(len(self.ups)):
139
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
140
+ for j, (k, d) in enumerate(
141
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
142
+ ):
143
+ self.resblocks.append(resblock(h, ch, k, d))
144
+
145
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
146
+ self.ups.apply(init_weights)
147
+ self.conv_post.apply(init_weights)
148
+
149
+ def forward(self, x):
150
+ x = self.conv_pre(x)
151
+ for i in range(self.num_upsamples):
152
+ x = F.leaky_relu(x, LRELU_SLOPE)
153
+ x = self.ups[i](x)
154
+ xs = None
155
+ for j in range(self.num_kernels):
156
+ if xs is None:
157
+ xs = self.resblocks[i * self.num_kernels + j](x)
158
+ else:
159
+ xs += self.resblocks[i * self.num_kernels + j](x)
160
+ x = xs / self.num_kernels
161
+ x = F.leaky_relu(x)
162
+ x = self.conv_post(x)
163
+ x = torch.tanh(x)
164
+
165
+ return x
166
+
167
+ def remove_weight_norm(self):
168
+ # print("Removing weight norm...")
169
+ for l in self.ups:
170
+ remove_weight_norm(l)
171
+ for l in self.resblocks:
172
+ l.remove_weight_norm()
173
+ remove_weight_norm(self.conv_pre)
174
+ remove_weight_norm(self.conv_post)
audioldm/hifigan/utilities.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import torch
5
+ import numpy as np
6
+
7
+ import audioldm.hifigan as hifigan
8
+
9
+ HIFIGAN_16K_64 = {
10
+ "resblock": "1",
11
+ "num_gpus": 6,
12
+ "batch_size": 16,
13
+ "learning_rate": 0.0002,
14
+ "adam_b1": 0.8,
15
+ "adam_b2": 0.99,
16
+ "lr_decay": 0.999,
17
+ "seed": 1234,
18
+ "upsample_rates": [5, 4, 2, 2, 2],
19
+ "upsample_kernel_sizes": [16, 16, 8, 4, 4],
20
+ "upsample_initial_channel": 1024,
21
+ "resblock_kernel_sizes": [3, 7, 11],
22
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
23
+ "segment_size": 8192,
24
+ "num_mels": 64,
25
+ "num_freq": 1025,
26
+ "n_fft": 1024,
27
+ "hop_size": 160,
28
+ "win_size": 1024,
29
+ "sampling_rate": 16000,
30
+ "fmin": 0,
31
+ "fmax": 8000,
32
+ "fmax_for_loss": None,
33
+ "num_workers": 4,
34
+ "dist_config": {
35
+ "dist_backend": "nccl",
36
+ "dist_url": "tcp://localhost:54321",
37
+ "world_size": 1,
38
+ },
39
+ }
40
+
41
+
42
+ def get_available_checkpoint_keys(model, ckpt):
43
+ print("==> Attemp to reload from %s" % ckpt)
44
+ state_dict = torch.load(ckpt)["state_dict"]
45
+ current_state_dict = model.state_dict()
46
+ new_state_dict = {}
47
+ for k in state_dict.keys():
48
+ if (
49
+ k in current_state_dict.keys()
50
+ and current_state_dict[k].size() == state_dict[k].size()
51
+ ):
52
+ new_state_dict[k] = state_dict[k]
53
+ else:
54
+ print("==> WARNING: Skipping %s" % k)
55
+ print(
56
+ "%s out of %s keys are matched"
57
+ % (len(new_state_dict.keys()), len(state_dict.keys()))
58
+ )
59
+ return new_state_dict
60
+
61
+
62
+ def get_param_num(model):
63
+ num_param = sum(param.numel() for param in model.parameters())
64
+ return num_param
65
+
66
+
67
+ def get_vocoder(config, device):
68
+ config = hifigan.AttrDict(HIFIGAN_16K_64)
69
+ vocoder = hifigan.Generator(config)
70
+ vocoder.eval()
71
+ vocoder.remove_weight_norm()
72
+ vocoder.to(device)
73
+ return vocoder
74
+
75
+
76
+ def vocoder_infer(mels, vocoder, lengths=None):
77
+ vocoder.eval()
78
+ with torch.no_grad():
79
+ wavs = vocoder(mels).squeeze(1)
80
+
81
+ wavs = (wavs.cpu().numpy() * 32768).astype("int16")
82
+
83
+ if lengths is not None:
84
+ wavs = wavs[:, :lengths]
85
+
86
+ return wavs
audioldm/latent_diffusion/__init__.py ADDED
File without changes
audioldm/latent_diffusion/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (164 Bytes). View file
 
audioldm/latent_diffusion/__pycache__/attention.cpython-39.pyc ADDED
Binary file (11.4 kB). View file
 
audioldm/latent_diffusion/__pycache__/ddim.cpython-39.pyc ADDED
Binary file (7.11 kB). View file
 
audioldm/latent_diffusion/__pycache__/ddpm.cpython-39.pyc ADDED
Binary file (11 kB). View file
 
audioldm/latent_diffusion/__pycache__/ema.cpython-39.pyc ADDED
Binary file (3 kB). View file
 
audioldm/latent_diffusion/__pycache__/openaimodel.cpython-39.pyc ADDED
Binary file (23.7 kB). View file
 
audioldm/latent_diffusion/__pycache__/util.cpython-39.pyc ADDED
Binary file (9.6 kB). View file
 
audioldm/latent_diffusion/attention.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from einops import rearrange
7
+
8
+ from audioldm.latent_diffusion.util import checkpoint
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def uniq(arr):
16
+ return {el: True for el in arr}.keys()
17
+
18
+
19
+ def default(val, d):
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def max_neg_value(t):
26
+ return -torch.finfo(t.dtype).max
27
+
28
+
29
+ def init_(tensor):
30
+ dim = tensor.shape[-1]
31
+ std = 1 / math.sqrt(dim)
32
+ tensor.uniform_(-std, std)
33
+ return tensor
34
+
35
+
36
+ # feedforward
37
+ class GEGLU(nn.Module):
38
+ def __init__(self, dim_in, dim_out):
39
+ super().__init__()
40
+ self.proj = nn.Linear(dim_in, dim_out * 2)
41
+
42
+ def forward(self, x):
43
+ x, gate = self.proj(x).chunk(2, dim=-1)
44
+ return x * F.gelu(gate)
45
+
46
+
47
+ class FeedForward(nn.Module):
48
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
49
+ super().__init__()
50
+ inner_dim = int(dim * mult)
51
+ dim_out = default(dim_out, dim)
52
+ project_in = (
53
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
54
+ if not glu
55
+ else GEGLU(dim, inner_dim)
56
+ )
57
+
58
+ self.net = nn.Sequential(
59
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
60
+ )
61
+
62
+ def forward(self, x):
63
+ return self.net(x)
64
+
65
+
66
+ def zero_module(module):
67
+ """
68
+ Zero out the parameters of a module and return it.
69
+ """
70
+ for p in module.parameters():
71
+ p.detach().zero_()
72
+ return module
73
+
74
+
75
+ def Normalize(in_channels):
76
+ return torch.nn.GroupNorm(
77
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
78
+ )
79
+
80
+
81
+ class LinearAttention(nn.Module):
82
+ def __init__(self, dim, heads=4, dim_head=32):
83
+ super().__init__()
84
+ self.heads = heads
85
+ hidden_dim = dim_head * heads
86
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
87
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
88
+
89
+ def forward(self, x):
90
+ b, c, h, w = x.shape
91
+ qkv = self.to_qkv(x)
92
+ q, k, v = rearrange(
93
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
94
+ )
95
+ k = k.softmax(dim=-1)
96
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
97
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
98
+ out = rearrange(
99
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
100
+ )
101
+ return self.to_out(out)
102
+
103
+
104
+ class SpatialSelfAttention(nn.Module):
105
+ def __init__(self, in_channels):
106
+ super().__init__()
107
+ self.in_channels = in_channels
108
+
109
+ self.norm = Normalize(in_channels)
110
+ self.q = torch.nn.Conv2d(
111
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
112
+ )
113
+ self.k = torch.nn.Conv2d(
114
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
115
+ )
116
+ self.v = torch.nn.Conv2d(
117
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
118
+ )
119
+ self.proj_out = torch.nn.Conv2d(
120
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
121
+ )
122
+
123
+ def forward(self, x):
124
+ h_ = x
125
+ h_ = self.norm(h_)
126
+ q = self.q(h_)
127
+ k = self.k(h_)
128
+ v = self.v(h_)
129
+
130
+ # compute attention
131
+ b, c, h, w = q.shape
132
+ q = rearrange(q, "b c h w -> b (h w) c")
133
+ k = rearrange(k, "b c h w -> b c (h w)")
134
+ w_ = torch.einsum("bij,bjk->bik", q, k)
135
+
136
+ w_ = w_ * (int(c) ** (-0.5))
137
+ w_ = torch.nn.functional.softmax(w_, dim=2)
138
+
139
+ # attend to values
140
+ v = rearrange(v, "b c h w -> b c (h w)")
141
+ w_ = rearrange(w_, "b i j -> b j i")
142
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
143
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
144
+ h_ = self.proj_out(h_)
145
+
146
+ return x + h_
147
+
148
+
149
+ class CrossAttention(nn.Module):
150
+ """
151
+ ### Cross Attention Layer
152
+ This falls-back to self-attention when conditional embeddings are not specified.
153
+ """
154
+
155
+ # use_flash_attention: bool = True
156
+ use_flash_attention: bool = False
157
+
158
+ def __init__(
159
+ self,
160
+ query_dim,
161
+ context_dim=None,
162
+ heads=8,
163
+ dim_head=64,
164
+ dropout=0.0,
165
+ is_inplace: bool = True,
166
+ ):
167
+ # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
168
+ """
169
+ :param d_model: is the input embedding size
170
+ :param n_heads: is the number of attention heads
171
+ :param d_head: is the size of a attention head
172
+ :param d_cond: is the size of the conditional embeddings
173
+ :param is_inplace: specifies whether to perform the attention softmax computation inplace to
174
+ save memory
175
+ """
176
+ super().__init__()
177
+
178
+ self.is_inplace = is_inplace
179
+ self.n_heads = heads
180
+ self.d_head = dim_head
181
+
182
+ # Attention scaling factor
183
+ self.scale = dim_head**-0.5
184
+
185
+ # The normal self-attention layer
186
+ if context_dim is None:
187
+ context_dim = query_dim
188
+
189
+ # Query, key and value mappings
190
+ d_attn = dim_head * heads
191
+ self.to_q = nn.Linear(query_dim, d_attn, bias=False)
192
+ self.to_k = nn.Linear(context_dim, d_attn, bias=False)
193
+ self.to_v = nn.Linear(context_dim, d_attn, bias=False)
194
+
195
+ # Final linear layer
196
+ self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))
197
+
198
+ # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
199
+ # Flash attention is only used if it's installed
200
+ # and `CrossAttention.use_flash_attention` is set to `True`.
201
+ try:
202
+ # You can install flash attention by cloning their Github repo,
203
+ # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
204
+ # and then running `python setup.py install`
205
+ from flash_attn.flash_attention import FlashAttention
206
+
207
+ self.flash = FlashAttention()
208
+ # Set the scale for scaled dot-product attention.
209
+ self.flash.softmax_scale = self.scale
210
+ # Set to `None` if it's not installed
211
+ except ImportError:
212
+ self.flash = None
213
+
214
+ def forward(self, x, context=None, mask=None):
215
+ """
216
+ :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
217
+ :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
218
+ """
219
+
220
+ # If `cond` is `None` we perform self attention
221
+ has_cond = context is not None
222
+ if not has_cond:
223
+ context = x
224
+
225
+ # Get query, key and value vectors
226
+ q = self.to_q(x)
227
+ k = self.to_k(context)
228
+ v = self.to_v(context)
229
+
230
+ # Use flash attention if it's available and the head size is less than or equal to `128`
231
+ if (
232
+ CrossAttention.use_flash_attention
233
+ and self.flash is not None
234
+ and not has_cond
235
+ and self.d_head <= 128
236
+ ):
237
+ return self.flash_attention(q, k, v)
238
+ # Otherwise, fallback to normal attention
239
+ else:
240
+ return self.normal_attention(q, k, v)
241
+
242
+ def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
243
+ """
244
+ #### Flash Attention
245
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
246
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
247
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
248
+ """
249
+
250
+ # Get batch size and number of elements along sequence axis (`width * height`)
251
+ batch_size, seq_len, _ = q.shape
252
+
253
+ # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
254
+ # shape `[batch_size, seq_len, 3, n_heads * d_head]`
255
+ qkv = torch.stack((q, k, v), dim=2)
256
+ # Split the heads
257
+ qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
258
+
259
+ # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
260
+ # fit this size.
261
+ if self.d_head <= 32:
262
+ pad = 32 - self.d_head
263
+ elif self.d_head <= 64:
264
+ pad = 64 - self.d_head
265
+ elif self.d_head <= 128:
266
+ pad = 128 - self.d_head
267
+ else:
268
+ raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
269
+
270
+ # Pad the heads
271
+ if pad:
272
+ qkv = torch.cat(
273
+ (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
274
+ )
275
+
276
+ # Compute attention
277
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
278
+ # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
279
+ # TODO here I add the dtype changing
280
+ out, _ = self.flash(qkv.type(torch.float16))
281
+ # Truncate the extra head size
282
+ out = out[:, :, :, : self.d_head].float()
283
+ # Reshape to `[batch_size, seq_len, n_heads * d_head]`
284
+ out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
285
+
286
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
287
+ return self.to_out(out)
288
+
289
+ def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
290
+ """
291
+ #### Normal Attention
292
+
293
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
294
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
295
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
296
+ """
297
+
298
+ # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
299
+ q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32]
300
+ k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32]
301
+ v = v.view(*v.shape[:2], self.n_heads, -1)
302
+
303
+ # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
304
+ attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
305
+
306
+ # Compute softmax
307
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
308
+ if self.is_inplace:
309
+ half = attn.shape[0] // 2
310
+ attn[half:] = attn[half:].softmax(dim=-1)
311
+ attn[:half] = attn[:half].softmax(dim=-1)
312
+ else:
313
+ attn = attn.softmax(dim=-1)
314
+
315
+ # Compute attention output
316
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
317
+ # attn: [bs, 20, 64, 1]
318
+ # v: [bs, 1, 20, 32]
319
+ out = torch.einsum("bhij,bjhd->bihd", attn, v)
320
+ # Reshape to `[batch_size, height * width, n_heads * d_head]`
321
+ out = out.reshape(*out.shape[:2], -1)
322
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
323
+ return self.to_out(out)
324
+
325
+
326
+ # class CrossAttention(nn.Module):
327
+ # def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
328
+ # super().__init__()
329
+ # inner_dim = dim_head * heads
330
+ # context_dim = default(context_dim, query_dim)
331
+
332
+ # self.scale = dim_head ** -0.5
333
+ # self.heads = heads
334
+
335
+ # self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
336
+ # self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
337
+ # self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
338
+
339
+ # self.to_out = nn.Sequential(
340
+ # nn.Linear(inner_dim, query_dim),
341
+ # nn.Dropout(dropout)
342
+ # )
343
+
344
+ # def forward(self, x, context=None, mask=None):
345
+ # h = self.heads
346
+
347
+ # q = self.to_q(x)
348
+ # context = default(context, x)
349
+ # k = self.to_k(context)
350
+ # v = self.to_v(context)
351
+
352
+ # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
353
+
354
+ # sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
355
+
356
+ # if exists(mask):
357
+ # mask = rearrange(mask, 'b ... -> b (...)')
358
+ # max_neg_value = -torch.finfo(sim.dtype).max
359
+ # mask = repeat(mask, 'b j -> (b h) () j', h=h)
360
+ # sim.masked_fill_(~mask, max_neg_value)
361
+
362
+ # # attention, what we cannot get enough of
363
+ # attn = sim.softmax(dim=-1)
364
+
365
+ # out = einsum('b i j, b j d -> b i d', attn, v)
366
+ # out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
367
+ # return self.to_out(out)
368
+
369
+
370
+ class BasicTransformerBlock(nn.Module):
371
+ def __init__(
372
+ self,
373
+ dim,
374
+ n_heads,
375
+ d_head,
376
+ dropout=0.0,
377
+ context_dim=None,
378
+ gated_ff=True,
379
+ checkpoint=True,
380
+ ):
381
+ super().__init__()
382
+ self.attn1 = CrossAttention(
383
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
384
+ ) # is a self-attention
385
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
386
+ self.attn2 = CrossAttention(
387
+ query_dim=dim,
388
+ context_dim=context_dim,
389
+ heads=n_heads,
390
+ dim_head=d_head,
391
+ dropout=dropout,
392
+ ) # is self-attn if context is none
393
+ self.norm1 = nn.LayerNorm(dim)
394
+ self.norm2 = nn.LayerNorm(dim)
395
+ self.norm3 = nn.LayerNorm(dim)
396
+ self.checkpoint = checkpoint
397
+
398
+ def forward(self, x, context=None):
399
+ if context is None:
400
+ return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
401
+ else:
402
+ return checkpoint(
403
+ self._forward, (x, context), self.parameters(), self.checkpoint
404
+ )
405
+
406
+ def _forward(self, x, context=None):
407
+ x = self.attn1(self.norm1(x)) + x
408
+ x = self.attn2(self.norm2(x), context=context) + x
409
+ x = self.ff(self.norm3(x)) + x
410
+ return x
411
+
412
+
413
+ class SpatialTransformer(nn.Module):
414
+ """
415
+ Transformer block for image-like data.
416
+ First, project the input (aka embedding)
417
+ and reshape to b, t, d.
418
+ Then apply standard transformer action.
419
+ Finally, reshape to image
420
+ """
421
+
422
+ def __init__(
423
+ self,
424
+ in_channels,
425
+ n_heads,
426
+ d_head,
427
+ depth=1,
428
+ dropout=0.0,
429
+ context_dim=None,
430
+ no_context=False,
431
+ ):
432
+ super().__init__()
433
+
434
+ if no_context:
435
+ context_dim = None
436
+
437
+ self.in_channels = in_channels
438
+ inner_dim = n_heads * d_head
439
+ self.norm = Normalize(in_channels)
440
+
441
+ self.proj_in = nn.Conv2d(
442
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
443
+ )
444
+
445
+ self.transformer_blocks = nn.ModuleList(
446
+ [
447
+ BasicTransformerBlock(
448
+ inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
449
+ )
450
+ for d in range(depth)
451
+ ]
452
+ )
453
+
454
+ self.proj_out = zero_module(
455
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
456
+ )
457
+
458
+ def forward(self, x, context=None):
459
+ # note: if no context is given, cross-attention defaults to self-attention
460
+ b, c, h, w = x.shape
461
+ x_in = x
462
+ x = self.norm(x)
463
+ x = self.proj_in(x)
464
+ x = rearrange(x, "b c h w -> b (h w) c")
465
+ for block in self.transformer_blocks:
466
+ x = block(x, context=context)
467
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
468
+ x = self.proj_out(x)
469
+ return x + x_in
audioldm/latent_diffusion/ddim.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from audioldm.latent_diffusion.util import (
8
+ make_ddim_sampling_parameters,
9
+ make_ddim_timesteps,
10
+ noise_like,
11
+ extract_into_tensor,
12
+ )
13
+
14
+
15
+ class DDIMSampler(object):
16
+ def __init__(self, model, schedule="linear", **kwargs):
17
+ super().__init__()
18
+ self.model = model
19
+ self.ddpm_num_timesteps = model.num_timesteps
20
+ self.schedule = schedule
21
+
22
+ def register_buffer(self, name, attr):
23
+ if type(attr) == torch.Tensor:
24
+ if attr.device != torch.device("cuda"):
25
+ attr = attr.to(torch.device("cuda"))
26
+ setattr(self, name, attr)
27
+
28
+ def make_schedule(
29
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
30
+ ):
31
+ self.ddim_timesteps = make_ddim_timesteps(
32
+ ddim_discr_method=ddim_discretize,
33
+ num_ddim_timesteps=ddim_num_steps,
34
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
35
+ verbose=verbose,
36
+ )
37
+ alphas_cumprod = self.model.alphas_cumprod
38
+ assert (
39
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
40
+ ), "alphas have to be defined for each timestep"
41
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
42
+
43
+ self.register_buffer("betas", to_torch(self.model.betas))
44
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
45
+ self.register_buffer(
46
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
47
+ )
48
+
49
+ # calculations for diffusion q(x_t | x_{t-1}) and others
50
+ self.register_buffer(
51
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
52
+ )
53
+ self.register_buffer(
54
+ "sqrt_one_minus_alphas_cumprod",
55
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
56
+ )
57
+ self.register_buffer(
58
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
59
+ )
60
+ self.register_buffer(
61
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
62
+ )
63
+ self.register_buffer(
64
+ "sqrt_recipm1_alphas_cumprod",
65
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
66
+ )
67
+
68
+ # ddim sampling parameters
69
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
70
+ alphacums=alphas_cumprod.cpu(),
71
+ ddim_timesteps=self.ddim_timesteps,
72
+ eta=ddim_eta,
73
+ verbose=verbose,
74
+ )
75
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
76
+ self.register_buffer("ddim_alphas", ddim_alphas)
77
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
78
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
79
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
80
+ (1 - self.alphas_cumprod_prev)
81
+ / (1 - self.alphas_cumprod)
82
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
83
+ )
84
+ self.register_buffer(
85
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
86
+ )
87
+
88
+ @torch.no_grad()
89
+ def sample(
90
+ self,
91
+ S,
92
+ batch_size,
93
+ shape,
94
+ conditioning=None,
95
+ callback=None,
96
+ normals_sequence=None,
97
+ img_callback=None,
98
+ quantize_x0=False,
99
+ eta=0.0,
100
+ mask=None,
101
+ x0=None,
102
+ temperature=1.0,
103
+ noise_dropout=0.0,
104
+ score_corrector=None,
105
+ corrector_kwargs=None,
106
+ verbose=True,
107
+ x_T=None,
108
+ log_every_t=100,
109
+ unconditional_guidance_scale=1.0,
110
+ unconditional_conditioning=None,
111
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
112
+ **kwargs,
113
+ ):
114
+ if conditioning is not None:
115
+ if isinstance(conditioning, dict):
116
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
117
+ if cbs != batch_size:
118
+ print(
119
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
120
+ )
121
+ else:
122
+ if conditioning.shape[0] != batch_size:
123
+ print(
124
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
125
+ )
126
+
127
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
128
+ # sampling
129
+ C, H, W = shape
130
+ size = (batch_size, C, H, W)
131
+ samples, intermediates = self.ddim_sampling(
132
+ conditioning,
133
+ size,
134
+ callback=callback,
135
+ img_callback=img_callback,
136
+ quantize_denoised=quantize_x0,
137
+ mask=mask,
138
+ x0=x0,
139
+ ddim_use_original_steps=False,
140
+ noise_dropout=noise_dropout,
141
+ temperature=temperature,
142
+ score_corrector=score_corrector,
143
+ corrector_kwargs=corrector_kwargs,
144
+ x_T=x_T,
145
+ log_every_t=log_every_t,
146
+ unconditional_guidance_scale=unconditional_guidance_scale,
147
+ unconditional_conditioning=unconditional_conditioning,
148
+ )
149
+ return samples, intermediates
150
+
151
+ @torch.no_grad()
152
+ def ddim_sampling(
153
+ self,
154
+ cond,
155
+ shape,
156
+ x_T=None,
157
+ ddim_use_original_steps=False,
158
+ callback=None,
159
+ timesteps=None,
160
+ quantize_denoised=False,
161
+ mask=None,
162
+ x0=None,
163
+ img_callback=None,
164
+ log_every_t=100,
165
+ temperature=1.0,
166
+ noise_dropout=0.0,
167
+ score_corrector=None,
168
+ corrector_kwargs=None,
169
+ unconditional_guidance_scale=1.0,
170
+ unconditional_conditioning=None,
171
+ ):
172
+ device = self.model.betas.device
173
+ b = shape[0]
174
+ if x_T is None:
175
+ img = torch.randn(shape, device=device)
176
+ else:
177
+ img = x_T
178
+
179
+ if timesteps is None:
180
+ timesteps = (
181
+ self.ddpm_num_timesteps
182
+ if ddim_use_original_steps
183
+ else self.ddim_timesteps
184
+ )
185
+ elif timesteps is not None and not ddim_use_original_steps:
186
+ subset_end = (
187
+ int(
188
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
189
+ * self.ddim_timesteps.shape[0]
190
+ )
191
+ - 1
192
+ )
193
+ timesteps = self.ddim_timesteps[:subset_end]
194
+
195
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
196
+ time_range = (
197
+ reversed(range(0, timesteps))
198
+ if ddim_use_original_steps
199
+ else np.flip(timesteps)
200
+ )
201
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
202
+ # print(f"Running DDIM Sampling with {total_steps} timesteps")
203
+
204
+ # iterator = gr.Progress().tqdm(time_range, desc="DDIM Sampler", total=total_steps)
205
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps, leave=False)
206
+
207
+ for i, step in enumerate(iterator):
208
+ index = total_steps - i - 1
209
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
210
+ if mask is not None:
211
+ assert x0 is not None
212
+ img_orig = self.model.q_sample(
213
+ x0, ts
214
+ ) # TODO deterministic forward pass?
215
+ img = (
216
+ img_orig * mask + (1.0 - mask) * img
217
+ ) # In the first sampling step, img is pure gaussian noise
218
+
219
+ outs = self.p_sample_ddim(
220
+ img,
221
+ cond,
222
+ ts,
223
+ index=index,
224
+ use_original_steps=ddim_use_original_steps,
225
+ quantize_denoised=quantize_denoised,
226
+ temperature=temperature,
227
+ noise_dropout=noise_dropout,
228
+ score_corrector=score_corrector,
229
+ corrector_kwargs=corrector_kwargs,
230
+ unconditional_guidance_scale=unconditional_guidance_scale,
231
+ unconditional_conditioning=unconditional_conditioning,
232
+ )
233
+ img, pred_x0 = outs
234
+ if callback:
235
+ callback(i)
236
+ if img_callback:
237
+ img_callback(pred_x0, i)
238
+
239
+ if index % log_every_t == 0 or index == total_steps - 1:
240
+ intermediates["x_inter"].append(img)
241
+ intermediates["pred_x0"].append(pred_x0)
242
+
243
+ return img, intermediates
244
+
245
+ @torch.no_grad()
246
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
247
+ # fast, but does not allow for exact reconstruction
248
+ # t serves as an index to gather the correct alphas
249
+ if use_original_steps:
250
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
251
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
252
+ else:
253
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
254
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
255
+
256
+ if noise is None:
257
+ noise = torch.randn_like(x0)
258
+
259
+ return (
260
+ extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
261
+ + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
262
+ )
263
+
264
+ @torch.no_grad()
265
+ def decode(
266
+ self,
267
+ x_latent,
268
+ cond,
269
+ t_start,
270
+ unconditional_guidance_scale=1.0,
271
+ unconditional_conditioning=None,
272
+ use_original_steps=False,
273
+ ):
274
+
275
+ timesteps = (
276
+ np.arange(self.ddpm_num_timesteps)
277
+ if use_original_steps
278
+ else self.ddim_timesteps
279
+ )
280
+ timesteps = timesteps[:t_start]
281
+
282
+ time_range = np.flip(timesteps)
283
+ total_steps = timesteps.shape[0]
284
+ # print(f"Running DDIM Sampling with {total_steps} timesteps")
285
+
286
+ # iterator = gr.Progress().tqdm(time_range, desc="Decoding image", total=total_steps)
287
+ iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
288
+ x_dec = x_latent
289
+
290
+ for i, step in enumerate(iterator):
291
+ index = total_steps - i - 1
292
+ ts = torch.full(
293
+ (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
294
+ )
295
+ x_dec, _ = self.p_sample_ddim(
296
+ x_dec,
297
+ cond,
298
+ ts,
299
+ index=index,
300
+ use_original_steps=use_original_steps,
301
+ unconditional_guidance_scale=unconditional_guidance_scale,
302
+ unconditional_conditioning=unconditional_conditioning,
303
+ )
304
+ return x_dec
305
+
306
+ @torch.no_grad()
307
+ def p_sample_ddim(
308
+ self,
309
+ x,
310
+ c,
311
+ t,
312
+ index,
313
+ repeat_noise=False,
314
+ use_original_steps=False,
315
+ quantize_denoised=False,
316
+ temperature=1.0,
317
+ noise_dropout=0.0,
318
+ score_corrector=None,
319
+ corrector_kwargs=None,
320
+ unconditional_guidance_scale=1.0,
321
+ unconditional_conditioning=None,
322
+ ):
323
+ b, *_, device = *x.shape, x.device
324
+
325
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
326
+ e_t = self.model.apply_model(x, t, c)
327
+ else:
328
+ x_in = torch.cat([x] * 2)
329
+ t_in = torch.cat([t] * 2)
330
+ c_in = torch.cat([unconditional_conditioning, c])
331
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
332
+ # When unconditional_guidance_scale == 1: only e_t
333
+ # When unconditional_guidance_scale == 0: only unconditional
334
+ # When unconditional_guidance_scale > 1: add more unconditional guidance
335
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
336
+
337
+ if score_corrector is not None:
338
+ assert self.model.parameterization == "eps"
339
+ e_t = score_corrector.modify_score(
340
+ self.model, e_t, x, t, c, **corrector_kwargs
341
+ )
342
+
343
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
344
+ alphas_prev = (
345
+ self.model.alphas_cumprod_prev
346
+ if use_original_steps
347
+ else self.ddim_alphas_prev
348
+ )
349
+ sqrt_one_minus_alphas = (
350
+ self.model.sqrt_one_minus_alphas_cumprod
351
+ if use_original_steps
352
+ else self.ddim_sqrt_one_minus_alphas
353
+ )
354
+ sigmas = (
355
+ self.model.ddim_sigmas_for_original_num_steps
356
+ if use_original_steps
357
+ else self.ddim_sigmas
358
+ )
359
+ # select parameters corresponding to the currently considered timestep
360
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
361
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
362
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
363
+ sqrt_one_minus_at = torch.full(
364
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
365
+ )
366
+
367
+ # current prediction for x_0
368
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
369
+ if quantize_denoised:
370
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
371
+ # direction pointing to x_t
372
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
373
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
374
+ if noise_dropout > 0.0:
375
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
376
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise # TODO
377
+ return x_prev, pred_x0
audioldm/latent_diffusion/ddpm.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+ import sys
9
+ import os
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import numpy as np
14
+ from contextlib import contextmanager
15
+ from functools import partial
16
+ from tqdm import tqdm
17
+
18
+ from audioldm.utils import exists, default, count_params, instantiate_from_config
19
+ from audioldm.latent_diffusion.ema import LitEma
20
+ from audioldm.latent_diffusion.util import (
21
+ make_beta_schedule,
22
+ extract_into_tensor,
23
+ noise_like,
24
+ )
25
+ import soundfile as sf
26
+ import os
27
+
28
+
29
+ __conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
30
+
31
+
32
+ def disabled_train(self, mode=True):
33
+ """Overwrite model.train with this function to make sure train/eval mode
34
+ does not change anymore."""
35
+ return self
36
+
37
+
38
+ def uniform_on_device(r1, r2, shape, device):
39
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
40
+
41
+
42
+ class DiffusionWrapper(nn.Module):
43
+ def __init__(self, diff_model_config, conditioning_key):
44
+ super().__init__()
45
+ self.diffusion_model = instantiate_from_config(diff_model_config)
46
+ self.conditioning_key = conditioning_key
47
+ assert self.conditioning_key in [
48
+ None,
49
+ "concat",
50
+ "crossattn",
51
+ "hybrid",
52
+ "adm",
53
+ "film",
54
+ ]
55
+
56
+ def forward(
57
+ self, x, t, c_concat: list = None, c_crossattn: list = None, c_film: list = None
58
+ ):
59
+ x = x.contiguous()
60
+ t = t.contiguous()
61
+
62
+ if self.conditioning_key is None:
63
+ out = self.diffusion_model(x, t)
64
+ elif self.conditioning_key == "concat":
65
+ xc = torch.cat([x] + c_concat, dim=1)
66
+ out = self.diffusion_model(xc, t)
67
+ elif self.conditioning_key == "crossattn":
68
+ cc = torch.cat(c_crossattn, 1)
69
+ out = self.diffusion_model(x, t, context=cc)
70
+ elif self.conditioning_key == "hybrid":
71
+ xc = torch.cat([x] + c_concat, dim=1)
72
+ cc = torch.cat(c_crossattn, 1)
73
+ out = self.diffusion_model(xc, t, context=cc)
74
+ elif (
75
+ self.conditioning_key == "film"
76
+ ): # The condition is assumed to be a global token, which wil pass through a linear layer and added with the time embedding for the FILM
77
+ cc = c_film[0].squeeze(1) # only has one token
78
+ out = self.diffusion_model(x, t, y=cc)
79
+ elif self.conditioning_key == "adm":
80
+ cc = c_crossattn[0]
81
+ out = self.diffusion_model(x, t, y=cc)
82
+ else:
83
+ raise NotImplementedError()
84
+
85
+ return out
86
+
87
+
88
+ class DDPM(nn.Module):
89
+ # classic DDPM with Gaussian diffusion, in image space
90
+ def __init__(
91
+ self,
92
+ unet_config,
93
+ timesteps=1000,
94
+ beta_schedule="linear",
95
+ loss_type="l2",
96
+ ckpt_path=None,
97
+ ignore_keys=[],
98
+ load_only_unet=False,
99
+ monitor="val/loss",
100
+ use_ema=True,
101
+ first_stage_key="image",
102
+ latent_t_size=256,
103
+ latent_f_size=16,
104
+ channels=3,
105
+ log_every_t=100,
106
+ clip_denoised=True,
107
+ linear_start=1e-4,
108
+ linear_end=2e-2,
109
+ cosine_s=8e-3,
110
+ given_betas=None,
111
+ original_elbo_weight=0.0,
112
+ v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
113
+ l_simple_weight=1.0,
114
+ conditioning_key=None,
115
+ parameterization="eps", # all assuming fixed variance schedules
116
+ scheduler_config=None,
117
+ use_positional_encodings=False,
118
+ learn_logvar=False,
119
+ logvar_init=0.0,
120
+ ):
121
+ super().__init__()
122
+ assert parameterization in [
123
+ "eps",
124
+ "x0",
125
+ ], 'currently only supporting "eps" and "x0"'
126
+ self.parameterization = parameterization
127
+ self.state = None
128
+ # print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
129
+ self.cond_stage_model = None
130
+ self.clip_denoised = clip_denoised
131
+ self.log_every_t = log_every_t
132
+ self.first_stage_key = first_stage_key
133
+
134
+ self.latent_t_size = latent_t_size
135
+ self.latent_f_size = latent_f_size
136
+
137
+ self.channels = channels
138
+ self.use_positional_encodings = use_positional_encodings
139
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
140
+ count_params(self.model, verbose=True)
141
+ self.use_ema = use_ema
142
+ if self.use_ema:
143
+ self.model_ema = LitEma(self.model)
144
+ # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
145
+
146
+ self.use_scheduler = scheduler_config is not None
147
+ if self.use_scheduler:
148
+ self.scheduler_config = scheduler_config
149
+
150
+ self.v_posterior = v_posterior
151
+ self.original_elbo_weight = original_elbo_weight
152
+ self.l_simple_weight = l_simple_weight
153
+
154
+ if monitor is not None:
155
+ self.monitor = monitor
156
+
157
+ self.register_schedule(
158
+ given_betas=given_betas,
159
+ beta_schedule=beta_schedule,
160
+ timesteps=timesteps,
161
+ linear_start=linear_start,
162
+ linear_end=linear_end,
163
+ cosine_s=cosine_s,
164
+ )
165
+
166
+ self.loss_type = loss_type
167
+
168
+ self.learn_logvar = learn_logvar
169
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
170
+ if self.learn_logvar:
171
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
172
+ else:
173
+ self.logvar = nn.Parameter(self.logvar, requires_grad=False)
174
+
175
+ self.logger_save_dir = None
176
+ self.logger_project = None
177
+ self.logger_version = None
178
+ self.label_indices_total = None
179
+ # To avoid the system cannot find metric value for checkpoint
180
+ self.metrics_buffer = {
181
+ "val/kullback_leibler_divergence_sigmoid": 15.0,
182
+ "val/kullback_leibler_divergence_softmax": 10.0,
183
+ "val/psnr": 0.0,
184
+ "val/ssim": 0.0,
185
+ "val/inception_score_mean": 1.0,
186
+ "val/inception_score_std": 0.0,
187
+ "val/kernel_inception_distance_mean": 0.0,
188
+ "val/kernel_inception_distance_std": 0.0,
189
+ "val/frechet_inception_distance": 133.0,
190
+ "val/frechet_audio_distance": 32.0,
191
+ }
192
+ self.initial_learning_rate = None
193
+
194
+ def get_log_dir(self):
195
+ if (
196
+ self.logger_save_dir is None
197
+ and self.logger_project is None
198
+ and self.logger_version is None
199
+ ):
200
+ return os.path.join(
201
+ self.logger.save_dir, self.logger._project, self.logger.version
202
+ )
203
+ else:
204
+ return os.path.join(
205
+ self.logger_save_dir, self.logger_project, self.logger_version
206
+ )
207
+
208
+ def set_log_dir(self, save_dir, project, version):
209
+ self.logger_save_dir = save_dir
210
+ self.logger_project = project
211
+ self.logger_version = version
212
+
213
+ def register_schedule(
214
+ self,
215
+ given_betas=None,
216
+ beta_schedule="linear",
217
+ timesteps=1000,
218
+ linear_start=1e-4,
219
+ linear_end=2e-2,
220
+ cosine_s=8e-3,
221
+ ):
222
+ if exists(given_betas):
223
+ betas = given_betas
224
+ else:
225
+ betas = make_beta_schedule(
226
+ beta_schedule,
227
+ timesteps,
228
+ linear_start=linear_start,
229
+ linear_end=linear_end,
230
+ cosine_s=cosine_s,
231
+ )
232
+ alphas = 1.0 - betas
233
+ alphas_cumprod = np.cumprod(alphas, axis=0)
234
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
235
+
236
+ (timesteps,) = betas.shape
237
+ self.num_timesteps = int(timesteps)
238
+ self.linear_start = linear_start
239
+ self.linear_end = linear_end
240
+ assert (
241
+ alphas_cumprod.shape[0] == self.num_timesteps
242
+ ), "alphas have to be defined for each timestep"
243
+
244
+ to_torch = partial(torch.tensor, dtype=torch.float32)
245
+
246
+ self.register_buffer("betas", to_torch(betas))
247
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
248
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
249
+
250
+ # calculations for diffusion q(x_t | x_{t-1}) and others
251
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
252
+ self.register_buffer(
253
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
254
+ )
255
+ self.register_buffer(
256
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
257
+ )
258
+ self.register_buffer(
259
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
260
+ )
261
+ self.register_buffer(
262
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
263
+ )
264
+
265
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
266
+ posterior_variance = (1 - self.v_posterior) * betas * (
267
+ 1.0 - alphas_cumprod_prev
268
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
269
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
270
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
271
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
272
+ self.register_buffer(
273
+ "posterior_log_variance_clipped",
274
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
275
+ )
276
+ self.register_buffer(
277
+ "posterior_mean_coef1",
278
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
279
+ )
280
+ self.register_buffer(
281
+ "posterior_mean_coef2",
282
+ to_torch(
283
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
284
+ ),
285
+ )
286
+
287
+ if self.parameterization == "eps":
288
+ lvlb_weights = self.betas**2 / (
289
+ 2
290
+ * self.posterior_variance
291
+ * to_torch(alphas)
292
+ * (1 - self.alphas_cumprod)
293
+ )
294
+ elif self.parameterization == "x0":
295
+ lvlb_weights = (
296
+ 0.5
297
+ * np.sqrt(torch.Tensor(alphas_cumprod))
298
+ / (2.0 * 1 - torch.Tensor(alphas_cumprod))
299
+ )
300
+ else:
301
+ raise NotImplementedError("mu not supported")
302
+ # TODO how to choose this term
303
+ lvlb_weights[0] = lvlb_weights[1]
304
+ self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
305
+ assert not torch.isnan(self.lvlb_weights).all()
306
+
307
+ @contextmanager
308
+ def ema_scope(self, context=None):
309
+ if self.use_ema:
310
+ self.model_ema.store(self.model.parameters())
311
+ self.model_ema.copy_to(self.model)
312
+ if context is not None:
313
+ # print(f"{context}: Switched to EMA weights")
314
+ pass
315
+ try:
316
+ yield None
317
+ finally:
318
+ if self.use_ema:
319
+ self.model_ema.restore(self.model.parameters())
320
+ if context is not None:
321
+ # print(f"{context}: Restored training weights")
322
+ pass
323
+
324
+ def q_mean_variance(self, x_start, t):
325
+ """
326
+ Get the distribution q(x_t | x_0).
327
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
328
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
329
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
330
+ """
331
+ mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
332
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
333
+ log_variance = extract_into_tensor(
334
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
335
+ )
336
+ return mean, variance, log_variance
337
+
338
+ def predict_start_from_noise(self, x_t, t, noise):
339
+ return (
340
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
341
+ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
342
+ * noise
343
+ )
344
+
345
+ def q_posterior(self, x_start, x_t, t):
346
+ posterior_mean = (
347
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
348
+ + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
349
+ )
350
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
351
+ posterior_log_variance_clipped = extract_into_tensor(
352
+ self.posterior_log_variance_clipped, t, x_t.shape
353
+ )
354
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
355
+
356
+ def p_mean_variance(self, x, t, clip_denoised: bool):
357
+ model_out = self.model(x, t)
358
+ if self.parameterization == "eps":
359
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
360
+ elif self.parameterization == "x0":
361
+ x_recon = model_out
362
+ if clip_denoised:
363
+ x_recon.clamp_(-1.0, 1.0)
364
+
365
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
366
+ x_start=x_recon, x_t=x, t=t
367
+ )
368
+ return model_mean, posterior_variance, posterior_log_variance
369
+
370
+ @torch.no_grad()
371
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
372
+ b, *_, device = *x.shape, x.device
373
+ model_mean, _, model_log_variance = self.p_mean_variance(
374
+ x=x, t=t, clip_denoised=clip_denoised
375
+ )
376
+ noise = noise_like(x.shape, device, repeat_noise)
377
+ # no noise when t == 0
378
+ nonzero_mask = (
379
+ (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()
380
+ )
381
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
382
+
383
+ @torch.no_grad()
384
+ def p_sample_loop(self, shape, return_intermediates=False):
385
+ device = self.betas.device
386
+ b = shape[0]
387
+ img = torch.randn(shape, device=device)
388
+ intermediates = [img]
389
+ for i in tqdm(
390
+ reversed(range(0, self.num_timesteps)),
391
+ desc="Sampling t",
392
+ total=self.num_timesteps,
393
+ ):
394
+ img = self.p_sample(
395
+ img,
396
+ torch.full((b,), i, device=device, dtype=torch.long),
397
+ clip_denoised=self.clip_denoised,
398
+ )
399
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
400
+ intermediates.append(img)
401
+ if return_intermediates:
402
+ return img, intermediates
403
+ return img
404
+
405
+ @torch.no_grad()
406
+ def sample(self, batch_size=16, return_intermediates=False):
407
+ shape = (batch_size, channels, self.latent_t_size, self.latent_f_size)
408
+ channels = self.channels
409
+ return self.p_sample_loop(shape, return_intermediates=return_intermediates)
410
+
411
+ def q_sample(self, x_start, t, noise=None):
412
+ noise = default(noise, lambda: torch.randn_like(x_start))
413
+ return (
414
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
415
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
416
+ * noise
417
+ )
418
+
419
+ def forward(self, x, *args, **kwargs):
420
+ t = torch.randint(
421
+ 0, self.num_timesteps, (x.shape[0],), device=self.device
422
+ ).long()
423
+ return self.p_losses(x, t, *args, **kwargs)
424
+
425
+ def get_input(self, batch, k):
426
+ # fbank, log_magnitudes_stft, label_indices, fname, waveform, clip_label, text = batch
427
+ fbank, log_magnitudes_stft, label_indices, fname, waveform, text = batch
428
+ ret = {}
429
+
430
+ ret["fbank"] = (
431
+ fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float()
432
+ )
433
+ ret["stft"] = log_magnitudes_stft.to(
434
+ memory_format=torch.contiguous_format
435
+ ).float()
436
+ # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float()
437
+ ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float()
438
+ ret["text"] = list(text)
439
+ ret["fname"] = fname
440
+
441
+ return ret[k]
audioldm/latent_diffusion/ema.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError("Decay must be between 0 and 1")
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer(
14
+ "num_updates",
15
+ torch.tensor(0, dtype=torch.int)
16
+ if use_num_upates
17
+ else torch.tensor(-1, dtype=torch.int),
18
+ )
19
+
20
+ for name, p in model.named_parameters():
21
+ if p.requires_grad:
22
+ # remove as '.'-character is not allowed in buffers
23
+ s_name = name.replace(".", "")
24
+ self.m_name2s_name.update({name: s_name})
25
+ self.register_buffer(s_name, p.clone().detach().data)
26
+
27
+ self.collected_params = []
28
+
29
+ def forward(self, model):
30
+ decay = self.decay
31
+
32
+ if self.num_updates >= 0:
33
+ self.num_updates += 1
34
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35
+
36
+ one_minus_decay = 1.0 - decay
37
+
38
+ with torch.no_grad():
39
+ m_param = dict(model.named_parameters())
40
+ shadow_params = dict(self.named_buffers())
41
+
42
+ for key in m_param:
43
+ if m_param[key].requires_grad:
44
+ sname = self.m_name2s_name[key]
45
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46
+ shadow_params[sname].sub_(
47
+ one_minus_decay * (shadow_params[sname] - m_param[key])
48
+ )
49
+ else:
50
+ assert not key in self.m_name2s_name
51
+
52
+ def copy_to(self, model):
53
+ m_param = dict(model.named_parameters())
54
+ shadow_params = dict(self.named_buffers())
55
+ for key in m_param:
56
+ if m_param[key].requires_grad:
57
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
58
+ else:
59
+ assert not key in self.m_name2s_name
60
+
61
+ def store(self, parameters):
62
+ """
63
+ Save the current parameters for restoring later.
64
+ Args:
65
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
66
+ temporarily stored.
67
+ """
68
+ self.collected_params = [param.clone() for param in parameters]
69
+
70
+ def restore(self, parameters):
71
+ """
72
+ Restore the parameters stored with the `store` method.
73
+ Useful to validate the model with EMA parameters without affecting the
74
+ original optimization process. Store the parameters before the
75
+ `copy_to` method. After validation (or model saving), use this to
76
+ restore the former parameters.
77
+ Args:
78
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
79
+ updated with the stored parameters.
80
+ """
81
+ for c_param, param in zip(self.collected_params, parameters):
82
+ param.data.copy_(c_param.data)
audioldm/latent_diffusion/openaimodel.py ADDED
@@ -0,0 +1,1069 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch as th
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from audioldm.latent_diffusion.util import (
10
+ checkpoint,
11
+ conv_nd,
12
+ linear,
13
+ avg_pool_nd,
14
+ zero_module,
15
+ normalization,
16
+ timestep_embedding,
17
+ )
18
+ from audioldm.latent_diffusion.attention import SpatialTransformer
19
+
20
+
21
+ # dummy replace
22
+ def convert_module_to_f16(x):
23
+ pass
24
+
25
+
26
+ def convert_module_to_f32(x):
27
+ pass
28
+
29
+
30
+ ## go
31
+ class AttentionPool2d(nn.Module):
32
+ """
33
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ spacial_dim: int,
39
+ embed_dim: int,
40
+ num_heads_channels: int,
41
+ output_dim: int = None,
42
+ ):
43
+ super().__init__()
44
+ self.positional_embedding = nn.Parameter(
45
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
46
+ )
47
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
48
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
49
+ self.num_heads = embed_dim // num_heads_channels
50
+ self.attention = QKVAttention(self.num_heads)
51
+
52
+ def forward(self, x):
53
+ b, c, *_spatial = x.shape
54
+ x = x.reshape(b, c, -1).contiguous() # NC(HW)
55
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
56
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
57
+ x = self.qkv_proj(x)
58
+ x = self.attention(x)
59
+ x = self.c_proj(x)
60
+ return x[:, :, 0]
61
+
62
+
63
+ class TimestepBlock(nn.Module):
64
+ """
65
+ Any module where forward() takes timestep embeddings as a second argument.
66
+ """
67
+
68
+ @abstractmethod
69
+ def forward(self, x, emb):
70
+ """
71
+ Apply the module to `x` given `emb` timestep embeddings.
72
+ """
73
+
74
+
75
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
76
+ """
77
+ A sequential module that passes timestep embeddings to the children that
78
+ support it as an extra input.
79
+ """
80
+
81
+ def forward(self, x, emb, context=None):
82
+ for layer in self:
83
+ if isinstance(layer, TimestepBlock):
84
+ x = layer(x, emb)
85
+ elif isinstance(layer, SpatialTransformer):
86
+ x = layer(x, context)
87
+ else:
88
+ x = layer(x)
89
+ return x
90
+
91
+
92
+ class Upsample(nn.Module):
93
+ """
94
+ An upsampling layer with an optional convolution.
95
+ :param channels: channels in the inputs and outputs.
96
+ :param use_conv: a bool determining if a convolution is applied.
97
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
98
+ upsampling occurs in the inner-two dimensions.
99
+ """
100
+
101
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
102
+ super().__init__()
103
+ self.channels = channels
104
+ self.out_channels = out_channels or channels
105
+ self.use_conv = use_conv
106
+ self.dims = dims
107
+ if use_conv:
108
+ self.conv = conv_nd(
109
+ dims, self.channels, self.out_channels, 3, padding=padding
110
+ )
111
+
112
+ def forward(self, x):
113
+ assert x.shape[1] == self.channels
114
+ if self.dims == 3:
115
+ x = F.interpolate(
116
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
117
+ )
118
+ else:
119
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
120
+ if self.use_conv:
121
+ x = self.conv(x)
122
+ return x
123
+
124
+
125
+ class TransposedUpsample(nn.Module):
126
+ "Learned 2x upsampling without padding"
127
+
128
+ def __init__(self, channels, out_channels=None, ks=5):
129
+ super().__init__()
130
+ self.channels = channels
131
+ self.out_channels = out_channels or channels
132
+
133
+ self.up = nn.ConvTranspose2d(
134
+ self.channels, self.out_channels, kernel_size=ks, stride=2
135
+ )
136
+
137
+ def forward(self, x):
138
+ return self.up(x)
139
+
140
+
141
+ class Downsample(nn.Module):
142
+ """
143
+ A downsampling layer with an optional convolution.
144
+ :param channels: channels in the inputs and outputs.
145
+ :param use_conv: a bool determining if a convolution is applied.
146
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
147
+ downsampling occurs in the inner-two dimensions.
148
+ """
149
+
150
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
151
+ super().__init__()
152
+ self.channels = channels
153
+ self.out_channels = out_channels or channels
154
+ self.use_conv = use_conv
155
+ self.dims = dims
156
+ stride = 2 if dims != 3 else (1, 2, 2)
157
+ if use_conv:
158
+ self.op = conv_nd(
159
+ dims,
160
+ self.channels,
161
+ self.out_channels,
162
+ 3,
163
+ stride=stride,
164
+ padding=padding,
165
+ )
166
+ else:
167
+ assert self.channels == self.out_channels
168
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
169
+
170
+ def forward(self, x):
171
+ assert x.shape[1] == self.channels
172
+ return self.op(x)
173
+
174
+
175
+ class ResBlock(TimestepBlock):
176
+ """
177
+ A residual block that can optionally change the number of channels.
178
+ :param channels: the number of input channels.
179
+ :param emb_channels: the number of timestep embedding channels.
180
+ :param dropout: the rate of dropout.
181
+ :param out_channels: if specified, the number of out channels.
182
+ :param use_conv: if True and out_channels is specified, use a spatial
183
+ convolution instead of a smaller 1x1 convolution to change the
184
+ channels in the skip connection.
185
+ :param dims: determines if the signal is 1D, 2D, or 3D.
186
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
187
+ :param up: if True, use this block for upsampling.
188
+ :param down: if True, use this block for downsampling.
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ channels,
194
+ emb_channels,
195
+ dropout,
196
+ out_channels=None,
197
+ use_conv=False,
198
+ use_scale_shift_norm=False,
199
+ dims=2,
200
+ use_checkpoint=False,
201
+ up=False,
202
+ down=False,
203
+ ):
204
+ super().__init__()
205
+ self.channels = channels
206
+ self.emb_channels = emb_channels
207
+ self.dropout = dropout
208
+ self.out_channels = out_channels or channels
209
+ self.use_conv = use_conv
210
+ self.use_checkpoint = use_checkpoint
211
+ self.use_scale_shift_norm = use_scale_shift_norm
212
+
213
+ self.in_layers = nn.Sequential(
214
+ normalization(channels),
215
+ nn.SiLU(),
216
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
217
+ )
218
+
219
+ self.updown = up or down
220
+
221
+ if up:
222
+ self.h_upd = Upsample(channels, False, dims)
223
+ self.x_upd = Upsample(channels, False, dims)
224
+ elif down:
225
+ self.h_upd = Downsample(channels, False, dims)
226
+ self.x_upd = Downsample(channels, False, dims)
227
+ else:
228
+ self.h_upd = self.x_upd = nn.Identity()
229
+
230
+ self.emb_layers = nn.Sequential(
231
+ nn.SiLU(),
232
+ linear(
233
+ emb_channels,
234
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
235
+ ),
236
+ )
237
+ self.out_layers = nn.Sequential(
238
+ normalization(self.out_channels),
239
+ nn.SiLU(),
240
+ nn.Dropout(p=dropout),
241
+ zero_module(
242
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
243
+ ),
244
+ )
245
+
246
+ if self.out_channels == channels:
247
+ self.skip_connection = nn.Identity()
248
+ elif use_conv:
249
+ self.skip_connection = conv_nd(
250
+ dims, channels, self.out_channels, 3, padding=1
251
+ )
252
+ else:
253
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
254
+
255
+ def forward(self, x, emb):
256
+ """
257
+ Apply the block to a Tensor, conditioned on a timestep embedding.
258
+ :param x: an [N x C x ...] Tensor of features.
259
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
260
+ :return: an [N x C x ...] Tensor of outputs.
261
+ """
262
+ return checkpoint(
263
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
264
+ )
265
+
266
+ def _forward(self, x, emb):
267
+ if self.updown:
268
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
269
+ h = in_rest(x)
270
+ h = self.h_upd(h)
271
+ x = self.x_upd(x)
272
+ h = in_conv(h)
273
+ else:
274
+ h = self.in_layers(x)
275
+ emb_out = self.emb_layers(emb).type(h.dtype)
276
+ while len(emb_out.shape) < len(h.shape):
277
+ emb_out = emb_out[..., None]
278
+ if self.use_scale_shift_norm:
279
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
280
+ scale, shift = th.chunk(emb_out, 2, dim=1)
281
+ h = out_norm(h) * (1 + scale) + shift
282
+ h = out_rest(h)
283
+ else:
284
+ h = h + emb_out
285
+ h = self.out_layers(h)
286
+ return self.skip_connection(x) + h
287
+
288
+
289
+ class AttentionBlock(nn.Module):
290
+ """
291
+ An attention block that allows spatial positions to attend to each other.
292
+ Originally ported from here, but adapted to the N-d case.
293
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ channels,
299
+ num_heads=1,
300
+ num_head_channels=-1,
301
+ use_checkpoint=False,
302
+ use_new_attention_order=False,
303
+ ):
304
+ super().__init__()
305
+ self.channels = channels
306
+ if num_head_channels == -1:
307
+ self.num_heads = num_heads
308
+ else:
309
+ assert (
310
+ channels % num_head_channels == 0
311
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
312
+ self.num_heads = channels // num_head_channels
313
+ self.use_checkpoint = use_checkpoint
314
+ self.norm = normalization(channels)
315
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
316
+ if use_new_attention_order:
317
+ # split qkv before split heads
318
+ self.attention = QKVAttention(self.num_heads)
319
+ else:
320
+ # split heads before split qkv
321
+ self.attention = QKVAttentionLegacy(self.num_heads)
322
+
323
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
324
+
325
+ def forward(self, x):
326
+ return checkpoint(
327
+ self._forward, (x,), self.parameters(), True
328
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
329
+ # return pt_checkpoint(self._forward, x) # pytorch
330
+
331
+ def _forward(self, x):
332
+ b, c, *spatial = x.shape
333
+ x = x.reshape(b, c, -1).contiguous()
334
+ qkv = self.qkv(self.norm(x)).contiguous()
335
+ h = self.attention(qkv).contiguous()
336
+ h = self.proj_out(h).contiguous()
337
+ return (x + h).reshape(b, c, *spatial).contiguous()
338
+
339
+
340
+ def count_flops_attn(model, _x, y):
341
+ """
342
+ A counter for the `thop` package to count the operations in an
343
+ attention operation.
344
+ Meant to be used like:
345
+ macs, params = thop.profile(
346
+ model,
347
+ inputs=(inputs, timestamps),
348
+ custom_ops={QKVAttention: QKVAttention.count_flops},
349
+ )
350
+ """
351
+ b, c, *spatial = y[0].shape
352
+ num_spatial = int(np.prod(spatial))
353
+ # We perform two matmuls with the same number of ops.
354
+ # The first computes the weight matrix, the second computes
355
+ # the combination of the value vectors.
356
+ matmul_ops = 2 * b * (num_spatial**2) * c
357
+ model.total_ops += th.DoubleTensor([matmul_ops])
358
+
359
+
360
+ class QKVAttentionLegacy(nn.Module):
361
+ """
362
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
363
+ """
364
+
365
+ def __init__(self, n_heads):
366
+ super().__init__()
367
+ self.n_heads = n_heads
368
+
369
+ def forward(self, qkv):
370
+ """
371
+ Apply QKV attention.
372
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
373
+ :return: an [N x (H * C) x T] tensor after attention.
374
+ """
375
+ bs, width, length = qkv.shape
376
+ assert width % (3 * self.n_heads) == 0
377
+ ch = width // (3 * self.n_heads)
378
+ q, k, v = (
379
+ qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1)
380
+ )
381
+ scale = 1 / math.sqrt(math.sqrt(ch))
382
+ weight = th.einsum(
383
+ "bct,bcs->bts", q * scale, k * scale
384
+ ) # More stable with f16 than dividing afterwards
385
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
386
+ a = th.einsum("bts,bcs->bct", weight, v)
387
+ return a.reshape(bs, -1, length).contiguous()
388
+
389
+ @staticmethod
390
+ def count_flops(model, _x, y):
391
+ return count_flops_attn(model, _x, y)
392
+
393
+
394
+ class QKVAttention(nn.Module):
395
+ """
396
+ A module which performs QKV attention and splits in a different order.
397
+ """
398
+
399
+ def __init__(self, n_heads):
400
+ super().__init__()
401
+ self.n_heads = n_heads
402
+
403
+ def forward(self, qkv):
404
+ """
405
+ Apply QKV attention.
406
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
407
+ :return: an [N x (H * C) x T] tensor after attention.
408
+ """
409
+ bs, width, length = qkv.shape
410
+ assert width % (3 * self.n_heads) == 0
411
+ ch = width // (3 * self.n_heads)
412
+ q, k, v = qkv.chunk(3, dim=1)
413
+ scale = 1 / math.sqrt(math.sqrt(ch))
414
+ weight = th.einsum(
415
+ "bct,bcs->bts",
416
+ (q * scale).view(bs * self.n_heads, ch, length),
417
+ (k * scale).view(bs * self.n_heads, ch, length),
418
+ ) # More stable with f16 than dividing afterwards
419
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
420
+ a = th.einsum(
421
+ "bts,bcs->bct",
422
+ weight,
423
+ v.reshape(bs * self.n_heads, ch, length).contiguous(),
424
+ )
425
+ return a.reshape(bs, -1, length).contiguous()
426
+
427
+ @staticmethod
428
+ def count_flops(model, _x, y):
429
+ return count_flops_attn(model, _x, y)
430
+
431
+
432
+ class UNetModel(nn.Module):
433
+ """
434
+ The full UNet model with attention and timestep embedding.
435
+ :param in_channels: channels in the input Tensor.
436
+ :param model_channels: base channel count for the model.
437
+ :param out_channels: channels in the output Tensor.
438
+ :param num_res_blocks: number of residual blocks per downsample.
439
+ :param attention_resolutions: a collection of downsample rates at which
440
+ attention will take place. May be a set, list, or tuple.
441
+ For example, if this contains 4, then at 4x downsampling, attention
442
+ will be used.
443
+ :param dropout: the dropout probability.
444
+ :param channel_mult: channel multiplier for each level of the UNet.
445
+ :param conv_resample: if True, use learned convolutions for upsampling and
446
+ downsampling.
447
+ :param dims: determines if the signal is 1D, 2D, or 3D.
448
+ :param num_classes: if specified (as an int), then this model will be
449
+ class-conditional with `num_classes` classes.
450
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
451
+ :param num_heads: the number of attention heads in each attention layer.
452
+ :param num_heads_channels: if specified, ignore num_heads and instead use
453
+ a fixed channel width per attention head.
454
+ :param num_heads_upsample: works with num_heads to set a different number
455
+ of heads for upsampling. Deprecated.
456
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
457
+ :param resblock_updown: use residual blocks for up/downsampling.
458
+ :param use_new_attention_order: use a different attention pattern for potentially
459
+ increased efficiency.
460
+ """
461
+
462
+ def __init__(
463
+ self,
464
+ image_size,
465
+ in_channels,
466
+ model_channels,
467
+ out_channels,
468
+ num_res_blocks,
469
+ attention_resolutions,
470
+ dropout=0,
471
+ channel_mult=(1, 2, 4, 8),
472
+ conv_resample=True,
473
+ dims=2,
474
+ num_classes=None,
475
+ extra_film_condition_dim=None,
476
+ use_checkpoint=False,
477
+ use_fp16=False,
478
+ num_heads=-1,
479
+ num_head_channels=-1,
480
+ num_heads_upsample=-1,
481
+ use_scale_shift_norm=False,
482
+ extra_film_use_concat=False, # If true, concatenate extrafilm condition with time embedding, else addition
483
+ resblock_updown=False,
484
+ use_new_attention_order=False,
485
+ use_spatial_transformer=False, # custom transformer support
486
+ transformer_depth=1, # custom transformer support
487
+ context_dim=None, # custom transformer support
488
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
489
+ legacy=True,
490
+ ):
491
+ super().__init__()
492
+ if num_heads_upsample == -1:
493
+ num_heads_upsample = num_heads
494
+
495
+ if num_heads == -1:
496
+ assert (
497
+ num_head_channels != -1
498
+ ), "Either num_heads or num_head_channels has to be set"
499
+
500
+ if num_head_channels == -1:
501
+ assert (
502
+ num_heads != -1
503
+ ), "Either num_heads or num_head_channels has to be set"
504
+
505
+ self.image_size = image_size
506
+ self.in_channels = in_channels
507
+ self.model_channels = model_channels
508
+ self.out_channels = out_channels
509
+ self.num_res_blocks = num_res_blocks
510
+ self.attention_resolutions = attention_resolutions
511
+ self.dropout = dropout
512
+ self.channel_mult = channel_mult
513
+ self.conv_resample = conv_resample
514
+ self.num_classes = num_classes
515
+ self.extra_film_condition_dim = extra_film_condition_dim
516
+ self.use_checkpoint = use_checkpoint
517
+ self.dtype = th.float16 if use_fp16 else th.float32
518
+ self.num_heads = num_heads
519
+ self.num_head_channels = num_head_channels
520
+ self.num_heads_upsample = num_heads_upsample
521
+ self.predict_codebook_ids = n_embed is not None
522
+ self.extra_film_use_concat = extra_film_use_concat
523
+ time_embed_dim = model_channels * 4
524
+ self.time_embed = nn.Sequential(
525
+ linear(model_channels, time_embed_dim),
526
+ nn.SiLU(),
527
+ linear(time_embed_dim, time_embed_dim),
528
+ )
529
+
530
+ assert not (
531
+ self.num_classes is not None and self.extra_film_condition_dim is not None
532
+ ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim."
533
+
534
+ if self.num_classes is not None:
535
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
536
+
537
+ self.use_extra_film_by_concat = (
538
+ self.extra_film_condition_dim is not None and self.extra_film_use_concat
539
+ )
540
+ self.use_extra_film_by_addition = (
541
+ self.extra_film_condition_dim is not None and not self.extra_film_use_concat
542
+ )
543
+
544
+ if self.extra_film_condition_dim is not None:
545
+ self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim)
546
+ # print("+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " % self.extra_film_condition_dim)
547
+ # if(self.use_extra_film_by_concat):
548
+ # print("\t By concatenation with time embedding")
549
+ # elif(self.use_extra_film_by_concat):
550
+ # print("\t By addition with time embedding")
551
+
552
+ if use_spatial_transformer and (
553
+ self.use_extra_film_by_concat or self.use_extra_film_by_addition
554
+ ):
555
+ # print("+ Spatial transformer will only be used as self-attention. Because you have choose to use film as your global condition.")
556
+ spatial_transformer_no_context = True
557
+ else:
558
+ spatial_transformer_no_context = False
559
+
560
+ if use_spatial_transformer and not spatial_transformer_no_context:
561
+ assert (
562
+ context_dim is not None
563
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
564
+
565
+ if context_dim is not None and not spatial_transformer_no_context:
566
+ assert (
567
+ use_spatial_transformer
568
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
569
+ from omegaconf.listconfig import ListConfig
570
+
571
+ if type(context_dim) == ListConfig:
572
+ context_dim = list(context_dim)
573
+
574
+ self.input_blocks = nn.ModuleList(
575
+ [
576
+ TimestepEmbedSequential(
577
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
578
+ )
579
+ ]
580
+ )
581
+ self._feature_size = model_channels
582
+ input_block_chans = [model_channels]
583
+ ch = model_channels
584
+ ds = 1
585
+ for level, mult in enumerate(channel_mult):
586
+ for _ in range(num_res_blocks):
587
+ layers = [
588
+ ResBlock(
589
+ ch,
590
+ time_embed_dim
591
+ if (not self.use_extra_film_by_concat)
592
+ else time_embed_dim * 2,
593
+ dropout,
594
+ out_channels=mult * model_channels,
595
+ dims=dims,
596
+ use_checkpoint=use_checkpoint,
597
+ use_scale_shift_norm=use_scale_shift_norm,
598
+ )
599
+ ]
600
+ ch = mult * model_channels
601
+ if ds in attention_resolutions:
602
+ if num_head_channels == -1:
603
+ dim_head = ch // num_heads
604
+ else:
605
+ num_heads = ch // num_head_channels
606
+ dim_head = num_head_channels
607
+ if legacy:
608
+ dim_head = (
609
+ ch // num_heads
610
+ if use_spatial_transformer
611
+ else num_head_channels
612
+ )
613
+ layers.append(
614
+ AttentionBlock(
615
+ ch,
616
+ use_checkpoint=use_checkpoint,
617
+ num_heads=num_heads,
618
+ num_head_channels=dim_head,
619
+ use_new_attention_order=use_new_attention_order,
620
+ )
621
+ if not use_spatial_transformer
622
+ else SpatialTransformer(
623
+ ch,
624
+ num_heads,
625
+ dim_head,
626
+ depth=transformer_depth,
627
+ context_dim=context_dim,
628
+ no_context=spatial_transformer_no_context,
629
+ )
630
+ )
631
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
632
+ self._feature_size += ch
633
+ input_block_chans.append(ch)
634
+ if level != len(channel_mult) - 1:
635
+ out_ch = ch
636
+ self.input_blocks.append(
637
+ TimestepEmbedSequential(
638
+ ResBlock(
639
+ ch,
640
+ time_embed_dim
641
+ if (not self.use_extra_film_by_concat)
642
+ else time_embed_dim * 2,
643
+ dropout,
644
+ out_channels=out_ch,
645
+ dims=dims,
646
+ use_checkpoint=use_checkpoint,
647
+ use_scale_shift_norm=use_scale_shift_norm,
648
+ down=True,
649
+ )
650
+ if resblock_updown
651
+ else Downsample(
652
+ ch, conv_resample, dims=dims, out_channels=out_ch
653
+ )
654
+ )
655
+ )
656
+ ch = out_ch
657
+ input_block_chans.append(ch)
658
+ ds *= 2
659
+ self._feature_size += ch
660
+
661
+ if num_head_channels == -1:
662
+ dim_head = ch // num_heads
663
+ else:
664
+ num_heads = ch // num_head_channels
665
+ dim_head = num_head_channels
666
+ if legacy:
667
+ # num_heads = 1
668
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
669
+ self.middle_block = TimestepEmbedSequential(
670
+ ResBlock(
671
+ ch,
672
+ time_embed_dim
673
+ if (not self.use_extra_film_by_concat)
674
+ else time_embed_dim * 2,
675
+ dropout,
676
+ dims=dims,
677
+ use_checkpoint=use_checkpoint,
678
+ use_scale_shift_norm=use_scale_shift_norm,
679
+ ),
680
+ AttentionBlock(
681
+ ch,
682
+ use_checkpoint=use_checkpoint,
683
+ num_heads=num_heads,
684
+ num_head_channels=dim_head,
685
+ use_new_attention_order=use_new_attention_order,
686
+ )
687
+ if not use_spatial_transformer
688
+ else SpatialTransformer(
689
+ ch,
690
+ num_heads,
691
+ dim_head,
692
+ depth=transformer_depth,
693
+ context_dim=context_dim,
694
+ no_context=spatial_transformer_no_context,
695
+ ),
696
+ ResBlock(
697
+ ch,
698
+ time_embed_dim
699
+ if (not self.use_extra_film_by_concat)
700
+ else time_embed_dim * 2,
701
+ dropout,
702
+ dims=dims,
703
+ use_checkpoint=use_checkpoint,
704
+ use_scale_shift_norm=use_scale_shift_norm,
705
+ ),
706
+ )
707
+ self._feature_size += ch
708
+
709
+ self.output_blocks = nn.ModuleList([])
710
+ for level, mult in list(enumerate(channel_mult))[::-1]:
711
+ for i in range(num_res_blocks + 1):
712
+ ich = input_block_chans.pop()
713
+ layers = [
714
+ ResBlock(
715
+ ch + ich,
716
+ time_embed_dim
717
+ if (not self.use_extra_film_by_concat)
718
+ else time_embed_dim * 2,
719
+ dropout,
720
+ out_channels=model_channels * mult,
721
+ dims=dims,
722
+ use_checkpoint=use_checkpoint,
723
+ use_scale_shift_norm=use_scale_shift_norm,
724
+ )
725
+ ]
726
+ ch = model_channels * mult
727
+ if ds in attention_resolutions:
728
+ if num_head_channels == -1:
729
+ dim_head = ch // num_heads
730
+ else:
731
+ num_heads = ch // num_head_channels
732
+ dim_head = num_head_channels
733
+ if legacy:
734
+ # num_heads = 1
735
+ dim_head = (
736
+ ch // num_heads
737
+ if use_spatial_transformer
738
+ else num_head_channels
739
+ )
740
+ layers.append(
741
+ AttentionBlock(
742
+ ch,
743
+ use_checkpoint=use_checkpoint,
744
+ num_heads=num_heads_upsample,
745
+ num_head_channels=dim_head,
746
+ use_new_attention_order=use_new_attention_order,
747
+ )
748
+ if not use_spatial_transformer
749
+ else SpatialTransformer(
750
+ ch,
751
+ num_heads,
752
+ dim_head,
753
+ depth=transformer_depth,
754
+ context_dim=context_dim,
755
+ no_context=spatial_transformer_no_context,
756
+ )
757
+ )
758
+ if level and i == num_res_blocks:
759
+ out_ch = ch
760
+ layers.append(
761
+ ResBlock(
762
+ ch,
763
+ time_embed_dim
764
+ if (not self.use_extra_film_by_concat)
765
+ else time_embed_dim * 2,
766
+ dropout,
767
+ out_channels=out_ch,
768
+ dims=dims,
769
+ use_checkpoint=use_checkpoint,
770
+ use_scale_shift_norm=use_scale_shift_norm,
771
+ up=True,
772
+ )
773
+ if resblock_updown
774
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
775
+ )
776
+ ds //= 2
777
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
778
+ self._feature_size += ch
779
+
780
+ self.out = nn.Sequential(
781
+ normalization(ch),
782
+ nn.SiLU(),
783
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
784
+ )
785
+ if self.predict_codebook_ids:
786
+ self.id_predictor = nn.Sequential(
787
+ normalization(ch),
788
+ conv_nd(dims, model_channels, n_embed, 1),
789
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
790
+ )
791
+
792
+ self.shape_reported = False
793
+
794
+ def convert_to_fp16(self):
795
+ """
796
+ Convert the torso of the model to float16.
797
+ """
798
+ self.input_blocks.apply(convert_module_to_f16)
799
+ self.middle_block.apply(convert_module_to_f16)
800
+ self.output_blocks.apply(convert_module_to_f16)
801
+
802
+ def convert_to_fp32(self):
803
+ """
804
+ Convert the torso of the model to float32.
805
+ """
806
+ self.input_blocks.apply(convert_module_to_f32)
807
+ self.middle_block.apply(convert_module_to_f32)
808
+ self.output_blocks.apply(convert_module_to_f32)
809
+
810
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
811
+ """
812
+ Apply the model to an input batch.
813
+ :param x: an [N x C x ...] Tensor of inputs.
814
+ :param timesteps: a 1-D batch of timesteps.
815
+ :param context: conditioning plugged in via crossattn
816
+ :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional
817
+ :return: an [N x C x ...] Tensor of outputs.
818
+ """
819
+ if not self.shape_reported:
820
+ # print("The shape of UNet input is", x.size())
821
+ self.shape_reported = True
822
+
823
+ assert (y is not None) == (
824
+ self.num_classes is not None or self.extra_film_condition_dim is not None
825
+ ), "must specify y if and only if the model is class-conditional or film embedding conditional"
826
+ hs = []
827
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
828
+ emb = self.time_embed(t_emb)
829
+
830
+ if self.num_classes is not None:
831
+ assert y.shape == (x.shape[0],)
832
+ emb = emb + self.label_emb(y)
833
+
834
+ if self.use_extra_film_by_addition:
835
+ emb = emb + self.film_emb(y)
836
+ elif self.use_extra_film_by_concat:
837
+ emb = th.cat([emb, self.film_emb(y)], dim=-1)
838
+
839
+ h = x.type(self.dtype)
840
+ for module in self.input_blocks:
841
+ h = module(h, emb, context)
842
+ hs.append(h)
843
+ h = self.middle_block(h, emb, context)
844
+ for module in self.output_blocks:
845
+ h = th.cat([h, hs.pop()], dim=1)
846
+ h = module(h, emb, context)
847
+ h = h.type(x.dtype)
848
+ if self.predict_codebook_ids:
849
+ return self.id_predictor(h)
850
+ else:
851
+ return self.out(h)
852
+
853
+
854
+ class EncoderUNetModel(nn.Module):
855
+ """
856
+ The half UNet model with attention and timestep embedding.
857
+ For usage, see UNet.
858
+ """
859
+
860
+ def __init__(
861
+ self,
862
+ image_size,
863
+ in_channels,
864
+ model_channels,
865
+ out_channels,
866
+ num_res_blocks,
867
+ attention_resolutions,
868
+ dropout=0,
869
+ channel_mult=(1, 2, 4, 8),
870
+ conv_resample=True,
871
+ dims=2,
872
+ use_checkpoint=False,
873
+ use_fp16=False,
874
+ num_heads=1,
875
+ num_head_channels=-1,
876
+ num_heads_upsample=-1,
877
+ use_scale_shift_norm=False,
878
+ resblock_updown=False,
879
+ use_new_attention_order=False,
880
+ pool="adaptive",
881
+ *args,
882
+ **kwargs,
883
+ ):
884
+ super().__init__()
885
+
886
+ if num_heads_upsample == -1:
887
+ num_heads_upsample = num_heads
888
+
889
+ self.in_channels = in_channels
890
+ self.model_channels = model_channels
891
+ self.out_channels = out_channels
892
+ self.num_res_blocks = num_res_blocks
893
+ self.attention_resolutions = attention_resolutions
894
+ self.dropout = dropout
895
+ self.channel_mult = channel_mult
896
+ self.conv_resample = conv_resample
897
+ self.use_checkpoint = use_checkpoint
898
+ self.dtype = th.float16 if use_fp16 else th.float32
899
+ self.num_heads = num_heads
900
+ self.num_head_channels = num_head_channels
901
+ self.num_heads_upsample = num_heads_upsample
902
+
903
+ time_embed_dim = model_channels * 4
904
+ self.time_embed = nn.Sequential(
905
+ linear(model_channels, time_embed_dim),
906
+ nn.SiLU(),
907
+ linear(time_embed_dim, time_embed_dim),
908
+ )
909
+
910
+ self.input_blocks = nn.ModuleList(
911
+ [
912
+ TimestepEmbedSequential(
913
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
914
+ )
915
+ ]
916
+ )
917
+ self._feature_size = model_channels
918
+ input_block_chans = [model_channels]
919
+ ch = model_channels
920
+ ds = 1
921
+ for level, mult in enumerate(channel_mult):
922
+ for _ in range(num_res_blocks):
923
+ layers = [
924
+ ResBlock(
925
+ ch,
926
+ time_embed_dim,
927
+ dropout,
928
+ out_channels=mult * model_channels,
929
+ dims=dims,
930
+ use_checkpoint=use_checkpoint,
931
+ use_scale_shift_norm=use_scale_shift_norm,
932
+ )
933
+ ]
934
+ ch = mult * model_channels
935
+ if ds in attention_resolutions:
936
+ layers.append(
937
+ AttentionBlock(
938
+ ch,
939
+ use_checkpoint=use_checkpoint,
940
+ num_heads=num_heads,
941
+ num_head_channels=num_head_channels,
942
+ use_new_attention_order=use_new_attention_order,
943
+ )
944
+ )
945
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
946
+ self._feature_size += ch
947
+ input_block_chans.append(ch)
948
+ if level != len(channel_mult) - 1:
949
+ out_ch = ch
950
+ self.input_blocks.append(
951
+ TimestepEmbedSequential(
952
+ ResBlock(
953
+ ch,
954
+ time_embed_dim,
955
+ dropout,
956
+ out_channels=out_ch,
957
+ dims=dims,
958
+ use_checkpoint=use_checkpoint,
959
+ use_scale_shift_norm=use_scale_shift_norm,
960
+ down=True,
961
+ )
962
+ if resblock_updown
963
+ else Downsample(
964
+ ch, conv_resample, dims=dims, out_channels=out_ch
965
+ )
966
+ )
967
+ )
968
+ ch = out_ch
969
+ input_block_chans.append(ch)
970
+ ds *= 2
971
+ self._feature_size += ch
972
+
973
+ self.middle_block = TimestepEmbedSequential(
974
+ ResBlock(
975
+ ch,
976
+ time_embed_dim,
977
+ dropout,
978
+ dims=dims,
979
+ use_checkpoint=use_checkpoint,
980
+ use_scale_shift_norm=use_scale_shift_norm,
981
+ ),
982
+ AttentionBlock(
983
+ ch,
984
+ use_checkpoint=use_checkpoint,
985
+ num_heads=num_heads,
986
+ num_head_channels=num_head_channels,
987
+ use_new_attention_order=use_new_attention_order,
988
+ ),
989
+ ResBlock(
990
+ ch,
991
+ time_embed_dim,
992
+ dropout,
993
+ dims=dims,
994
+ use_checkpoint=use_checkpoint,
995
+ use_scale_shift_norm=use_scale_shift_norm,
996
+ ),
997
+ )
998
+ self._feature_size += ch
999
+ self.pool = pool
1000
+ if pool == "adaptive":
1001
+ self.out = nn.Sequential(
1002
+ normalization(ch),
1003
+ nn.SiLU(),
1004
+ nn.AdaptiveAvgPool2d((1, 1)),
1005
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
1006
+ nn.Flatten(),
1007
+ )
1008
+ elif pool == "attention":
1009
+ assert num_head_channels != -1
1010
+ self.out = nn.Sequential(
1011
+ normalization(ch),
1012
+ nn.SiLU(),
1013
+ AttentionPool2d(
1014
+ (image_size // ds), ch, num_head_channels, out_channels
1015
+ ),
1016
+ )
1017
+ elif pool == "spatial":
1018
+ self.out = nn.Sequential(
1019
+ nn.Linear(self._feature_size, 2048),
1020
+ nn.ReLU(),
1021
+ nn.Linear(2048, self.out_channels),
1022
+ )
1023
+ elif pool == "spatial_v2":
1024
+ self.out = nn.Sequential(
1025
+ nn.Linear(self._feature_size, 2048),
1026
+ normalization(2048),
1027
+ nn.SiLU(),
1028
+ nn.Linear(2048, self.out_channels),
1029
+ )
1030
+ else:
1031
+ raise NotImplementedError(f"Unexpected {pool} pooling")
1032
+
1033
+ def convert_to_fp16(self):
1034
+ """
1035
+ Convert the torso of the model to float16.
1036
+ """
1037
+ self.input_blocks.apply(convert_module_to_f16)
1038
+ self.middle_block.apply(convert_module_to_f16)
1039
+
1040
+ def convert_to_fp32(self):
1041
+ """
1042
+ Convert the torso of the model to float32.
1043
+ """
1044
+ self.input_blocks.apply(convert_module_to_f32)
1045
+ self.middle_block.apply(convert_module_to_f32)
1046
+
1047
+ def forward(self, x, timesteps):
1048
+ """
1049
+ Apply the model to an input batch.
1050
+ :param x: an [N x C x ...] Tensor of inputs.
1051
+ :param timesteps: a 1-D batch of timesteps.
1052
+ :return: an [N x K] Tensor of outputs.
1053
+ """
1054
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1055
+
1056
+ results = []
1057
+ h = x.type(self.dtype)
1058
+ for module in self.input_blocks:
1059
+ h = module(h, emb)
1060
+ if self.pool.startswith("spatial"):
1061
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1062
+ h = self.middle_block(h, emb)
1063
+ if self.pool.startswith("spatial"):
1064
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1065
+ h = th.cat(results, axis=-1)
1066
+ return self.out(h)
1067
+ else:
1068
+ h = h.type(x.dtype)
1069
+ return self.out(h)
audioldm/latent_diffusion/util.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ from audioldm.utils import instantiate_from_config
19
+
20
+
21
+ def make_beta_schedule(
22
+ schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
23
+ ):
24
+ if schedule == "linear":
25
+ betas = (
26
+ torch.linspace(
27
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
28
+ )
29
+ ** 2
30
+ )
31
+
32
+ elif schedule == "cosine":
33
+ timesteps = (
34
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
35
+ )
36
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
37
+ alphas = torch.cos(alphas).pow(2)
38
+ alphas = alphas / alphas[0]
39
+ betas = 1 - alphas[1:] / alphas[:-1]
40
+ betas = np.clip(betas, a_min=0, a_max=0.999)
41
+
42
+ elif schedule == "sqrt_linear":
43
+ betas = torch.linspace(
44
+ linear_start, linear_end, n_timestep, dtype=torch.float64
45
+ )
46
+ elif schedule == "sqrt":
47
+ betas = (
48
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
49
+ ** 0.5
50
+ )
51
+ else:
52
+ raise ValueError(f"schedule '{schedule}' unknown.")
53
+ return betas.numpy()
54
+
55
+
56
+ def make_ddim_timesteps(
57
+ ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
58
+ ):
59
+ if ddim_discr_method == "uniform":
60
+ c = num_ddpm_timesteps // num_ddim_timesteps
61
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
62
+ elif ddim_discr_method == "quad":
63
+ ddim_timesteps = (
64
+ (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
65
+ ).astype(int)
66
+ else:
67
+ raise NotImplementedError(
68
+ f'There is no ddim discretization method called "{ddim_discr_method}"'
69
+ )
70
+
71
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
72
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
73
+ steps_out = ddim_timesteps + 1
74
+ if verbose:
75
+ print(f"Selected timesteps for ddim sampler: {steps_out}")
76
+ return steps_out
77
+
78
+
79
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
80
+ # select alphas for computing the variance schedule
81
+ alphas = alphacums[ddim_timesteps]
82
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
83
+
84
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
85
+ sigmas = eta * np.sqrt(
86
+ (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
87
+ )
88
+ if verbose:
89
+ print(
90
+ f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
91
+ )
92
+ print(
93
+ f"For the chosen value of eta, which is {eta}, "
94
+ f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
95
+ )
96
+ return sigmas, alphas, alphas_prev
97
+
98
+
99
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
100
+ """
101
+ Create a beta schedule that discretizes the given alpha_t_bar function,
102
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
103
+ :param num_diffusion_timesteps: the number of betas to produce.
104
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
105
+ produces the cumulative product of (1-beta) up to that
106
+ part of the diffusion process.
107
+ :param max_beta: the maximum beta to use; use values lower than 1 to
108
+ prevent singularities.
109
+ """
110
+ betas = []
111
+ for i in range(num_diffusion_timesteps):
112
+ t1 = i / num_diffusion_timesteps
113
+ t2 = (i + 1) / num_diffusion_timesteps
114
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
115
+ return np.array(betas)
116
+
117
+
118
+ def extract_into_tensor(a, t, x_shape):
119
+ b, *_ = t.shape
120
+ out = a.gather(-1, t).contiguous()
121
+ return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous()
122
+
123
+
124
+ def checkpoint(func, inputs, params, flag):
125
+ """
126
+ Evaluate a function without caching intermediate activations, allowing for
127
+ reduced memory at the expense of extra compute in the backward pass.
128
+ :param func: the function to evaluate.
129
+ :param inputs: the argument sequence to pass to `func`.
130
+ :param params: a sequence of parameters `func` depends on but does not
131
+ explicitly take as arguments.
132
+ :param flag: if False, disable gradient checkpointing.
133
+ """
134
+ if flag:
135
+ args = tuple(inputs) + tuple(params)
136
+ return CheckpointFunction.apply(func, len(inputs), *args)
137
+ else:
138
+ return func(*inputs)
139
+
140
+
141
+ class CheckpointFunction(torch.autograd.Function):
142
+ @staticmethod
143
+ def forward(ctx, run_function, length, *args):
144
+ ctx.run_function = run_function
145
+ ctx.input_tensors = list(args[:length])
146
+ ctx.input_params = list(args[length:])
147
+
148
+ with torch.no_grad():
149
+ output_tensors = ctx.run_function(*ctx.input_tensors)
150
+ return output_tensors
151
+
152
+ @staticmethod
153
+ def backward(ctx, *output_grads):
154
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155
+ with torch.enable_grad():
156
+ # Fixes a bug where the first op in run_function modifies the
157
+ # Tensor storage in place, which is not allowed for detach()'d
158
+ # Tensors.
159
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160
+ output_tensors = ctx.run_function(*shallow_copies)
161
+ input_grads = torch.autograd.grad(
162
+ output_tensors,
163
+ ctx.input_tensors + ctx.input_params,
164
+ output_grads,
165
+ allow_unused=True,
166
+ )
167
+ del ctx.input_tensors
168
+ del ctx.input_params
169
+ del output_tensors
170
+ return (None, None) + input_grads
171
+
172
+
173
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
174
+ """
175
+ Create sinusoidal timestep embeddings.
176
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
177
+ These may be fractional.
178
+ :param dim: the dimension of the output.
179
+ :param max_period: controls the minimum frequency of the embeddings.
180
+ :return: an [N x dim] Tensor of positional embeddings.
181
+ """
182
+ if not repeat_only:
183
+ half = dim // 2
184
+ freqs = torch.exp(
185
+ -math.log(max_period)
186
+ * torch.arange(start=0, end=half, dtype=torch.float32)
187
+ / half
188
+ ).to(device=timesteps.device)
189
+ args = timesteps[:, None].float() * freqs[None]
190
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
191
+ if dim % 2:
192
+ embedding = torch.cat(
193
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
194
+ )
195
+ else:
196
+ embedding = repeat(timesteps, "b -> b d", d=dim)
197
+ return embedding
198
+
199
+
200
+ def zero_module(module):
201
+ """
202
+ Zero out the parameters of a module and return it.
203
+ """
204
+ for p in module.parameters():
205
+ p.detach().zero_()
206
+ return module
207
+
208
+
209
+ def scale_module(module, scale):
210
+ """
211
+ Scale the parameters of a module and return it.
212
+ """
213
+ for p in module.parameters():
214
+ p.detach().mul_(scale)
215
+ return module
216
+
217
+
218
+ def mean_flat(tensor):
219
+ """
220
+ Take the mean over all non-batch dimensions.
221
+ """
222
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
223
+
224
+
225
+ def normalization(channels):
226
+ """
227
+ Make a standard normalization layer.
228
+ :param channels: number of input channels.
229
+ :return: an nn.Module for normalization.
230
+ """
231
+ return GroupNorm32(32, channels)
232
+
233
+
234
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
235
+ class SiLU(nn.Module):
236
+ def forward(self, x):
237
+ return x * torch.sigmoid(x)
238
+
239
+
240
+ class GroupNorm32(nn.GroupNorm):
241
+ def forward(self, x):
242
+ return super().forward(x.float()).type(x.dtype)
243
+
244
+
245
+ def conv_nd(dims, *args, **kwargs):
246
+ """
247
+ Create a 1D, 2D, or 3D convolution module.
248
+ """
249
+ if dims == 1:
250
+ return nn.Conv1d(*args, **kwargs)
251
+ elif dims == 2:
252
+ return nn.Conv2d(*args, **kwargs)
253
+ elif dims == 3:
254
+ return nn.Conv3d(*args, **kwargs)
255
+ raise ValueError(f"unsupported dimensions: {dims}")
256
+
257
+
258
+ def linear(*args, **kwargs):
259
+ """
260
+ Create a linear module.
261
+ """
262
+ return nn.Linear(*args, **kwargs)
263
+
264
+
265
+ def avg_pool_nd(dims, *args, **kwargs):
266
+ """
267
+ Create a 1D, 2D, or 3D average pooling module.
268
+ """
269
+ if dims == 1:
270
+ return nn.AvgPool1d(*args, **kwargs)
271
+ elif dims == 2:
272
+ return nn.AvgPool2d(*args, **kwargs)
273
+ elif dims == 3:
274
+ return nn.AvgPool3d(*args, **kwargs)
275
+ raise ValueError(f"unsupported dimensions: {dims}")
276
+
277
+
278
+ class HybridConditioner(nn.Module):
279
+ def __init__(self, c_concat_config, c_crossattn_config):
280
+ super().__init__()
281
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
282
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
283
+
284
+ def forward(self, c_concat, c_crossattn):
285
+ c_concat = self.concat_conditioner(c_concat)
286
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
287
+ return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
288
+
289
+
290
+ def noise_like(shape, device, repeat=False):
291
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
292
+ shape[0], *((1,) * (len(shape) - 1))
293
+ )
294
+ noise = lambda: torch.randn(shape, device=device)
295
+ return repeat_noise() if repeat else noise()
audioldm/ldm.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from audioldm.utils import default, instantiate_from_config, save_wave
7
+ from audioldm.latent_diffusion.ddpm import DDPM
8
+ from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution
9
+ from audioldm.latent_diffusion.util import noise_like
10
+ from audioldm.latent_diffusion.ddim import DDIMSampler
11
+ import os
12
+
13
+
14
+ def disabled_train(self, mode=True):
15
+ """Overwrite model.train with this function to make sure train/eval mode
16
+ does not change anymore."""
17
+ return self
18
+
19
+
20
+ class LatentDiffusion(DDPM):
21
+ """main class"""
22
+
23
+ def __init__(
24
+ self,
25
+ device="cuda",
26
+ first_stage_config=None,
27
+ cond_stage_config=None,
28
+ num_timesteps_cond=None,
29
+ cond_stage_key="image",
30
+ cond_stage_trainable=False,
31
+ concat_mode=True,
32
+ cond_stage_forward=None,
33
+ conditioning_key=None,
34
+ scale_factor=1.0,
35
+ scale_by_std=False,
36
+ base_learning_rate=None,
37
+ *args,
38
+ **kwargs,
39
+ ):
40
+ self.device = device
41
+ self.learning_rate = base_learning_rate
42
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
43
+ self.scale_by_std = scale_by_std
44
+ assert self.num_timesteps_cond <= kwargs["timesteps"]
45
+ # for backwards compatibility after implementation of DiffusionWrapper
46
+ if conditioning_key is None:
47
+ conditioning_key = "concat" if concat_mode else "crossattn"
48
+ if cond_stage_config == "__is_unconditional__":
49
+ conditioning_key = None
50
+ ckpt_path = kwargs.pop("ckpt_path", None)
51
+ ignore_keys = kwargs.pop("ignore_keys", [])
52
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
53
+ self.concat_mode = concat_mode
54
+ self.cond_stage_trainable = cond_stage_trainable
55
+ self.cond_stage_key = cond_stage_key
56
+ self.cond_stage_key_orig = cond_stage_key
57
+ try:
58
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
59
+ except:
60
+ self.num_downs = 0
61
+ if not scale_by_std:
62
+ self.scale_factor = scale_factor
63
+ else:
64
+ self.register_buffer("scale_factor", torch.tensor(scale_factor))
65
+ self.instantiate_first_stage(first_stage_config)
66
+ self.instantiate_cond_stage(cond_stage_config)
67
+ self.cond_stage_forward = cond_stage_forward
68
+ self.clip_denoised = False
69
+
70
+ def make_cond_schedule(
71
+ self,
72
+ ):
73
+ self.cond_ids = torch.full(
74
+ size=(self.num_timesteps,),
75
+ fill_value=self.num_timesteps - 1,
76
+ dtype=torch.long,
77
+ )
78
+ ids = torch.round(
79
+ torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
80
+ ).long()
81
+ self.cond_ids[: self.num_timesteps_cond] = ids
82
+
83
+ def register_schedule(
84
+ self,
85
+ given_betas=None,
86
+ beta_schedule="linear",
87
+ timesteps=1000,
88
+ linear_start=1e-4,
89
+ linear_end=2e-2,
90
+ cosine_s=8e-3,
91
+ ):
92
+ super().register_schedule(
93
+ given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
94
+ )
95
+
96
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
97
+ if self.shorten_cond_schedule:
98
+ self.make_cond_schedule()
99
+
100
+ def instantiate_first_stage(self, config):
101
+ model = instantiate_from_config(config)
102
+ self.first_stage_model = model.eval()
103
+ self.first_stage_model.train = disabled_train
104
+ for param in self.first_stage_model.parameters():
105
+ param.requires_grad = False
106
+
107
+ def instantiate_cond_stage(self, config):
108
+ if not self.cond_stage_trainable:
109
+ if config == "__is_first_stage__":
110
+ print("Using first stage also as cond stage.")
111
+ self.cond_stage_model = self.first_stage_model
112
+ elif config == "__is_unconditional__":
113
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
114
+ self.cond_stage_model = None
115
+ # self.be_unconditional = True
116
+ else:
117
+ model = instantiate_from_config(config)
118
+ self.cond_stage_model = model.eval()
119
+ self.cond_stage_model.train = disabled_train
120
+ for param in self.cond_stage_model.parameters():
121
+ param.requires_grad = False
122
+ else:
123
+ assert config != "__is_first_stage__"
124
+ assert config != "__is_unconditional__"
125
+ model = instantiate_from_config(config)
126
+ self.cond_stage_model = model
127
+ self.cond_stage_model = self.cond_stage_model.to(self.device)
128
+
129
+ def get_first_stage_encoding(self, encoder_posterior):
130
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
131
+ z = encoder_posterior.sample()
132
+ elif isinstance(encoder_posterior, torch.Tensor):
133
+ z = encoder_posterior
134
+ else:
135
+ raise NotImplementedError(
136
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
137
+ )
138
+ return self.scale_factor * z
139
+
140
+ def get_learned_conditioning(self, c):
141
+ if self.cond_stage_forward is None:
142
+ if hasattr(self.cond_stage_model, "encode") and callable(
143
+ self.cond_stage_model.encode
144
+ ):
145
+ c = self.cond_stage_model.encode(c)
146
+ if isinstance(c, DiagonalGaussianDistribution):
147
+ c = c.mode()
148
+ else:
149
+ # Text input is list
150
+ if type(c) == list and len(c) == 1:
151
+ c = self.cond_stage_model([c[0], c[0]])
152
+ c = c[0:1]
153
+ else:
154
+ c = self.cond_stage_model(c)
155
+ else:
156
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
157
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
158
+ return c
159
+
160
+ @torch.no_grad()
161
+ def get_input(
162
+ self,
163
+ batch,
164
+ k,
165
+ return_first_stage_encode=True,
166
+ return_first_stage_outputs=False,
167
+ force_c_encode=False,
168
+ cond_key=None,
169
+ return_original_cond=False,
170
+ bs=None,
171
+ ):
172
+ x = super().get_input(batch, k)
173
+
174
+ if bs is not None:
175
+ x = x[:bs]
176
+
177
+ x = x.to(self.device)
178
+
179
+ if return_first_stage_encode:
180
+ encoder_posterior = self.encode_first_stage(x)
181
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
182
+ else:
183
+ z = None
184
+
185
+ if self.model.conditioning_key is not None:
186
+ if cond_key is None:
187
+ cond_key = self.cond_stage_key
188
+ if cond_key != self.first_stage_key:
189
+ if cond_key in ["caption", "coordinates_bbox"]:
190
+ xc = batch[cond_key]
191
+ elif cond_key == "class_label":
192
+ xc = batch
193
+ else:
194
+ # [bs, 1, 527]
195
+ xc = super().get_input(batch, cond_key)
196
+ if type(xc) == torch.Tensor:
197
+ xc = xc.to(self.device)
198
+ else:
199
+ xc = x
200
+ if not self.cond_stage_trainable or force_c_encode:
201
+ if isinstance(xc, dict) or isinstance(xc, list):
202
+ c = self.get_learned_conditioning(xc)
203
+ else:
204
+ c = self.get_learned_conditioning(xc.to(self.device))
205
+ else:
206
+ c = xc
207
+
208
+ if bs is not None:
209
+ c = c[:bs]
210
+
211
+ else:
212
+ c = None
213
+ xc = None
214
+ if self.use_positional_encodings:
215
+ pos_x, pos_y = self.compute_latent_shifts(batch)
216
+ c = {"pos_x": pos_x, "pos_y": pos_y}
217
+ out = [z, c]
218
+ if return_first_stage_outputs:
219
+ xrec = self.decode_first_stage(z)
220
+ out.extend([x, xrec])
221
+ if return_original_cond:
222
+ out.append(xc)
223
+ return out
224
+
225
+ @torch.no_grad()
226
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
227
+ if predict_cids:
228
+ if z.dim() == 4:
229
+ z = torch.argmax(z.exp(), dim=1).long()
230
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
231
+ z = rearrange(z, "b h w c -> b c h w").contiguous()
232
+
233
+ z = 1.0 / self.scale_factor * z
234
+ return self.first_stage_model.decode(z)
235
+
236
+ def mel_spectrogram_to_waveform(self, mel):
237
+ # Mel: [bs, 1, t-steps, fbins]
238
+ if len(mel.size()) == 4:
239
+ mel = mel.squeeze(1)
240
+ mel = mel.permute(0, 2, 1)
241
+ waveform = self.first_stage_model.vocoder(mel)
242
+ waveform = waveform.cpu().detach().numpy()
243
+ return waveform
244
+
245
+ @torch.no_grad()
246
+ def encode_first_stage(self, x):
247
+ return self.first_stage_model.encode(x)
248
+
249
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
250
+
251
+ if isinstance(cond, dict):
252
+ # hybrid case, cond is exptected to be a dict
253
+ pass
254
+ else:
255
+ if not isinstance(cond, list):
256
+ cond = [cond]
257
+ if self.model.conditioning_key == "concat":
258
+ key = "c_concat"
259
+ elif self.model.conditioning_key == "crossattn":
260
+ key = "c_crossattn"
261
+ else:
262
+ key = "c_film"
263
+
264
+ cond = {key: cond}
265
+
266
+ x_recon = self.model(x_noisy, t, **cond)
267
+
268
+ if isinstance(x_recon, tuple) and not return_ids:
269
+ return x_recon[0]
270
+ else:
271
+ return x_recon
272
+
273
+ def p_mean_variance(
274
+ self,
275
+ x,
276
+ c,
277
+ t,
278
+ clip_denoised: bool,
279
+ return_codebook_ids=False,
280
+ quantize_denoised=False,
281
+ return_x0=False,
282
+ score_corrector=None,
283
+ corrector_kwargs=None,
284
+ ):
285
+ t_in = t
286
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
287
+
288
+ if score_corrector is not None:
289
+ assert self.parameterization == "eps"
290
+ model_out = score_corrector.modify_score(
291
+ self, model_out, x, t, c, **corrector_kwargs
292
+ )
293
+
294
+ if return_codebook_ids:
295
+ model_out, logits = model_out
296
+
297
+ if self.parameterization == "eps":
298
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
299
+ elif self.parameterization == "x0":
300
+ x_recon = model_out
301
+ else:
302
+ raise NotImplementedError()
303
+
304
+ if clip_denoised:
305
+ x_recon.clamp_(-1.0, 1.0)
306
+ if quantize_denoised:
307
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
308
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
309
+ x_start=x_recon, x_t=x, t=t
310
+ )
311
+ if return_codebook_ids:
312
+ return model_mean, posterior_variance, posterior_log_variance, logits
313
+ elif return_x0:
314
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
315
+ else:
316
+ return model_mean, posterior_variance, posterior_log_variance
317
+
318
+ @torch.no_grad()
319
+ def p_sample(
320
+ self,
321
+ x,
322
+ c,
323
+ t,
324
+ clip_denoised=False,
325
+ repeat_noise=False,
326
+ return_codebook_ids=False,
327
+ quantize_denoised=False,
328
+ return_x0=False,
329
+ temperature=1.0,
330
+ noise_dropout=0.0,
331
+ score_corrector=None,
332
+ corrector_kwargs=None,
333
+ ):
334
+ b, *_, device = *x.shape, x.device
335
+ outputs = self.p_mean_variance(
336
+ x=x,
337
+ c=c,
338
+ t=t,
339
+ clip_denoised=clip_denoised,
340
+ return_codebook_ids=return_codebook_ids,
341
+ quantize_denoised=quantize_denoised,
342
+ return_x0=return_x0,
343
+ score_corrector=score_corrector,
344
+ corrector_kwargs=corrector_kwargs,
345
+ )
346
+ if return_codebook_ids:
347
+ raise DeprecationWarning("Support dropped.")
348
+ model_mean, _, model_log_variance, logits = outputs
349
+ elif return_x0:
350
+ model_mean, _, model_log_variance, x0 = outputs
351
+ else:
352
+ model_mean, _, model_log_variance = outputs
353
+
354
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
355
+ if noise_dropout > 0.0:
356
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
357
+ # no noise when t == 0
358
+ nonzero_mask = (
359
+ (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()
360
+ )
361
+
362
+ if return_codebook_ids:
363
+ return model_mean + nonzero_mask * (
364
+ 0.5 * model_log_variance
365
+ ).exp() * noise, logits.argmax(dim=1)
366
+ if return_x0:
367
+ return (
368
+ model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
369
+ x0,
370
+ )
371
+ else:
372
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
373
+
374
+ @torch.no_grad()
375
+ def progressive_denoising(
376
+ self,
377
+ cond,
378
+ shape,
379
+ verbose=True,
380
+ callback=None,
381
+ quantize_denoised=False,
382
+ img_callback=None,
383
+ mask=None,
384
+ x0=None,
385
+ temperature=1.0,
386
+ noise_dropout=0.0,
387
+ score_corrector=None,
388
+ corrector_kwargs=None,
389
+ batch_size=None,
390
+ x_T=None,
391
+ start_T=None,
392
+ log_every_t=None,
393
+ ):
394
+ if not log_every_t:
395
+ log_every_t = self.log_every_t
396
+ timesteps = self.num_timesteps
397
+ if batch_size is not None:
398
+ b = batch_size if batch_size is not None else shape[0]
399
+ shape = [batch_size] + list(shape)
400
+ else:
401
+ b = batch_size = shape[0]
402
+ if x_T is None:
403
+ img = torch.randn(shape, device=self.device)
404
+ else:
405
+ img = x_T
406
+ intermediates = []
407
+ if cond is not None:
408
+ if isinstance(cond, dict):
409
+ cond = {
410
+ key: cond[key][:batch_size]
411
+ if not isinstance(cond[key], list)
412
+ else list(map(lambda x: x[:batch_size], cond[key]))
413
+ for key in cond
414
+ }
415
+ else:
416
+ cond = (
417
+ [c[:batch_size] for c in cond]
418
+ if isinstance(cond, list)
419
+ else cond[:batch_size]
420
+ )
421
+
422
+ if start_T is not None:
423
+ timesteps = min(timesteps, start_T)
424
+ iterator = (
425
+ tqdm(
426
+ reversed(range(0, timesteps)),
427
+ desc="Progressive Generation",
428
+ total=timesteps,
429
+ )
430
+ if verbose
431
+ else reversed(range(0, timesteps))
432
+ )
433
+ if type(temperature) == float:
434
+ temperature = [temperature] * timesteps
435
+
436
+ for i in iterator:
437
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
438
+ if self.shorten_cond_schedule:
439
+ assert self.model.conditioning_key != "hybrid"
440
+ tc = self.cond_ids[ts].to(cond.device)
441
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
442
+
443
+ img, x0_partial = self.p_sample(
444
+ img,
445
+ cond,
446
+ ts,
447
+ clip_denoised=self.clip_denoised,
448
+ quantize_denoised=quantize_denoised,
449
+ return_x0=True,
450
+ temperature=temperature[i],
451
+ noise_dropout=noise_dropout,
452
+ score_corrector=score_corrector,
453
+ corrector_kwargs=corrector_kwargs,
454
+ )
455
+ if mask is not None:
456
+ assert x0 is not None
457
+ img_orig = self.q_sample(x0, ts)
458
+ img = img_orig * mask + (1.0 - mask) * img
459
+
460
+ if i % log_every_t == 0 or i == timesteps - 1:
461
+ intermediates.append(x0_partial)
462
+ if callback:
463
+ callback(i)
464
+ if img_callback:
465
+ img_callback(img, i)
466
+ return img, intermediates
467
+
468
+ @torch.no_grad()
469
+ def p_sample_loop(
470
+ self,
471
+ cond,
472
+ shape,
473
+ return_intermediates=False,
474
+ x_T=None,
475
+ verbose=True,
476
+ callback=None,
477
+ timesteps=None,
478
+ quantize_denoised=False,
479
+ mask=None,
480
+ x0=None,
481
+ img_callback=None,
482
+ start_T=None,
483
+ log_every_t=None,
484
+ ):
485
+
486
+ if not log_every_t:
487
+ log_every_t = self.log_every_t
488
+ device = self.betas.device
489
+ b = shape[0]
490
+ if x_T is None:
491
+ img = torch.randn(shape, device=device)
492
+ else:
493
+ img = x_T
494
+
495
+ intermediates = [img]
496
+ if timesteps is None:
497
+ timesteps = self.num_timesteps
498
+
499
+ if start_T is not None:
500
+ timesteps = min(timesteps, start_T)
501
+ iterator = (
502
+ tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
503
+ if verbose
504
+ else reversed(range(0, timesteps))
505
+ )
506
+
507
+ if mask is not None:
508
+ assert x0 is not None
509
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
510
+
511
+ for i in iterator:
512
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
513
+ if self.shorten_cond_schedule:
514
+ assert self.model.conditioning_key != "hybrid"
515
+ tc = self.cond_ids[ts].to(cond.device)
516
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
517
+
518
+ img = self.p_sample(
519
+ img,
520
+ cond,
521
+ ts,
522
+ clip_denoised=self.clip_denoised,
523
+ quantize_denoised=quantize_denoised,
524
+ )
525
+ if mask is not None:
526
+ img_orig = self.q_sample(x0, ts)
527
+ img = img_orig * mask + (1.0 - mask) * img
528
+
529
+ if i % log_every_t == 0 or i == timesteps - 1:
530
+ intermediates.append(img)
531
+ if callback:
532
+ callback(i)
533
+ if img_callback:
534
+ img_callback(img, i)
535
+
536
+ if return_intermediates:
537
+ return img, intermediates
538
+ return img
539
+
540
+ @torch.no_grad()
541
+ def sample(
542
+ self,
543
+ cond,
544
+ batch_size=16,
545
+ return_intermediates=False,
546
+ x_T=None,
547
+ verbose=True,
548
+ timesteps=None,
549
+ quantize_denoised=False,
550
+ mask=None,
551
+ x0=None,
552
+ shape=None,
553
+ **kwargs,
554
+ ):
555
+ if shape is None:
556
+ shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size)
557
+ if cond is not None:
558
+ if isinstance(cond, dict):
559
+ cond = {
560
+ key: cond[key][:batch_size]
561
+ if not isinstance(cond[key], list)
562
+ else list(map(lambda x: x[:batch_size], cond[key]))
563
+ for key in cond
564
+ }
565
+ else:
566
+ cond = (
567
+ [c[:batch_size] for c in cond]
568
+ if isinstance(cond, list)
569
+ else cond[:batch_size]
570
+ )
571
+ return self.p_sample_loop(
572
+ cond,
573
+ shape,
574
+ return_intermediates=return_intermediates,
575
+ x_T=x_T,
576
+ verbose=verbose,
577
+ timesteps=timesteps,
578
+ quantize_denoised=quantize_denoised,
579
+ mask=mask,
580
+ x0=x0,
581
+ **kwargs,
582
+ )
583
+
584
+ @torch.no_grad()
585
+ def sample_log(
586
+ self,
587
+ cond,
588
+ batch_size,
589
+ ddim,
590
+ ddim_steps,
591
+ unconditional_guidance_scale=1.0,
592
+ unconditional_conditioning=None,
593
+ use_plms=False,
594
+ mask=None,
595
+ **kwargs,
596
+ ):
597
+
598
+ if mask is not None:
599
+ shape = (self.channels, mask.size()[-2], mask.size()[-1])
600
+ else:
601
+ shape = (self.channels, self.latent_t_size, self.latent_f_size)
602
+
603
+ intermediate = None
604
+ if ddim and not use_plms:
605
+ # print("Use ddim sampler")
606
+
607
+ ddim_sampler = DDIMSampler(self)
608
+ samples, intermediates = ddim_sampler.sample(
609
+ ddim_steps,
610
+ batch_size,
611
+ shape,
612
+ cond,
613
+ verbose=False,
614
+ unconditional_guidance_scale=unconditional_guidance_scale,
615
+ unconditional_conditioning=unconditional_conditioning,
616
+ mask=mask,
617
+ **kwargs,
618
+ )
619
+
620
+ else:
621
+ # print("Use DDPM sampler")
622
+ samples, intermediates = self.sample(
623
+ cond=cond,
624
+ batch_size=batch_size,
625
+ return_intermediates=True,
626
+ unconditional_guidance_scale=unconditional_guidance_scale,
627
+ mask=mask,
628
+ unconditional_conditioning=unconditional_conditioning,
629
+ **kwargs,
630
+ )
631
+
632
+ return samples, intermediate
633
+
634
+ @torch.no_grad()
635
+ def generate_sample(
636
+ self,
637
+ batchs,
638
+ ddim_steps=200,
639
+ ddim_eta=1.0,
640
+ x_T=None,
641
+ n_candidate_gen_per_text=1,
642
+ unconditional_guidance_scale=1.0,
643
+ unconditional_conditioning=None,
644
+ name="waveform",
645
+ use_plms=False,
646
+ save=False,
647
+ **kwargs,
648
+ ):
649
+ # Generate n_candidate_gen_per_text times and select the best
650
+ # Batch: audio, text, fnames
651
+ assert x_T is None
652
+ try:
653
+ batchs = iter(batchs)
654
+ except TypeError:
655
+ raise ValueError("The first input argument should be an iterable object")
656
+
657
+ if use_plms:
658
+ assert ddim_steps is not None
659
+ use_ddim = ddim_steps is not None
660
+ # waveform_save_path = os.path.join(self.get_log_dir(), name)
661
+ # os.makedirs(waveform_save_path, exist_ok=True)
662
+ # print("Waveform save path: ", waveform_save_path)
663
+
664
+ with self.ema_scope("Generate"):
665
+ for batch in batchs:
666
+ z, c = self.get_input(
667
+ batch,
668
+ self.first_stage_key,
669
+ cond_key=self.cond_stage_key,
670
+ return_first_stage_outputs=False,
671
+ force_c_encode=True,
672
+ return_original_cond=False,
673
+ bs=None,
674
+ )
675
+ text = super().get_input(batch, "text")
676
+
677
+ # Generate multiple samples
678
+ batch_size = z.shape[0] * n_candidate_gen_per_text
679
+ c = torch.cat([c] * n_candidate_gen_per_text, dim=0)
680
+ text = text * n_candidate_gen_per_text
681
+
682
+ if unconditional_guidance_scale != 1.0:
683
+ unconditional_conditioning = (
684
+ self.cond_stage_model.get_unconditional_condition(batch_size)
685
+ )
686
+
687
+ samples, _ = self.sample_log(
688
+ cond=c,
689
+ batch_size=batch_size,
690
+ x_T=x_T,
691
+ ddim=use_ddim,
692
+ ddim_steps=ddim_steps,
693
+ eta=ddim_eta,
694
+ unconditional_guidance_scale=unconditional_guidance_scale,
695
+ unconditional_conditioning=unconditional_conditioning,
696
+ use_plms=use_plms,
697
+ )
698
+
699
+ if(torch.max(torch.abs(samples)) > 1e2):
700
+ samples = torch.clip(samples, min=-10, max=10)
701
+
702
+ mel = self.decode_first_stage(samples)
703
+
704
+ waveform = self.mel_spectrogram_to_waveform(mel)
705
+
706
+ if waveform.shape[0] > 1:
707
+ similarity = self.cond_stage_model.cos_similarity(
708
+ torch.FloatTensor(waveform).squeeze(1), text
709
+ )
710
+
711
+ best_index = []
712
+ for i in range(z.shape[0]):
713
+ candidates = similarity[i :: z.shape[0]]
714
+ max_index = torch.argmax(candidates).item()
715
+ best_index.append(i + max_index * z.shape[0])
716
+
717
+ waveform = waveform[best_index]
718
+ # print("Similarity between generated audio and text", similarity)
719
+ # print("Choose the following indexes:", best_index)
720
+
721
+ return waveform
722
+
723
+ @torch.no_grad()
724
+ def generate_sample_masked(
725
+ self,
726
+ batchs,
727
+ ddim_steps=200,
728
+ ddim_eta=1.0,
729
+ x_T=None,
730
+ n_candidate_gen_per_text=1,
731
+ unconditional_guidance_scale=1.0,
732
+ unconditional_conditioning=None,
733
+ name="waveform",
734
+ use_plms=False,
735
+ time_mask_ratio_start_and_end=(0.25, 0.75),
736
+ freq_mask_ratio_start_and_end=(0.75, 1.0),
737
+ save=False,
738
+ **kwargs,
739
+ ):
740
+ # Generate n_candidate_gen_per_text times and select the best
741
+ # Batch: audio, text, fnames
742
+ assert x_T is None
743
+ try:
744
+ batchs = iter(batchs)
745
+ except TypeError:
746
+ raise ValueError("The first input argument should be an iterable object")
747
+
748
+ if use_plms:
749
+ assert ddim_steps is not None
750
+ use_ddim = ddim_steps is not None
751
+ # waveform_save_path = os.path.join(self.get_log_dir(), name)
752
+ # os.makedirs(waveform_save_path, exist_ok=True)
753
+ # print("Waveform save path: ", waveform_save_path)
754
+
755
+ with self.ema_scope("Generate"):
756
+ for batch in batchs:
757
+ z, c = self.get_input(
758
+ batch,
759
+ self.first_stage_key,
760
+ cond_key=self.cond_stage_key,
761
+ return_first_stage_outputs=False,
762
+ force_c_encode=True,
763
+ return_original_cond=False,
764
+ bs=None,
765
+ )
766
+ text = super().get_input(batch, "text")
767
+
768
+ # Generate multiple samples
769
+ batch_size = z.shape[0] * n_candidate_gen_per_text
770
+
771
+ _, h, w = z.shape[0], z.shape[2], z.shape[3]
772
+
773
+ mask = torch.ones(batch_size, h, w).to(self.device)
774
+
775
+ mask[:, int(h * time_mask_ratio_start_and_end[0]) : int(h * time_mask_ratio_start_and_end[1]), :] = 0
776
+ mask[:, :, int(w * freq_mask_ratio_start_and_end[0]) : int(w * freq_mask_ratio_start_and_end[1])] = 0
777
+ mask = mask[:, None, ...]
778
+
779
+ c = torch.cat([c] * n_candidate_gen_per_text, dim=0)
780
+ text = text * n_candidate_gen_per_text
781
+
782
+ if unconditional_guidance_scale != 1.0:
783
+ unconditional_conditioning = (
784
+ self.cond_stage_model.get_unconditional_condition(batch_size)
785
+ )
786
+
787
+ samples, _ = self.sample_log(
788
+ cond=c,
789
+ batch_size=batch_size,
790
+ x_T=x_T,
791
+ ddim=use_ddim,
792
+ ddim_steps=ddim_steps,
793
+ eta=ddim_eta,
794
+ unconditional_guidance_scale=unconditional_guidance_scale,
795
+ unconditional_conditioning=unconditional_conditioning,
796
+ use_plms=use_plms, mask=mask, x0=torch.cat([z] * n_candidate_gen_per_text)
797
+ )
798
+
799
+ mel = self.decode_first_stage(samples)
800
+
801
+ waveform = self.mel_spectrogram_to_waveform(mel)
802
+
803
+ if waveform.shape[0] > 1:
804
+ similarity = self.cond_stage_model.cos_similarity(
805
+ torch.FloatTensor(waveform).squeeze(1), text
806
+ )
807
+
808
+ best_index = []
809
+ for i in range(z.shape[0]):
810
+ candidates = similarity[i :: z.shape[0]]
811
+ max_index = torch.argmax(candidates).item()
812
+ best_index.append(i + max_index * z.shape[0])
813
+
814
+ waveform = waveform[best_index]
815
+ # print("Similarity between generated audio and text", similarity)
816
+ # print("Choose the following indexes:", best_index)
817
+
818
+ return waveform
audioldm/pipeline.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import argparse
4
+ import yaml
5
+ import torch
6
+ from torch import autocast
7
+ from tqdm import tqdm, trange
8
+
9
+ from audioldm import LatentDiffusion, seed_everything
10
+ from audioldm.utils import default_audioldm_config, get_duration, get_bit_depth, get_metadata, download_checkpoint
11
+ from audioldm.audio import wav_to_fbank, TacotronSTFT, read_wav_file
12
+ from audioldm.latent_diffusion.ddim import DDIMSampler
13
+ from einops import repeat
14
+ import os
15
+
16
+ def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1):
17
+ text = [text] * batchsize
18
+ if batchsize < 1:
19
+ print("Warning: Batchsize must be at least 1. Batchsize is set to .")
20
+
21
+ if(fbank is None):
22
+ fbank = torch.zeros((batchsize, 1024, 64)) # Not used, here to keep the code format
23
+ else:
24
+ fbank = torch.FloatTensor(fbank)
25
+ fbank = fbank.expand(batchsize, 1024, 64)
26
+ assert fbank.size(0) == batchsize
27
+
28
+ stft = torch.zeros((batchsize, 1024, 512)) # Not used
29
+
30
+ if(waveform is None):
31
+ waveform = torch.zeros((batchsize, 160000)) # Not used
32
+ else:
33
+ waveform = torch.FloatTensor(waveform)
34
+ waveform = waveform.expand(batchsize, -1)
35
+ assert waveform.size(0) == batchsize
36
+
37
+ fname = [""] * batchsize # Not used
38
+
39
+ batch = (
40
+ fbank,
41
+ stft,
42
+ None,
43
+ fname,
44
+ waveform,
45
+ text,
46
+ )
47
+ return batch
48
+
49
+ def round_up_duration(duration):
50
+ return int(round(duration/2.5) + 1) * 2.5
51
+
52
+ def build_model(
53
+ ckpt_path=None,
54
+ config=None,
55
+ model_name="audioldm-s-full"
56
+ ):
57
+ print("Load AudioLDM: %s", model_name)
58
+
59
+ if(ckpt_path is None):
60
+ ckpt_path = get_metadata()[model_name]["path"]
61
+
62
+ if(not os.path.exists(ckpt_path)):
63
+ download_checkpoint(model_name)
64
+
65
+ if torch.cuda.is_available():
66
+ device = torch.device("cuda:0")
67
+ else:
68
+ device = torch.device("cpu")
69
+
70
+ if config is not None:
71
+ assert type(config) is str
72
+ config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
73
+ else:
74
+ config = default_audioldm_config(model_name)
75
+
76
+ # Use text as condition instead of using waveform during training
77
+ config["model"]["params"]["device"] = device
78
+ config["model"]["params"]["cond_stage_key"] = "text"
79
+
80
+ # No normalization here
81
+ latent_diffusion = LatentDiffusion(**config["model"]["params"])
82
+
83
+ resume_from_checkpoint = ckpt_path
84
+
85
+ checkpoint = torch.load(resume_from_checkpoint, map_location=device)
86
+ latent_diffusion.load_state_dict(checkpoint["state_dict"])
87
+
88
+ latent_diffusion.eval()
89
+ latent_diffusion = latent_diffusion.to(device)
90
+
91
+ latent_diffusion.cond_stage_model.embed_mode = "text"
92
+ return latent_diffusion
93
+
94
+ def duration_to_latent_t_size(duration):
95
+ return int(duration * 25.6)
96
+
97
+ def set_cond_audio(latent_diffusion):
98
+ latent_diffusion.cond_stage_key = "waveform"
99
+ latent_diffusion.cond_stage_model.embed_mode="audio"
100
+ return latent_diffusion
101
+
102
+ def set_cond_text(latent_diffusion):
103
+ latent_diffusion.cond_stage_key = "text"
104
+ latent_diffusion.cond_stage_model.embed_mode="text"
105
+ return latent_diffusion
106
+
107
+ def text_to_audio(
108
+ latent_diffusion,
109
+ text,
110
+ original_audio_file_path = None,
111
+ seed=42,
112
+ ddim_steps=200,
113
+ duration=10,
114
+ batchsize=1,
115
+ guidance_scale=2.5,
116
+ n_candidate_gen_per_text=3,
117
+ config=None,
118
+ ):
119
+ seed_everything(int(seed))
120
+ waveform = None
121
+ if(original_audio_file_path is not None):
122
+ waveform = read_wav_file(original_audio_file_path, int(duration * 102.4) * 160)
123
+
124
+ batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize)
125
+
126
+ latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
127
+
128
+ if(waveform is not None):
129
+ print("Generate audio that has similar content as %s" % original_audio_file_path)
130
+ latent_diffusion = set_cond_audio(latent_diffusion)
131
+ else:
132
+ print("Generate audio using text %s" % text)
133
+ latent_diffusion = set_cond_text(latent_diffusion)
134
+
135
+ with torch.no_grad():
136
+ waveform = latent_diffusion.generate_sample(
137
+ [batch],
138
+ unconditional_guidance_scale=guidance_scale,
139
+ ddim_steps=ddim_steps,
140
+ n_candidate_gen_per_text=n_candidate_gen_per_text,
141
+ duration=duration,
142
+ )
143
+ return waveform
144
+
145
+ def style_transfer(
146
+ latent_diffusion,
147
+ text,
148
+ original_audio_file_path,
149
+ transfer_strength,
150
+ seed=42,
151
+ duration=10,
152
+ batchsize=1,
153
+ guidance_scale=2.5,
154
+ ddim_steps=200,
155
+ config=None,
156
+ ):
157
+ if torch.cuda.is_available():
158
+ device = torch.device("cuda:0")
159
+ else:
160
+ device = torch.device("cpu")
161
+
162
+ assert original_audio_file_path is not None, "You need to provide the original audio file path"
163
+
164
+ audio_file_duration = get_duration(original_audio_file_path)
165
+
166
+ assert get_bit_depth(original_audio_file_path) == 16, "The bit depth of the original audio file %s must be 16" % original_audio_file_path
167
+
168
+ # if(duration > 20):
169
+ # print("Warning: The duration of the audio file %s must be less than 20 seconds. Longer duration will result in Nan in model output (we are still debugging that); Automatically set duration to 20 seconds")
170
+ # duration = 20
171
+
172
+ if(duration >= audio_file_duration):
173
+ print("Warning: Duration you specified %s-seconds must equal or smaller than the audio file duration %ss" % (duration, audio_file_duration))
174
+ duration = round_up_duration(audio_file_duration)
175
+ print("Set new duration as %s-seconds" % duration)
176
+
177
+ # duration = round_up_duration(duration)
178
+
179
+ latent_diffusion = set_cond_text(latent_diffusion)
180
+
181
+ if config is not None:
182
+ assert type(config) is str
183
+ config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
184
+ else:
185
+ config = default_audioldm_config()
186
+
187
+ seed_everything(int(seed))
188
+ # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
189
+ latent_diffusion.cond_stage_model.embed_mode = "text"
190
+
191
+ fn_STFT = TacotronSTFT(
192
+ config["preprocessing"]["stft"]["filter_length"],
193
+ config["preprocessing"]["stft"]["hop_length"],
194
+ config["preprocessing"]["stft"]["win_length"],
195
+ config["preprocessing"]["mel"]["n_mel_channels"],
196
+ config["preprocessing"]["audio"]["sampling_rate"],
197
+ config["preprocessing"]["mel"]["mel_fmin"],
198
+ config["preprocessing"]["mel"]["mel_fmax"],
199
+ )
200
+
201
+ mel, _, _ = wav_to_fbank(
202
+ original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT
203
+ )
204
+ mel = mel.unsqueeze(0).unsqueeze(0).to(device)
205
+ mel = repeat(mel, "1 ... -> b ...", b=batchsize)
206
+ init_latent = latent_diffusion.get_first_stage_encoding(
207
+ latent_diffusion.encode_first_stage(mel)
208
+ ) # move to latent space, encode and sample
209
+ if(torch.max(torch.abs(init_latent)) > 1e2):
210
+ init_latent = torch.clip(init_latent, min=-10, max=10)
211
+ sampler = DDIMSampler(latent_diffusion)
212
+ sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=1.0, verbose=False)
213
+
214
+ t_enc = int(transfer_strength * ddim_steps)
215
+ prompts = text
216
+
217
+ with torch.no_grad():
218
+ with autocast("cuda"):
219
+ with latent_diffusion.ema_scope():
220
+ uc = None
221
+ if guidance_scale != 1.0:
222
+ uc = latent_diffusion.cond_stage_model.get_unconditional_condition(
223
+ batchsize
224
+ )
225
+
226
+ c = latent_diffusion.get_learned_conditioning([prompts] * batchsize)
227
+ z_enc = sampler.stochastic_encode(
228
+ init_latent, torch.tensor([t_enc] * batchsize).to(device)
229
+ )
230
+ samples = sampler.decode(
231
+ z_enc,
232
+ c,
233
+ t_enc,
234
+ unconditional_guidance_scale=guidance_scale,
235
+ unconditional_conditioning=uc,
236
+ )
237
+ # x_samples = latent_diffusion.decode_first_stage(samples) # Will result in Nan in output
238
+ # print(torch.sum(torch.isnan(samples)))
239
+ x_samples = latent_diffusion.decode_first_stage(samples)
240
+ # print(x_samples)
241
+ x_samples = latent_diffusion.decode_first_stage(samples[:,:,:-3,:])
242
+ # print(x_samples)
243
+ waveform = latent_diffusion.first_stage_model.decode_to_waveform(
244
+ x_samples
245
+ )
246
+
247
+ return waveform
248
+
249
+ def super_resolution_and_inpainting(
250
+ latent_diffusion,
251
+ text,
252
+ original_audio_file_path = None,
253
+ seed=42,
254
+ ddim_steps=200,
255
+ duration=None,
256
+ batchsize=1,
257
+ guidance_scale=2.5,
258
+ n_candidate_gen_per_text=3,
259
+ time_mask_ratio_start_and_end=(0.10, 0.15), # regenerate the 10% to 15% of the time steps in the spectrogram
260
+ # time_mask_ratio_start_and_end=(1.0, 1.0), # no inpainting
261
+ # freq_mask_ratio_start_and_end=(0.75, 1.0), # regenerate the higher 75% to 100% mel bins
262
+ freq_mask_ratio_start_and_end=(1.0, 1.0), # no super-resolution
263
+ config=None,
264
+ ):
265
+ seed_everything(int(seed))
266
+ if config is not None:
267
+ assert type(config) is str
268
+ config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
269
+ else:
270
+ config = default_audioldm_config()
271
+ fn_STFT = TacotronSTFT(
272
+ config["preprocessing"]["stft"]["filter_length"],
273
+ config["preprocessing"]["stft"]["hop_length"],
274
+ config["preprocessing"]["stft"]["win_length"],
275
+ config["preprocessing"]["mel"]["n_mel_channels"],
276
+ config["preprocessing"]["audio"]["sampling_rate"],
277
+ config["preprocessing"]["mel"]["mel_fmin"],
278
+ config["preprocessing"]["mel"]["mel_fmax"],
279
+ )
280
+
281
+ # waveform = read_wav_file(original_audio_file_path, None)
282
+ mel, _, _ = wav_to_fbank(
283
+ original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT
284
+ )
285
+
286
+ batch = make_batch_for_text_to_audio(text, fbank=mel[None,...], batchsize=batchsize)
287
+
288
+ # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
289
+ latent_diffusion = set_cond_text(latent_diffusion)
290
+
291
+ with torch.no_grad():
292
+ waveform = latent_diffusion.generate_sample_masked(
293
+ [batch],
294
+ unconditional_guidance_scale=guidance_scale,
295
+ ddim_steps=ddim_steps,
296
+ n_candidate_gen_per_text=n_candidate_gen_per_text,
297
+ duration=duration,
298
+ time_mask_ratio_start_and_end=time_mask_ratio_start_and_end,
299
+ freq_mask_ratio_start_and_end=freq_mask_ratio_start_and_end
300
+ )
301
+ return waveform
audioldm/utils.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import importlib
3
+
4
+ from inspect import isfunction
5
+ import os
6
+ import soundfile as sf
7
+ import time
8
+ import wave
9
+
10
+ import urllib.request
11
+ import progressbar
12
+
13
+ CACHE_DIR = os.getenv(
14
+ "AUDIOLDM_CACHE_DIR",
15
+ os.path.join(os.path.expanduser("~"), ".cache/audioldm"))
16
+
17
+ def get_duration(fname):
18
+ with contextlib.closing(wave.open(fname, 'r')) as f:
19
+ frames = f.getnframes()
20
+ rate = f.getframerate()
21
+ return frames / float(rate)
22
+
23
+ def get_bit_depth(fname):
24
+ with contextlib.closing(wave.open(fname, 'r')) as f:
25
+ bit_depth = f.getsampwidth() * 8
26
+ return bit_depth
27
+
28
+ def get_time():
29
+ t = time.localtime()
30
+ return time.strftime("%d_%m_%Y_%H_%M_%S", t)
31
+
32
+ def seed_everything(seed):
33
+ import random, os
34
+ import numpy as np
35
+ import torch
36
+
37
+ random.seed(seed)
38
+ os.environ["PYTHONHASHSEED"] = str(seed)
39
+ np.random.seed(seed)
40
+ torch.manual_seed(seed)
41
+ torch.cuda.manual_seed(seed)
42
+ torch.backends.cudnn.deterministic = True
43
+ torch.backends.cudnn.benchmark = True
44
+
45
+
46
+ def save_wave(waveform, savepath, name="outwav"):
47
+ if type(name) is not list:
48
+ name = [name] * waveform.shape[0]
49
+
50
+ for i in range(waveform.shape[0]):
51
+ path = os.path.join(
52
+ savepath,
53
+ "%s_%s.wav"
54
+ % (
55
+ os.path.basename(name[i])
56
+ if (not ".wav" in name[i])
57
+ else os.path.basename(name[i]).split(".")[0],
58
+ i,
59
+ ),
60
+ )
61
+ print("Save audio to %s" % path)
62
+ sf.write(path, waveform[i, 0], samplerate=16000)
63
+
64
+
65
+ def exists(x):
66
+ return x is not None
67
+
68
+
69
+ def default(val, d):
70
+ if exists(val):
71
+ return val
72
+ return d() if isfunction(d) else d
73
+
74
+
75
+ def count_params(model, verbose=False):
76
+ total_params = sum(p.numel() for p in model.parameters())
77
+ if verbose:
78
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
79
+ return total_params
80
+
81
+
82
+ def get_obj_from_str(string, reload=False):
83
+ module, cls = string.rsplit(".", 1)
84
+ if reload:
85
+ module_imp = importlib.import_module(module)
86
+ importlib.reload(module_imp)
87
+ return getattr(importlib.import_module(module, package=None), cls)
88
+
89
+
90
+ def instantiate_from_config(config):
91
+ if not "target" in config:
92
+ if config == "__is_first_stage__":
93
+ return None
94
+ elif config == "__is_unconditional__":
95
+ return None
96
+ raise KeyError("Expected key `target` to instantiate.")
97
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
98
+
99
+
100
+ def default_audioldm_config(model_name="audioldm-s-full"):
101
+ basic_config = {
102
+ "wave_file_save_path": "./output",
103
+ "id": {
104
+ "version": "v1",
105
+ "name": "default",
106
+ "root": "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/AudioLDM-python/config/default/latent_diffusion.yaml",
107
+ },
108
+ "preprocessing": {
109
+ "audio": {"sampling_rate": 16000, "max_wav_value": 32768},
110
+ "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024},
111
+ "mel": {
112
+ "n_mel_channels": 64,
113
+ "mel_fmin": 0,
114
+ "mel_fmax": 8000,
115
+ "freqm": 0,
116
+ "timem": 0,
117
+ "blur": False,
118
+ "mean": -4.63,
119
+ "std": 2.74,
120
+ "target_length": 1024,
121
+ },
122
+ },
123
+ "model": {
124
+ "device": "cuda",
125
+ "target": "audioldm.pipline.LatentDiffusion",
126
+ "params": {
127
+ "base_learning_rate": 5e-06,
128
+ "linear_start": 0.0015,
129
+ "linear_end": 0.0195,
130
+ "num_timesteps_cond": 1,
131
+ "log_every_t": 200,
132
+ "timesteps": 1000,
133
+ "first_stage_key": "fbank",
134
+ "cond_stage_key": "waveform",
135
+ "latent_t_size": 256,
136
+ "latent_f_size": 16,
137
+ "channels": 8,
138
+ "cond_stage_trainable": True,
139
+ "conditioning_key": "film",
140
+ "monitor": "val/loss_simple_ema",
141
+ "scale_by_std": True,
142
+ "unet_config": {
143
+ "target": "audioldm.latent_diffusion.openaimodel.UNetModel",
144
+ "params": {
145
+ "image_size": 64,
146
+ "extra_film_condition_dim": 512,
147
+ "extra_film_use_concat": True,
148
+ "in_channels": 8,
149
+ "out_channels": 8,
150
+ "model_channels": 128,
151
+ "attention_resolutions": [8, 4, 2],
152
+ "num_res_blocks": 2,
153
+ "channel_mult": [1, 2, 3, 5],
154
+ "num_head_channels": 32,
155
+ "use_spatial_transformer": True,
156
+ },
157
+ },
158
+ "first_stage_config": {
159
+ "base_learning_rate": 4.5e-05,
160
+ "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL",
161
+ "params": {
162
+ "monitor": "val/rec_loss",
163
+ "image_key": "fbank",
164
+ "subband": 1,
165
+ "embed_dim": 8,
166
+ "time_shuffle": 1,
167
+ "ddconfig": {
168
+ "double_z": True,
169
+ "z_channels": 8,
170
+ "resolution": 256,
171
+ "downsample_time": False,
172
+ "in_channels": 1,
173
+ "out_ch": 1,
174
+ "ch": 128,
175
+ "ch_mult": [1, 2, 4],
176
+ "num_res_blocks": 2,
177
+ "attn_resolutions": [],
178
+ "dropout": 0.0,
179
+ },
180
+ },
181
+ },
182
+ "cond_stage_config": {
183
+ "target": "audioldm.clap.encoders.CLAPAudioEmbeddingClassifierFreev2",
184
+ "params": {
185
+ "key": "waveform",
186
+ "sampling_rate": 16000,
187
+ "embed_mode": "audio",
188
+ "unconditional_prob": 0.1,
189
+ },
190
+ },
191
+ },
192
+ },
193
+ }
194
+
195
+ if("-l-" in model_name):
196
+ basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 256
197
+ basic_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = 64
198
+ elif("-m-" in model_name):
199
+ basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 192
200
+ basic_config["model"]["params"]["cond_stage_config"]["params"]["amodel"] = "HTSAT-base" # This model use a larger HTAST
201
+
202
+ return basic_config
203
+
204
+ def get_metadata():
205
+ return {
206
+ "audioldm-s-full": {
207
+ "path": os.path.join(
208
+ CACHE_DIR,
209
+ "audioldm-s-full.ckpt",
210
+ ),
211
+ "url": "https://zenodo.org/record/7600541/files/audioldm-s-full?download=1",
212
+ },
213
+ "audioldm-l-full": {
214
+ "path": os.path.join(
215
+ CACHE_DIR,
216
+ "audioldm-l-full.ckpt",
217
+ ),
218
+ "url": "https://zenodo.org/record/7698295/files/audioldm-full-l.ckpt?download=1",
219
+ },
220
+ "audioldm-s-full-v2": {
221
+ "path": os.path.join(
222
+ CACHE_DIR,
223
+ "audioldm-s-full-v2.ckpt",
224
+ ),
225
+ "url": "https://zenodo.org/record/7698295/files/audioldm-full-s-v2.ckpt?download=1",
226
+ },
227
+ "audioldm-m-text-ft": {
228
+ "path": os.path.join(
229
+ CACHE_DIR,
230
+ "audioldm-m-text-ft.ckpt",
231
+ ),
232
+ "url": "https://zenodo.org/record/7813012/files/audioldm-m-text-ft.ckpt?download=1",
233
+ },
234
+ "audioldm-s-text-ft": {
235
+ "path": os.path.join(
236
+ CACHE_DIR,
237
+ "audioldm-s-text-ft.ckpt",
238
+ ),
239
+ "url": "https://zenodo.org/record/7813012/files/audioldm-s-text-ft.ckpt?download=1",
240
+ },
241
+ "audioldm-m-full": {
242
+ "path": os.path.join(
243
+ CACHE_DIR,
244
+ "audioldm-m-full.ckpt",
245
+ ),
246
+ "url": "https://zenodo.org/record/7813012/files/audioldm-m-full.ckpt?download=1",
247
+ },
248
+ }
249
+
250
+ class MyProgressBar():
251
+ def __init__(self):
252
+ self.pbar = None
253
+
254
+ def __call__(self, block_num, block_size, total_size):
255
+ if not self.pbar:
256
+ self.pbar=progressbar.ProgressBar(maxval=total_size)
257
+ self.pbar.start()
258
+
259
+ downloaded = block_num * block_size
260
+ if downloaded < total_size:
261
+ self.pbar.update(downloaded)
262
+ else:
263
+ self.pbar.finish()
264
+
265
+ def download_checkpoint(checkpoint_name="audioldm-s-full"):
266
+ meta = get_metadata()
267
+ if(checkpoint_name not in meta.keys()):
268
+ print("The model name you provided is not supported. Please use one of the following: ", meta.keys())
269
+
270
+ if not os.path.exists(meta[checkpoint_name]["path"]) or os.path.getsize(meta[checkpoint_name]["path"]) < 2*10**9:
271
+ os.makedirs(os.path.dirname(meta[checkpoint_name]["path"]), exist_ok=True)
272
+ print(f"Downloading the main structure of {checkpoint_name} into {os.path.dirname(meta[checkpoint_name]['path'])}")
273
+
274
+ urllib.request.urlretrieve(meta[checkpoint_name]["url"], meta[checkpoint_name]["path"], MyProgressBar())
275
+ print(
276
+ "Weights downloaded in: {} Size: {}".format(
277
+ meta[checkpoint_name]["path"],
278
+ os.path.getsize(meta[checkpoint_name]["path"]),
279
+ )
280
+ )
281
+
audioldm/variational_autoencoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .autoencoder import AutoencoderKL
audioldm/variational_autoencoder/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (220 Bytes). View file
 
audioldm/variational_autoencoder/__pycache__/autoencoder.cpython-39.pyc ADDED
Binary file (4.37 kB). View file
 
audioldm/variational_autoencoder/__pycache__/distributions.cpython-39.pyc ADDED
Binary file (3.78 kB). View file
 
audioldm/variational_autoencoder/__pycache__/modules.cpython-39.pyc ADDED
Binary file (22.1 kB). View file
 
audioldm/variational_autoencoder/autoencoder.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from audioldm.latent_diffusion.ema import *
3
+ from audioldm.variational_autoencoder.modules import Encoder, Decoder
4
+ from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution
5
+
6
+ from audioldm.hifigan.utilities import get_vocoder, vocoder_infer
7
+
8
+
9
+ class AutoencoderKL(nn.Module):
10
+ def __init__(
11
+ self,
12
+ ddconfig=None,
13
+ lossconfig=None,
14
+ image_key="fbank",
15
+ embed_dim=None,
16
+ time_shuffle=1,
17
+ subband=1,
18
+ ckpt_path=None,
19
+ reload_from_ckpt=None,
20
+ ignore_keys=[],
21
+ colorize_nlabels=None,
22
+ monitor=None,
23
+ base_learning_rate=1e-5,
24
+ scale_factor=1
25
+ ):
26
+ super().__init__()
27
+
28
+ self.encoder = Encoder(**ddconfig)
29
+ self.decoder = Decoder(**ddconfig)
30
+
31
+ self.subband = int(subband)
32
+
33
+ if self.subband > 1:
34
+ print("Use subband decomposition %s" % self.subband)
35
+
36
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
37
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
38
+
39
+ self.vocoder = get_vocoder(None, "cpu")
40
+ self.embed_dim = embed_dim
41
+
42
+ if monitor is not None:
43
+ self.monitor = monitor
44
+
45
+ self.time_shuffle = time_shuffle
46
+ self.reload_from_ckpt = reload_from_ckpt
47
+ self.reloaded = False
48
+ self.mean, self.std = None, None
49
+
50
+ self.scale_factor = scale_factor
51
+
52
+ def encode(self, x):
53
+ # x = self.time_shuffle_operation(x)
54
+ x = self.freq_split_subband(x)
55
+ h = self.encoder(x)
56
+ moments = self.quant_conv(h)
57
+ posterior = DiagonalGaussianDistribution(moments)
58
+ return posterior
59
+
60
+ def decode(self, z):
61
+ z = self.post_quant_conv(z)
62
+ dec = self.decoder(z)
63
+ dec = self.freq_merge_subband(dec)
64
+ return dec
65
+
66
+ def decode_to_waveform(self, dec):
67
+ dec = dec.squeeze(1).permute(0, 2, 1)
68
+ wav_reconstruction = vocoder_infer(dec, self.vocoder)
69
+ return wav_reconstruction
70
+
71
+ def forward(self, input, sample_posterior=True):
72
+ posterior = self.encode(input)
73
+ if sample_posterior:
74
+ z = posterior.sample()
75
+ else:
76
+ z = posterior.mode()
77
+
78
+ if self.flag_first_run:
79
+ print("Latent size: ", z.size())
80
+ self.flag_first_run = False
81
+
82
+ dec = self.decode(z)
83
+
84
+ return dec, posterior
85
+
86
+ def freq_split_subband(self, fbank):
87
+ if self.subband == 1 or self.image_key != "stft":
88
+ return fbank
89
+
90
+ bs, ch, tstep, fbins = fbank.size()
91
+
92
+ assert fbank.size(-1) % self.subband == 0
93
+ assert ch == 1
94
+
95
+ return (
96
+ fbank.squeeze(1)
97
+ .reshape(bs, tstep, self.subband, fbins // self.subband)
98
+ .permute(0, 2, 1, 3)
99
+ )
100
+
101
+ def freq_merge_subband(self, subband_fbank):
102
+ if self.subband == 1 or self.image_key != "stft":
103
+ return subband_fbank
104
+ assert subband_fbank.size(1) == self.subband # Channel dimension
105
+ bs, sub_ch, tstep, fbins = subband_fbank.size()
106
+ return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1)
107
+
108
+ def device(self):
109
+ return next(self.parameters()).device
110
+
111
+ @torch.no_grad()
112
+ def encode_first_stage(self, x):
113
+ return self.encode(x)
114
+
115
+ @torch.no_grad()
116
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
117
+ if predict_cids:
118
+ if z.dim() == 4:
119
+ z = torch.argmax(z.exp(), dim=1).long()
120
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
121
+ z = rearrange(z, "b h w c -> b c h w").contiguous()
122
+
123
+ z = 1.0 / self.scale_factor * z
124
+ return self.decode(z)
125
+
126
+ def get_first_stage_encoding(self, encoder_posterior):
127
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
128
+ z = encoder_posterior.sample()
129
+ elif isinstance(encoder_posterior, torch.Tensor):
130
+ z = encoder_posterior
131
+ else:
132
+ raise NotImplementedError(
133
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
134
+ )
135
+ return self.scale_factor * z
audioldm/variational_autoencoder/distributions.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(
34
+ device=self.parameters.device
35
+ )
36
+
37
+ def sample(self):
38
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
39
+ device=self.parameters.device
40
+ )
41
+ return x
42
+
43
+ def kl(self, other=None):
44
+ if self.deterministic:
45
+ return torch.Tensor([0.0])
46
+ else:
47
+ if other is None:
48
+ return 0.5 * torch.mean(
49
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50
+ dim=[1, 2, 3],
51
+ )
52
+ else:
53
+ return 0.5 * torch.mean(
54
+ torch.pow(self.mean - other.mean, 2) / other.var
55
+ + self.var / other.var
56
+ - 1.0
57
+ - self.logvar
58
+ + other.logvar,
59
+ dim=[1, 2, 3],
60
+ )
61
+
62
+ def nll(self, sample, dims=[1, 2, 3]):
63
+ if self.deterministic:
64
+ return torch.Tensor([0.0])
65
+ logtwopi = np.log(2.0 * np.pi)
66
+ return 0.5 * torch.sum(
67
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68
+ dim=dims,
69
+ )
70
+
71
+ def mode(self):
72
+ return self.mean
73
+
74
+
75
+ def normal_kl(mean1, logvar1, mean2, logvar2):
76
+ """
77
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78
+ Compute the KL divergence between two gaussians.
79
+ Shapes are automatically broadcasted, so batches can be compared to
80
+ scalars, among other use cases.
81
+ """
82
+ tensor = None
83
+ for obj in (mean1, logvar1, mean2, logvar2):
84
+ if isinstance(obj, torch.Tensor):
85
+ tensor = obj
86
+ break
87
+ assert tensor is not None, "at least one argument must be a Tensor"
88
+
89
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
90
+ # Tensors, but it does not work for torch.exp().
91
+ logvar1, logvar2 = [
92
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93
+ for x in (logvar1, logvar2)
94
+ ]
95
+
96
+ return 0.5 * (
97
+ -1.0
98
+ + logvar2
99
+ - logvar1
100
+ + torch.exp(logvar1 - logvar2)
101
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102
+ )
audioldm/variational_autoencoder/modules.py ADDED
@@ -0,0 +1,1066 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+
8
+ from audioldm.utils import instantiate_from_config
9
+ from audioldm.latent_diffusion.attention import LinearAttention
10
+
11
+
12
+ def get_timestep_embedding(timesteps, embedding_dim):
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
15
+ From Fairseq.
16
+ Build sinusoidal embeddings.
17
+ This matches the implementation in tensor2tensor, but differs slightly
18
+ from the description in Section 3.5 of "Attention Is All You Need".
19
+ """
20
+ assert len(timesteps.shape) == 1
21
+
22
+ half_dim = embedding_dim // 2
23
+ emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
25
+ emb = emb.to(device=timesteps.device)
26
+ emb = timesteps.float()[:, None] * emb[None, :]
27
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
+ if embedding_dim % 2 == 1: # zero pad
29
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
30
+ return emb
31
+
32
+
33
+ def nonlinearity(x):
34
+ # swish
35
+ return x * torch.sigmoid(x)
36
+
37
+
38
+ def Normalize(in_channels, num_groups=32):
39
+ return torch.nn.GroupNorm(
40
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
41
+ )
42
+
43
+
44
+ class Upsample(nn.Module):
45
+ def __init__(self, in_channels, with_conv):
46
+ super().__init__()
47
+ self.with_conv = with_conv
48
+ if self.with_conv:
49
+ self.conv = torch.nn.Conv2d(
50
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
51
+ )
52
+
53
+ def forward(self, x):
54
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
55
+ if self.with_conv:
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class UpsampleTimeStride4(nn.Module):
61
+ def __init__(self, in_channels, with_conv):
62
+ super().__init__()
63
+ self.with_conv = with_conv
64
+ if self.with_conv:
65
+ self.conv = torch.nn.Conv2d(
66
+ in_channels, in_channels, kernel_size=5, stride=1, padding=2
67
+ )
68
+
69
+ def forward(self, x):
70
+ x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")
71
+ if self.with_conv:
72
+ x = self.conv(x)
73
+ return x
74
+
75
+
76
+ class Downsample(nn.Module):
77
+ def __init__(self, in_channels, with_conv):
78
+ super().__init__()
79
+ self.with_conv = with_conv
80
+ if self.with_conv:
81
+ # Do time downsampling here
82
+ # no asymmetric padding in torch conv, must do it ourselves
83
+ self.conv = torch.nn.Conv2d(
84
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
85
+ )
86
+
87
+ def forward(self, x):
88
+ if self.with_conv:
89
+ pad = (0, 1, 0, 1)
90
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
91
+ x = self.conv(x)
92
+ else:
93
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
94
+ return x
95
+
96
+
97
+ class DownsampleTimeStride4(nn.Module):
98
+ def __init__(self, in_channels, with_conv):
99
+ super().__init__()
100
+ self.with_conv = with_conv
101
+ if self.with_conv:
102
+ # Do time downsampling here
103
+ # no asymmetric padding in torch conv, must do it ourselves
104
+ self.conv = torch.nn.Conv2d(
105
+ in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1
106
+ )
107
+
108
+ def forward(self, x):
109
+ if self.with_conv:
110
+ pad = (0, 1, 0, 1)
111
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
112
+ x = self.conv(x)
113
+ else:
114
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))
115
+ return x
116
+
117
+
118
+ class ResnetBlock(nn.Module):
119
+ def __init__(
120
+ self,
121
+ *,
122
+ in_channels,
123
+ out_channels=None,
124
+ conv_shortcut=False,
125
+ dropout,
126
+ temb_channels=512,
127
+ ):
128
+ super().__init__()
129
+ self.in_channels = in_channels
130
+ out_channels = in_channels if out_channels is None else out_channels
131
+ self.out_channels = out_channels
132
+ self.use_conv_shortcut = conv_shortcut
133
+
134
+ self.norm1 = Normalize(in_channels)
135
+ self.conv1 = torch.nn.Conv2d(
136
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
137
+ )
138
+ if temb_channels > 0:
139
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
140
+ self.norm2 = Normalize(out_channels)
141
+ self.dropout = torch.nn.Dropout(dropout)
142
+ self.conv2 = torch.nn.Conv2d(
143
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
144
+ )
145
+ if self.in_channels != self.out_channels:
146
+ if self.use_conv_shortcut:
147
+ self.conv_shortcut = torch.nn.Conv2d(
148
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
149
+ )
150
+ else:
151
+ self.nin_shortcut = torch.nn.Conv2d(
152
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
153
+ )
154
+
155
+ def forward(self, x, temb):
156
+ h = x
157
+ h = self.norm1(h)
158
+ h = nonlinearity(h)
159
+ h = self.conv1(h)
160
+
161
+ if temb is not None:
162
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
163
+
164
+ h = self.norm2(h)
165
+ h = nonlinearity(h)
166
+ h = self.dropout(h)
167
+ h = self.conv2(h)
168
+
169
+ if self.in_channels != self.out_channels:
170
+ if self.use_conv_shortcut:
171
+ x = self.conv_shortcut(x)
172
+ else:
173
+ x = self.nin_shortcut(x)
174
+
175
+ return x + h
176
+
177
+
178
+ class LinAttnBlock(LinearAttention):
179
+ """to match AttnBlock usage"""
180
+
181
+ def __init__(self, in_channels):
182
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
183
+
184
+
185
+ class AttnBlock(nn.Module):
186
+ def __init__(self, in_channels):
187
+ super().__init__()
188
+ self.in_channels = in_channels
189
+
190
+ self.norm = Normalize(in_channels)
191
+ self.q = torch.nn.Conv2d(
192
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
193
+ )
194
+ self.k = torch.nn.Conv2d(
195
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
196
+ )
197
+ self.v = torch.nn.Conv2d(
198
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
199
+ )
200
+ self.proj_out = torch.nn.Conv2d(
201
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
202
+ )
203
+
204
+ def forward(self, x):
205
+ h_ = x
206
+ h_ = self.norm(h_)
207
+ q = self.q(h_)
208
+ k = self.k(h_)
209
+ v = self.v(h_)
210
+
211
+ # compute attention
212
+ b, c, h, w = q.shape
213
+ q = q.reshape(b, c, h * w).contiguous()
214
+ q = q.permute(0, 2, 1).contiguous() # b,hw,c
215
+ k = k.reshape(b, c, h * w).contiguous() # b,c,hw
216
+ w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
217
+ w_ = w_ * (int(c) ** (-0.5))
218
+ w_ = torch.nn.functional.softmax(w_, dim=2)
219
+
220
+ # attend to values
221
+ v = v.reshape(b, c, h * w).contiguous()
222
+ w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
223
+ h_ = torch.bmm(
224
+ v, w_
225
+ ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
226
+ h_ = h_.reshape(b, c, h, w).contiguous()
227
+
228
+ h_ = self.proj_out(h_)
229
+
230
+ return x + h_
231
+
232
+
233
+ def make_attn(in_channels, attn_type="vanilla"):
234
+ assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
235
+ # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
236
+ if attn_type == "vanilla":
237
+ return AttnBlock(in_channels)
238
+ elif attn_type == "none":
239
+ return nn.Identity(in_channels)
240
+ else:
241
+ return LinAttnBlock(in_channels)
242
+
243
+
244
+ class Model(nn.Module):
245
+ def __init__(
246
+ self,
247
+ *,
248
+ ch,
249
+ out_ch,
250
+ ch_mult=(1, 2, 4, 8),
251
+ num_res_blocks,
252
+ attn_resolutions,
253
+ dropout=0.0,
254
+ resamp_with_conv=True,
255
+ in_channels,
256
+ resolution,
257
+ use_timestep=True,
258
+ use_linear_attn=False,
259
+ attn_type="vanilla",
260
+ ):
261
+ super().__init__()
262
+ if use_linear_attn:
263
+ attn_type = "linear"
264
+ self.ch = ch
265
+ self.temb_ch = self.ch * 4
266
+ self.num_resolutions = len(ch_mult)
267
+ self.num_res_blocks = num_res_blocks
268
+ self.resolution = resolution
269
+ self.in_channels = in_channels
270
+
271
+ self.use_timestep = use_timestep
272
+ if self.use_timestep:
273
+ # timestep embedding
274
+ self.temb = nn.Module()
275
+ self.temb.dense = nn.ModuleList(
276
+ [
277
+ torch.nn.Linear(self.ch, self.temb_ch),
278
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
279
+ ]
280
+ )
281
+
282
+ # downsampling
283
+ self.conv_in = torch.nn.Conv2d(
284
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
285
+ )
286
+
287
+ curr_res = resolution
288
+ in_ch_mult = (1,) + tuple(ch_mult)
289
+ self.down = nn.ModuleList()
290
+ for i_level in range(self.num_resolutions):
291
+ block = nn.ModuleList()
292
+ attn = nn.ModuleList()
293
+ block_in = ch * in_ch_mult[i_level]
294
+ block_out = ch * ch_mult[i_level]
295
+ for i_block in range(self.num_res_blocks):
296
+ block.append(
297
+ ResnetBlock(
298
+ in_channels=block_in,
299
+ out_channels=block_out,
300
+ temb_channels=self.temb_ch,
301
+ dropout=dropout,
302
+ )
303
+ )
304
+ block_in = block_out
305
+ if curr_res in attn_resolutions:
306
+ attn.append(make_attn(block_in, attn_type=attn_type))
307
+ down = nn.Module()
308
+ down.block = block
309
+ down.attn = attn
310
+ if i_level != self.num_resolutions - 1:
311
+ down.downsample = Downsample(block_in, resamp_with_conv)
312
+ curr_res = curr_res // 2
313
+ self.down.append(down)
314
+
315
+ # middle
316
+ self.mid = nn.Module()
317
+ self.mid.block_1 = ResnetBlock(
318
+ in_channels=block_in,
319
+ out_channels=block_in,
320
+ temb_channels=self.temb_ch,
321
+ dropout=dropout,
322
+ )
323
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
324
+ self.mid.block_2 = ResnetBlock(
325
+ in_channels=block_in,
326
+ out_channels=block_in,
327
+ temb_channels=self.temb_ch,
328
+ dropout=dropout,
329
+ )
330
+
331
+ # upsampling
332
+ self.up = nn.ModuleList()
333
+ for i_level in reversed(range(self.num_resolutions)):
334
+ block = nn.ModuleList()
335
+ attn = nn.ModuleList()
336
+ block_out = ch * ch_mult[i_level]
337
+ skip_in = ch * ch_mult[i_level]
338
+ for i_block in range(self.num_res_blocks + 1):
339
+ if i_block == self.num_res_blocks:
340
+ skip_in = ch * in_ch_mult[i_level]
341
+ block.append(
342
+ ResnetBlock(
343
+ in_channels=block_in + skip_in,
344
+ out_channels=block_out,
345
+ temb_channels=self.temb_ch,
346
+ dropout=dropout,
347
+ )
348
+ )
349
+ block_in = block_out
350
+ if curr_res in attn_resolutions:
351
+ attn.append(make_attn(block_in, attn_type=attn_type))
352
+ up = nn.Module()
353
+ up.block = block
354
+ up.attn = attn
355
+ if i_level != 0:
356
+ up.upsample = Upsample(block_in, resamp_with_conv)
357
+ curr_res = curr_res * 2
358
+ self.up.insert(0, up) # prepend to get consistent order
359
+
360
+ # end
361
+ self.norm_out = Normalize(block_in)
362
+ self.conv_out = torch.nn.Conv2d(
363
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
364
+ )
365
+
366
+ def forward(self, x, t=None, context=None):
367
+ # assert x.shape[2] == x.shape[3] == self.resolution
368
+ if context is not None:
369
+ # assume aligned context, cat along channel axis
370
+ x = torch.cat((x, context), dim=1)
371
+ if self.use_timestep:
372
+ # timestep embedding
373
+ assert t is not None
374
+ temb = get_timestep_embedding(t, self.ch)
375
+ temb = self.temb.dense[0](temb)
376
+ temb = nonlinearity(temb)
377
+ temb = self.temb.dense[1](temb)
378
+ else:
379
+ temb = None
380
+
381
+ # downsampling
382
+ hs = [self.conv_in(x)]
383
+ for i_level in range(self.num_resolutions):
384
+ for i_block in range(self.num_res_blocks):
385
+ h = self.down[i_level].block[i_block](hs[-1], temb)
386
+ if len(self.down[i_level].attn) > 0:
387
+ h = self.down[i_level].attn[i_block](h)
388
+ hs.append(h)
389
+ if i_level != self.num_resolutions - 1:
390
+ hs.append(self.down[i_level].downsample(hs[-1]))
391
+
392
+ # middle
393
+ h = hs[-1]
394
+ h = self.mid.block_1(h, temb)
395
+ h = self.mid.attn_1(h)
396
+ h = self.mid.block_2(h, temb)
397
+
398
+ # upsampling
399
+ for i_level in reversed(range(self.num_resolutions)):
400
+ for i_block in range(self.num_res_blocks + 1):
401
+ h = self.up[i_level].block[i_block](
402
+ torch.cat([h, hs.pop()], dim=1), temb
403
+ )
404
+ if len(self.up[i_level].attn) > 0:
405
+ h = self.up[i_level].attn[i_block](h)
406
+ if i_level != 0:
407
+ h = self.up[i_level].upsample(h)
408
+
409
+ # end
410
+ h = self.norm_out(h)
411
+ h = nonlinearity(h)
412
+ h = self.conv_out(h)
413
+ return h
414
+
415
+ def get_last_layer(self):
416
+ return self.conv_out.weight
417
+
418
+
419
+ class Encoder(nn.Module):
420
+ def __init__(
421
+ self,
422
+ *,
423
+ ch,
424
+ out_ch,
425
+ ch_mult=(1, 2, 4, 8),
426
+ num_res_blocks,
427
+ attn_resolutions,
428
+ dropout=0.0,
429
+ resamp_with_conv=True,
430
+ in_channels,
431
+ resolution,
432
+ z_channels,
433
+ double_z=True,
434
+ use_linear_attn=False,
435
+ attn_type="vanilla",
436
+ downsample_time_stride4_levels=[],
437
+ **ignore_kwargs,
438
+ ):
439
+ super().__init__()
440
+ if use_linear_attn:
441
+ attn_type = "linear"
442
+ self.ch = ch
443
+ self.temb_ch = 0
444
+ self.num_resolutions = len(ch_mult)
445
+ self.num_res_blocks = num_res_blocks
446
+ self.resolution = resolution
447
+ self.in_channels = in_channels
448
+ self.downsample_time_stride4_levels = downsample_time_stride4_levels
449
+
450
+ if len(self.downsample_time_stride4_levels) > 0:
451
+ assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
452
+ "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
453
+ % str(self.num_resolutions)
454
+ )
455
+
456
+ # downsampling
457
+ self.conv_in = torch.nn.Conv2d(
458
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
459
+ )
460
+
461
+ curr_res = resolution
462
+ in_ch_mult = (1,) + tuple(ch_mult)
463
+ self.in_ch_mult = in_ch_mult
464
+ self.down = nn.ModuleList()
465
+ for i_level in range(self.num_resolutions):
466
+ block = nn.ModuleList()
467
+ attn = nn.ModuleList()
468
+ block_in = ch * in_ch_mult[i_level]
469
+ block_out = ch * ch_mult[i_level]
470
+ for i_block in range(self.num_res_blocks):
471
+ block.append(
472
+ ResnetBlock(
473
+ in_channels=block_in,
474
+ out_channels=block_out,
475
+ temb_channels=self.temb_ch,
476
+ dropout=dropout,
477
+ )
478
+ )
479
+ block_in = block_out
480
+ if curr_res in attn_resolutions:
481
+ attn.append(make_attn(block_in, attn_type=attn_type))
482
+ down = nn.Module()
483
+ down.block = block
484
+ down.attn = attn
485
+ if i_level != self.num_resolutions - 1:
486
+ if i_level in self.downsample_time_stride4_levels:
487
+ down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)
488
+ else:
489
+ down.downsample = Downsample(block_in, resamp_with_conv)
490
+ curr_res = curr_res // 2
491
+ self.down.append(down)
492
+
493
+ # middle
494
+ self.mid = nn.Module()
495
+ self.mid.block_1 = ResnetBlock(
496
+ in_channels=block_in,
497
+ out_channels=block_in,
498
+ temb_channels=self.temb_ch,
499
+ dropout=dropout,
500
+ )
501
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
502
+ self.mid.block_2 = ResnetBlock(
503
+ in_channels=block_in,
504
+ out_channels=block_in,
505
+ temb_channels=self.temb_ch,
506
+ dropout=dropout,
507
+ )
508
+
509
+ # end
510
+ self.norm_out = Normalize(block_in)
511
+ self.conv_out = torch.nn.Conv2d(
512
+ block_in,
513
+ 2 * z_channels if double_z else z_channels,
514
+ kernel_size=3,
515
+ stride=1,
516
+ padding=1,
517
+ )
518
+
519
+ def forward(self, x):
520
+ # timestep embedding
521
+ temb = None
522
+ # downsampling
523
+ hs = [self.conv_in(x)]
524
+ for i_level in range(self.num_resolutions):
525
+ for i_block in range(self.num_res_blocks):
526
+ h = self.down[i_level].block[i_block](hs[-1], temb)
527
+ if len(self.down[i_level].attn) > 0:
528
+ h = self.down[i_level].attn[i_block](h)
529
+ hs.append(h)
530
+ if i_level != self.num_resolutions - 1:
531
+ hs.append(self.down[i_level].downsample(hs[-1]))
532
+
533
+ # middle
534
+ h = hs[-1]
535
+ h = self.mid.block_1(h, temb)
536
+ h = self.mid.attn_1(h)
537
+ h = self.mid.block_2(h, temb)
538
+
539
+ # end
540
+ h = self.norm_out(h)
541
+ h = nonlinearity(h)
542
+ h = self.conv_out(h)
543
+ return h
544
+
545
+
546
+ class Decoder(nn.Module):
547
+ def __init__(
548
+ self,
549
+ *,
550
+ ch,
551
+ out_ch,
552
+ ch_mult=(1, 2, 4, 8),
553
+ num_res_blocks,
554
+ attn_resolutions,
555
+ dropout=0.0,
556
+ resamp_with_conv=True,
557
+ in_channels,
558
+ resolution,
559
+ z_channels,
560
+ give_pre_end=False,
561
+ tanh_out=False,
562
+ use_linear_attn=False,
563
+ downsample_time_stride4_levels=[],
564
+ attn_type="vanilla",
565
+ **ignorekwargs,
566
+ ):
567
+ super().__init__()
568
+ if use_linear_attn:
569
+ attn_type = "linear"
570
+ self.ch = ch
571
+ self.temb_ch = 0
572
+ self.num_resolutions = len(ch_mult)
573
+ self.num_res_blocks = num_res_blocks
574
+ self.resolution = resolution
575
+ self.in_channels = in_channels
576
+ self.give_pre_end = give_pre_end
577
+ self.tanh_out = tanh_out
578
+ self.downsample_time_stride4_levels = downsample_time_stride4_levels
579
+
580
+ if len(self.downsample_time_stride4_levels) > 0:
581
+ assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
582
+ "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
583
+ % str(self.num_resolutions)
584
+ )
585
+
586
+ # compute in_ch_mult, block_in and curr_res at lowest res
587
+ in_ch_mult = (1,) + tuple(ch_mult)
588
+ block_in = ch * ch_mult[self.num_resolutions - 1]
589
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
590
+ self.z_shape = (1, z_channels, curr_res, curr_res)
591
+ # print("Working with z of shape {} = {} dimensions.".format(
592
+ # self.z_shape, np.prod(self.z_shape)))
593
+
594
+ # z to block_in
595
+ self.conv_in = torch.nn.Conv2d(
596
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
597
+ )
598
+
599
+ # middle
600
+ self.mid = nn.Module()
601
+ self.mid.block_1 = ResnetBlock(
602
+ in_channels=block_in,
603
+ out_channels=block_in,
604
+ temb_channels=self.temb_ch,
605
+ dropout=dropout,
606
+ )
607
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
608
+ self.mid.block_2 = ResnetBlock(
609
+ in_channels=block_in,
610
+ out_channels=block_in,
611
+ temb_channels=self.temb_ch,
612
+ dropout=dropout,
613
+ )
614
+
615
+ # upsampling
616
+ self.up = nn.ModuleList()
617
+ for i_level in reversed(range(self.num_resolutions)):
618
+ block = nn.ModuleList()
619
+ attn = nn.ModuleList()
620
+ block_out = ch * ch_mult[i_level]
621
+ for i_block in range(self.num_res_blocks + 1):
622
+ block.append(
623
+ ResnetBlock(
624
+ in_channels=block_in,
625
+ out_channels=block_out,
626
+ temb_channels=self.temb_ch,
627
+ dropout=dropout,
628
+ )
629
+ )
630
+ block_in = block_out
631
+ if curr_res in attn_resolutions:
632
+ attn.append(make_attn(block_in, attn_type=attn_type))
633
+ up = nn.Module()
634
+ up.block = block
635
+ up.attn = attn
636
+ if i_level != 0:
637
+ if i_level - 1 in self.downsample_time_stride4_levels:
638
+ up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)
639
+ else:
640
+ up.upsample = Upsample(block_in, resamp_with_conv)
641
+ curr_res = curr_res * 2
642
+ self.up.insert(0, up) # prepend to get consistent order
643
+
644
+ # end
645
+ self.norm_out = Normalize(block_in)
646
+ self.conv_out = torch.nn.Conv2d(
647
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
648
+ )
649
+
650
+ def forward(self, z):
651
+ # assert z.shape[1:] == self.z_shape[1:]
652
+ self.last_z_shape = z.shape
653
+
654
+ # timestep embedding
655
+ temb = None
656
+
657
+ # z to block_in
658
+ h = self.conv_in(z)
659
+
660
+ # middle
661
+ h = self.mid.block_1(h, temb)
662
+ h = self.mid.attn_1(h)
663
+ h = self.mid.block_2(h, temb)
664
+
665
+ # upsampling
666
+ for i_level in reversed(range(self.num_resolutions)):
667
+ for i_block in range(self.num_res_blocks + 1):
668
+ h = self.up[i_level].block[i_block](h, temb)
669
+ if len(self.up[i_level].attn) > 0:
670
+ h = self.up[i_level].attn[i_block](h)
671
+ if i_level != 0:
672
+ h = self.up[i_level].upsample(h)
673
+
674
+ # end
675
+ if self.give_pre_end:
676
+ return h
677
+
678
+ h = self.norm_out(h)
679
+ h = nonlinearity(h)
680
+ h = self.conv_out(h)
681
+ if self.tanh_out:
682
+ h = torch.tanh(h)
683
+ return h
684
+
685
+
686
+ class SimpleDecoder(nn.Module):
687
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
688
+ super().__init__()
689
+ self.model = nn.ModuleList(
690
+ [
691
+ nn.Conv2d(in_channels, in_channels, 1),
692
+ ResnetBlock(
693
+ in_channels=in_channels,
694
+ out_channels=2 * in_channels,
695
+ temb_channels=0,
696
+ dropout=0.0,
697
+ ),
698
+ ResnetBlock(
699
+ in_channels=2 * in_channels,
700
+ out_channels=4 * in_channels,
701
+ temb_channels=0,
702
+ dropout=0.0,
703
+ ),
704
+ ResnetBlock(
705
+ in_channels=4 * in_channels,
706
+ out_channels=2 * in_channels,
707
+ temb_channels=0,
708
+ dropout=0.0,
709
+ ),
710
+ nn.Conv2d(2 * in_channels, in_channels, 1),
711
+ Upsample(in_channels, with_conv=True),
712
+ ]
713
+ )
714
+ # end
715
+ self.norm_out = Normalize(in_channels)
716
+ self.conv_out = torch.nn.Conv2d(
717
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
718
+ )
719
+
720
+ def forward(self, x):
721
+ for i, layer in enumerate(self.model):
722
+ if i in [1, 2, 3]:
723
+ x = layer(x, None)
724
+ else:
725
+ x = layer(x)
726
+
727
+ h = self.norm_out(x)
728
+ h = nonlinearity(h)
729
+ x = self.conv_out(h)
730
+ return x
731
+
732
+
733
+ class UpsampleDecoder(nn.Module):
734
+ def __init__(
735
+ self,
736
+ in_channels,
737
+ out_channels,
738
+ ch,
739
+ num_res_blocks,
740
+ resolution,
741
+ ch_mult=(2, 2),
742
+ dropout=0.0,
743
+ ):
744
+ super().__init__()
745
+ # upsampling
746
+ self.temb_ch = 0
747
+ self.num_resolutions = len(ch_mult)
748
+ self.num_res_blocks = num_res_blocks
749
+ block_in = in_channels
750
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
751
+ self.res_blocks = nn.ModuleList()
752
+ self.upsample_blocks = nn.ModuleList()
753
+ for i_level in range(self.num_resolutions):
754
+ res_block = []
755
+ block_out = ch * ch_mult[i_level]
756
+ for i_block in range(self.num_res_blocks + 1):
757
+ res_block.append(
758
+ ResnetBlock(
759
+ in_channels=block_in,
760
+ out_channels=block_out,
761
+ temb_channels=self.temb_ch,
762
+ dropout=dropout,
763
+ )
764
+ )
765
+ block_in = block_out
766
+ self.res_blocks.append(nn.ModuleList(res_block))
767
+ if i_level != self.num_resolutions - 1:
768
+ self.upsample_blocks.append(Upsample(block_in, True))
769
+ curr_res = curr_res * 2
770
+
771
+ # end
772
+ self.norm_out = Normalize(block_in)
773
+ self.conv_out = torch.nn.Conv2d(
774
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
775
+ )
776
+
777
+ def forward(self, x):
778
+ # upsampling
779
+ h = x
780
+ for k, i_level in enumerate(range(self.num_resolutions)):
781
+ for i_block in range(self.num_res_blocks + 1):
782
+ h = self.res_blocks[i_level][i_block](h, None)
783
+ if i_level != self.num_resolutions - 1:
784
+ h = self.upsample_blocks[k](h)
785
+ h = self.norm_out(h)
786
+ h = nonlinearity(h)
787
+ h = self.conv_out(h)
788
+ return h
789
+
790
+
791
+ class LatentRescaler(nn.Module):
792
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
793
+ super().__init__()
794
+ # residual block, interpolate, residual block
795
+ self.factor = factor
796
+ self.conv_in = nn.Conv2d(
797
+ in_channels, mid_channels, kernel_size=3, stride=1, padding=1
798
+ )
799
+ self.res_block1 = nn.ModuleList(
800
+ [
801
+ ResnetBlock(
802
+ in_channels=mid_channels,
803
+ out_channels=mid_channels,
804
+ temb_channels=0,
805
+ dropout=0.0,
806
+ )
807
+ for _ in range(depth)
808
+ ]
809
+ )
810
+ self.attn = AttnBlock(mid_channels)
811
+ self.res_block2 = nn.ModuleList(
812
+ [
813
+ ResnetBlock(
814
+ in_channels=mid_channels,
815
+ out_channels=mid_channels,
816
+ temb_channels=0,
817
+ dropout=0.0,
818
+ )
819
+ for _ in range(depth)
820
+ ]
821
+ )
822
+
823
+ self.conv_out = nn.Conv2d(
824
+ mid_channels,
825
+ out_channels,
826
+ kernel_size=1,
827
+ )
828
+
829
+ def forward(self, x):
830
+ x = self.conv_in(x)
831
+ for block in self.res_block1:
832
+ x = block(x, None)
833
+ x = torch.nn.functional.interpolate(
834
+ x,
835
+ size=(
836
+ int(round(x.shape[2] * self.factor)),
837
+ int(round(x.shape[3] * self.factor)),
838
+ ),
839
+ )
840
+ x = self.attn(x).contiguous()
841
+ for block in self.res_block2:
842
+ x = block(x, None)
843
+ x = self.conv_out(x)
844
+ return x
845
+
846
+
847
+ class MergedRescaleEncoder(nn.Module):
848
+ def __init__(
849
+ self,
850
+ in_channels,
851
+ ch,
852
+ resolution,
853
+ out_ch,
854
+ num_res_blocks,
855
+ attn_resolutions,
856
+ dropout=0.0,
857
+ resamp_with_conv=True,
858
+ ch_mult=(1, 2, 4, 8),
859
+ rescale_factor=1.0,
860
+ rescale_module_depth=1,
861
+ ):
862
+ super().__init__()
863
+ intermediate_chn = ch * ch_mult[-1]
864
+ self.encoder = Encoder(
865
+ in_channels=in_channels,
866
+ num_res_blocks=num_res_blocks,
867
+ ch=ch,
868
+ ch_mult=ch_mult,
869
+ z_channels=intermediate_chn,
870
+ double_z=False,
871
+ resolution=resolution,
872
+ attn_resolutions=attn_resolutions,
873
+ dropout=dropout,
874
+ resamp_with_conv=resamp_with_conv,
875
+ out_ch=None,
876
+ )
877
+ self.rescaler = LatentRescaler(
878
+ factor=rescale_factor,
879
+ in_channels=intermediate_chn,
880
+ mid_channels=intermediate_chn,
881
+ out_channels=out_ch,
882
+ depth=rescale_module_depth,
883
+ )
884
+
885
+ def forward(self, x):
886
+ x = self.encoder(x)
887
+ x = self.rescaler(x)
888
+ return x
889
+
890
+
891
+ class MergedRescaleDecoder(nn.Module):
892
+ def __init__(
893
+ self,
894
+ z_channels,
895
+ out_ch,
896
+ resolution,
897
+ num_res_blocks,
898
+ attn_resolutions,
899
+ ch,
900
+ ch_mult=(1, 2, 4, 8),
901
+ dropout=0.0,
902
+ resamp_with_conv=True,
903
+ rescale_factor=1.0,
904
+ rescale_module_depth=1,
905
+ ):
906
+ super().__init__()
907
+ tmp_chn = z_channels * ch_mult[-1]
908
+ self.decoder = Decoder(
909
+ out_ch=out_ch,
910
+ z_channels=tmp_chn,
911
+ attn_resolutions=attn_resolutions,
912
+ dropout=dropout,
913
+ resamp_with_conv=resamp_with_conv,
914
+ in_channels=None,
915
+ num_res_blocks=num_res_blocks,
916
+ ch_mult=ch_mult,
917
+ resolution=resolution,
918
+ ch=ch,
919
+ )
920
+ self.rescaler = LatentRescaler(
921
+ factor=rescale_factor,
922
+ in_channels=z_channels,
923
+ mid_channels=tmp_chn,
924
+ out_channels=tmp_chn,
925
+ depth=rescale_module_depth,
926
+ )
927
+
928
+ def forward(self, x):
929
+ x = self.rescaler(x)
930
+ x = self.decoder(x)
931
+ return x
932
+
933
+
934
+ class Upsampler(nn.Module):
935
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
936
+ super().__init__()
937
+ assert out_size >= in_size
938
+ num_blocks = int(np.log2(out_size // in_size)) + 1
939
+ factor_up = 1.0 + (out_size % in_size)
940
+ print(
941
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
942
+ )
943
+ self.rescaler = LatentRescaler(
944
+ factor=factor_up,
945
+ in_channels=in_channels,
946
+ mid_channels=2 * in_channels,
947
+ out_channels=in_channels,
948
+ )
949
+ self.decoder = Decoder(
950
+ out_ch=out_channels,
951
+ resolution=out_size,
952
+ z_channels=in_channels,
953
+ num_res_blocks=2,
954
+ attn_resolutions=[],
955
+ in_channels=None,
956
+ ch=in_channels,
957
+ ch_mult=[ch_mult for _ in range(num_blocks)],
958
+ )
959
+
960
+ def forward(self, x):
961
+ x = self.rescaler(x)
962
+ x = self.decoder(x)
963
+ return x
964
+
965
+
966
+ class Resize(nn.Module):
967
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
968
+ super().__init__()
969
+ self.with_conv = learned
970
+ self.mode = mode
971
+ if self.with_conv:
972
+ print(
973
+ f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
974
+ )
975
+ raise NotImplementedError()
976
+ assert in_channels is not None
977
+ # no asymmetric padding in torch conv, must do it ourselves
978
+ self.conv = torch.nn.Conv2d(
979
+ in_channels, in_channels, kernel_size=4, stride=2, padding=1
980
+ )
981
+
982
+ def forward(self, x, scale_factor=1.0):
983
+ if scale_factor == 1.0:
984
+ return x
985
+ else:
986
+ x = torch.nn.functional.interpolate(
987
+ x, mode=self.mode, align_corners=False, scale_factor=scale_factor
988
+ )
989
+ return x
990
+
991
+
992
+ class FirstStagePostProcessor(nn.Module):
993
+ def __init__(
994
+ self,
995
+ ch_mult: list,
996
+ in_channels,
997
+ pretrained_model: nn.Module = None,
998
+ reshape=False,
999
+ n_channels=None,
1000
+ dropout=0.0,
1001
+ pretrained_config=None,
1002
+ ):
1003
+ super().__init__()
1004
+ if pretrained_config is None:
1005
+ assert (
1006
+ pretrained_model is not None
1007
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
1008
+ self.pretrained_model = pretrained_model
1009
+ else:
1010
+ assert (
1011
+ pretrained_config is not None
1012
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
1013
+ self.instantiate_pretrained(pretrained_config)
1014
+
1015
+ self.do_reshape = reshape
1016
+
1017
+ if n_channels is None:
1018
+ n_channels = self.pretrained_model.encoder.ch
1019
+
1020
+ self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
1021
+ self.proj = nn.Conv2d(
1022
+ in_channels, n_channels, kernel_size=3, stride=1, padding=1
1023
+ )
1024
+
1025
+ blocks = []
1026
+ downs = []
1027
+ ch_in = n_channels
1028
+ for m in ch_mult:
1029
+ blocks.append(
1030
+ ResnetBlock(
1031
+ in_channels=ch_in, out_channels=m * n_channels, dropout=dropout
1032
+ )
1033
+ )
1034
+ ch_in = m * n_channels
1035
+ downs.append(Downsample(ch_in, with_conv=False))
1036
+
1037
+ self.model = nn.ModuleList(blocks)
1038
+ self.downsampler = nn.ModuleList(downs)
1039
+
1040
+ def instantiate_pretrained(self, config):
1041
+ model = instantiate_from_config(config)
1042
+ self.pretrained_model = model.eval()
1043
+ # self.pretrained_model.train = False
1044
+ for param in self.pretrained_model.parameters():
1045
+ param.requires_grad = False
1046
+
1047
+ @torch.no_grad()
1048
+ def encode_with_pretrained(self, x):
1049
+ c = self.pretrained_model.encode(x)
1050
+ if isinstance(c, DiagonalGaussianDistribution):
1051
+ c = c.mode()
1052
+ return c
1053
+
1054
+ def forward(self, x):
1055
+ z_fs = self.encode_with_pretrained(x)
1056
+ z = self.proj_norm(z_fs)
1057
+ z = self.proj(z)
1058
+ z = nonlinearity(z)
1059
+
1060
+ for submodel, downmodel in zip(self.model, self.downsampler):
1061
+ z = submodel(z, temb=None)
1062
+ z = downmodel(z)
1063
+
1064
+ if self.do_reshape:
1065
+ z = rearrange(z, "b c h w -> b (h w) c")
1066
+ return z
diffusers/CITATION.cff ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cff-version: 1.2.0
2
+ title: 'Diffusers: State-of-the-art diffusion models'
3
+ message: >-
4
+ If you use this software, please cite it using the
5
+ metadata from this file.
6
+ type: software
7
+ authors:
8
+ - given-names: Patrick
9
+ family-names: von Platen
10
+ - given-names: Suraj
11
+ family-names: Patil
12
+ - given-names: Anton
13
+ family-names: Lozhkov
14
+ - given-names: Pedro
15
+ family-names: Cuenca
16
+ - given-names: Nathan
17
+ family-names: Lambert
18
+ - given-names: Kashif
19
+ family-names: Rasul
20
+ - given-names: Mishig
21
+ family-names: Davaadorj
22
+ - given-names: Thomas
23
+ family-names: Wolf
24
+ repository-code: 'https://github.com/huggingface/diffusers'
25
+ abstract: >-
26
+ Diffusers provides pretrained diffusion models across
27
+ multiple modalities, such as vision and audio, and serves
28
+ as a modular toolbox for inference and training of
29
+ diffusion models.
30
+ keywords:
31
+ - deep-learning
32
+ - pytorch
33
+ - image-generation
34
+ - diffusion
35
+ - text2image
36
+ - image2image
37
+ - score-based-generative-modeling
38
+ - stable-diffusion
39
+ license: Apache-2.0
40
+ version: 0.12.1
diffusers/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Contributor Covenant Code of Conduct
3
+
4
+ ## Our Pledge
5
+
6
+ We as members, contributors, and leaders pledge to make participation in our
7
+ community a harassment-free experience for everyone, regardless of age, body
8
+ size, visible or invisible disability, ethnicity, sex characteristics, gender
9
+ identity and expression, level of experience, education, socio-economic status,
10
+ nationality, personal appearance, race, religion, or sexual identity
11
+ and orientation.
12
+
13
+ We pledge to act and interact in ways that contribute to an open, welcoming,
14
+ diverse, inclusive, and healthy community.
15
+
16
+ ## Our Standards
17
+
18
+ Examples of behavior that contributes to a positive environment for our
19
+ community include:
20
+
21
+ * Demonstrating empathy and kindness toward other people
22
+ * Being respectful of differing opinions, viewpoints, and experiences
23
+ * Giving and gracefully accepting constructive feedback
24
+ * Accepting responsibility and apologizing to those affected by our mistakes,
25
+ and learning from the experience
26
+ * Focusing on what is best not just for us as individuals, but for the
27
+ overall diffusers community
28
+
29
+ Examples of unacceptable behavior include:
30
+
31
+ * The use of sexualized language or imagery, and sexual attention or
32
+ advances of any kind
33
+ * Trolling, insulting or derogatory comments, and personal or political attacks
34
+ * Public or private harassment
35
+ * Publishing others' private information, such as a physical or email
36
+ address, without their explicit permission
37
+ * Spamming issues or PRs with links to projects unrelated to this library
38
+ * Other conduct which could reasonably be considered inappropriate in a
39
+ professional setting
40
+
41
+ ## Enforcement Responsibilities
42
+
43
+ Community leaders are responsible for clarifying and enforcing our standards of
44
+ acceptable behavior and will take appropriate and fair corrective action in
45
+ response to any behavior that they deem inappropriate, threatening, offensive,
46
+ or harmful.
47
+
48
+ Community leaders have the right and responsibility to remove, edit, or reject
49
+ comments, commits, code, wiki edits, issues, and other contributions that are
50
+ not aligned to this Code of Conduct, and will communicate reasons for moderation
51
+ decisions when appropriate.
52
+
53
+ ## Scope
54
+
55
+ This Code of Conduct applies within all community spaces, and also applies when
56
+ an individual is officially representing the community in public spaces.
57
+ Examples of representing our community include using an official e-mail address,
58
+ posting via an official social media account, or acting as an appointed
59
+ representative at an online or offline event.
60
+
61
+ ## Enforcement
62
+
63
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
64
+ reported to the community leaders responsible for enforcement at
65
66
+ All complaints will be reviewed and investigated promptly and fairly.
67
+
68
+ All community leaders are obligated to respect the privacy and security of the
69
+ reporter of any incident.
70
+
71
+ ## Enforcement Guidelines
72
+
73
+ Community leaders will follow these Community Impact Guidelines in determining
74
+ the consequences for any action they deem in violation of this Code of Conduct:
75
+
76
+ ### 1. Correction
77
+
78
+ **Community Impact**: Use of inappropriate language or other behavior deemed
79
+ unprofessional or unwelcome in the community.
80
+
81
+ **Consequence**: A private, written warning from community leaders, providing
82
+ clarity around the nature of the violation and an explanation of why the
83
+ behavior was inappropriate. A public apology may be requested.
84
+
85
+ ### 2. Warning
86
+
87
+ **Community Impact**: A violation through a single incident or series
88
+ of actions.
89
+
90
+ **Consequence**: A warning with consequences for continued behavior. No
91
+ interaction with the people involved, including unsolicited interaction with
92
+ those enforcing the Code of Conduct, for a specified period of time. This
93
+ includes avoiding interactions in community spaces as well as external channels
94
+ like social media. Violating these terms may lead to a temporary or
95
+ permanent ban.
96
+
97
+ ### 3. Temporary Ban
98
+
99
+ **Community Impact**: A serious violation of community standards, including
100
+ sustained inappropriate behavior.
101
+
102
+ **Consequence**: A temporary ban from any sort of interaction or public
103
+ communication with the community for a specified period of time. No public or
104
+ private interaction with the people involved, including unsolicited interaction
105
+ with those enforcing the Code of Conduct, is allowed during this period.
106
+ Violating these terms may lead to a permanent ban.
107
+
108
+ ### 4. Permanent Ban
109
+
110
+ **Community Impact**: Demonstrating a pattern of violation of community
111
+ standards, including sustained inappropriate behavior, harassment of an
112
+ individual, or aggression toward or disparagement of classes of individuals.
113
+
114
+ **Consequence**: A permanent ban from any sort of public interaction within
115
+ the community.
116
+
117
+ ## Attribution
118
+
119
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage],
120
+ version 2.0, available at
121
+ https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
122
+
123
+ Community Impact Guidelines were inspired by [Mozilla's code of conduct
124
+ enforcement ladder](https://github.com/mozilla/diversity).
125
+
126
+ [homepage]: https://www.contributor-covenant.org
127
+
128
+ For answers to common questions about this code of conduct, see the FAQ at
129
+ https://www.contributor-covenant.org/faq. Translations are available at
130
+ https://www.contributor-covenant.org/translations.