ZheqiDAI commited on
Commit
070e26e
·
1 Parent(s): f40e29b

Initial commit

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 AudioFans
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README copy.md ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Musimple:Text2Music with DiT Made simple
2
+
3
+ ## Introduction
4
+
5
+ This repository provides a simple and clear implementation of a **Text-to-Music Generation** pipeline using a **DiT (Diffusion Transformer)** model. The codebase includes key components such as **model training**, **inference**, and **evaluation**. We use the **GTZAN dataset** as an example to demonstrate a minimal, working pipeline for text-conditioned music generation.
6
+
7
+ The repository is designed to be easy to use and customize, making it simple to reproduce our results on a single **NVIDIA RTX 4090 GPU**. Additionally, the code is structured to be flexible, allowing you to modify it for your own tasks and datasets.
8
+
9
+ We plan to continue maintaining and improving this repository with new features, model improvements, and extended documentation in the future.
10
+
11
+ ## Features
12
+
13
+ - **Text-to-Music Generation**: Generate music directly from text descriptions using a DiT model.
14
+ - **GTZAN Example**: A simple pipeline using the GTZAN dataset to demonstrate the workflow.
15
+ - **End-to-End Pipeline**: Includes model training, inference, and evaluation with support for generating audio files.
16
+ - **Customizable**: Easy to modify and extend for different datasets or use cases.
17
+ - **Single GPU Training**: Optimized for training on a single RTX 4090 GPU but adaptable to different hardware setups.
18
+
19
+ ## Requirements
20
+
21
+ Before using the code, ensure that the following dependencies are installed:
22
+
23
+ - Python >= 3.9
24
+ - CUDA (if available)
25
+ - Required Python libraries from `requirements.txt`
26
+
27
+ You can install the dependencies using:
28
+
29
+ ```bash
30
+ conda create -n musimple python=3.9
31
+ conda activate musimple
32
+ pip install -r requirements.txt
33
+ ```
34
+
35
+
36
+ ## Data Preprocessing
37
+
38
+ To begin with, you will need to download the **GTZAN dataset**. Once downloaded, you can use the `gtzan_split.py` script located in the `tools` directory to split the dataset into training and testing sets. Run the following command:
39
+
40
+ ```bash
41
+ python gtzan_split.py --root_dir /path/to/gtzan/genres --output_dir /path/to/output/directory
42
+ ```
43
+
44
+ Next, convert the audio files into an HDF5 format using the gtzan2h5.py script:
45
+
46
+ ```bash
47
+ python gtzan2h5.py --root_dir /path/to/audio/files --output_h5_file /path/to/output.h5 --config_path bigvgan_v2_22khz_80band_256x/config.json --sr 22050
48
+ ```
49
+
50
+ Preprocessed Data
51
+ If this process seems cumbersome, don’t worry! **We have already preprocessed the dataset**, and you can find it in the **musimple/dataset** directory. You can download and use this data directly to skip the preprocessing steps.
52
+
53
+ Data Breakdown
54
+ In this preprocessing stage, there are two main parts:
55
+
56
+ Text to Latent Transformation: We use a Sentence Transformer to convert text labels into latent representations.
57
+ Audio to Mel Spectrogram: The original audio files are converted into mel spectrograms.
58
+ Both the latent representations and mel spectrograms are stored in an HDF5 file, making them easily accessible during training and inference.
59
+
60
+ ## Training
61
+
62
+ To begin training, simply navigate to the `Musimple` directory and run the following command:
63
+
64
+ ```bash
65
+ cd Musimple
66
+ python train.py
67
+ ```
68
+
69
+ Configurable Parameters
70
+ All training-related parameters can be adjusted in the configuration file located at:
71
+ ```
72
+ ./config/train.yaml
73
+ ```
74
+ This allows you to easily modify aspects like the learning rate, batch size, number of epochs, and more to suit your hardware or dataset requirements.
75
+
76
+ We also provide a **pre-trained checkpoint** trained for two days on a single **NVIDIA RTX 4090**. You can use this checkpoint for inference or fine-tuning. The key training parameters for this checkpoint are as follows:
77
+
78
+ - `batch_size`: 48
79
+ - `mel_frames`: 800
80
+ - `lr`: 0.0001
81
+ - `num_epochs`: 100000
82
+ - `sample_interval`: 250
83
+ - `h5_file_path`: './dataset/gtzan_train.h5'
84
+ - `device`: 'cuda:4'
85
+ - `input_size`: [80, 800]
86
+ - `patch_size`: 8
87
+ - `in_channels`: 1
88
+ - `hidden_size`: 384
89
+ - `depth`: 12
90
+ - `num_heads`: 6
91
+ - `checkpoint_dir`: 'gtzan-ck'
92
+
93
+ You can modify the model architecture and parameters in the `train.yaml` configuration file to compare your models against ours. We will continue to release more checkpoints and models in future updates.
94
+
95
+ ## Inference
96
+
97
+ Once you have trained your own model, you can perform inference using the trained model. To do so, run the following command:
98
+
99
+ ```bash
100
+ python sample.py --checkpoint ./gtzan-ck/model_epoch_20000.pt \
101
+ --h5_file ./dataset/gtzan_test.h5 \
102
+ --output_gt_dir ./sample/gt \
103
+ --output_gen_dir ./sample/gn \
104
+ --segment_length 800 \
105
+ --sample_rate 22050
106
+ ```
107
+ You can also try running inference using our pre-trained model to familiarize yourself with the inference process. We have saved some inference results in the sample folder as a demo. However, due to the limited size of our model, the generated results are not of the highest quality and are intended as simple examples to guide further evaluation.
108
+
109
+ ## Evaluation
110
+
111
+ For the evaluation phase, we highly recommend creating a new environment and using the evaluation library available at [Generated Music Evaluation](https://github.com/HarlandZZC/generated_music_evaluation). This repository provides detailed instructions on setting up the environment and how to use the evaluation tools. New features and functionality will be added to this library over time.
112
+
113
+ Once you have set up the environment following the instructions from the evaluation repository, you can run the following script to evaluate your generated music:
114
+
115
+ ```bash
116
+ python eval.py \
117
+ --ref_path ../sample/gt \
118
+ --gen_path ../sample/gn \
119
+ --id2text_csv_path ../gtzan-test.csv \
120
+ --output_path ./output \
121
+ --device_id 0 \
122
+ --batch_size 32 \
123
+ --original_sample_rate 24000 \
124
+ --fad_sample_rate 16000 \
125
+ --kl_sample_rate 16000 \
126
+ --clap_sample_rate 48000 \
127
+ --run_fad 1 \
128
+ --run_kl 1 \
129
+ --run_clap 1
130
+ ```
131
+
132
+ This script evaluates the generated music against reference music, producing evaluation metrics such as CLAP, KL, and FAD scores.
133
+
134
+ ## To-Do
135
+
136
+ The following features and improvements are planned for future updates:
137
+
138
+ - **EMA Model**: Implement Exponential Moving Average (EMA) for model weights to stabilize training and improve final generation quality.
139
+ - **Long-Term Music Fine-tuning**: Explore fine-tuning the model to generate longer-term music with more coherent structures.
140
+ - **VAE Integration**: Integrate a Variational Autoencoder (VAE) to improve latent space representations and potentially enhance generation diversity.
141
+ - **T5-based Text Conditioning**: Add T5 to enhance text conditioning, improving the control and accuracy of the text-to-music generation process.
config/train.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 48
2
+ mel_frames: 800
3
+ lr: 0.0001
4
+ num_epochs: 100000
5
+ sample_interval: 250
6
+ h5_file_path: './dataset/gtzan_train.h5'
7
+ device: 'cuda:4'
8
+ input_size: [80, 800]
9
+ patch_size: 8
10
+ in_channels: 1
11
+ hidden_size: 384
12
+ depth: 12
13
+ num_heads: 6
14
+ checkpoint_dir: 'gtzan-ck'
dataset/gtzan_test.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f9c40a6548fcd65c8bf4296968e1bf8289ba422e9fdfacd6745d4c9dfc86082
3
+ size 90507648
dataset/gtzan_train.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:682e8998af88b14af1132d3fafc916f30fcfe21d4e91743fa2e91828667b9d6d
3
+ size 813506352
diffusion/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from . import gaussian_diffusion as gd
7
+ from .respace import SpacedDiffusion, space_timesteps
8
+
9
+
10
+ def create_diffusion(
11
+ timestep_respacing,
12
+ noise_schedule="linear",
13
+ use_kl=False,
14
+ sigma_small=False,
15
+ predict_xstart=False,
16
+ learn_sigma=True,
17
+ rescale_learned_sigmas=False,
18
+ diffusion_steps=1000
19
+ ):
20
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21
+ if use_kl:
22
+ loss_type = gd.LossType.RESCALED_KL
23
+ elif rescale_learned_sigmas:
24
+ loss_type = gd.LossType.RESCALED_MSE
25
+ else:
26
+ loss_type = gd.LossType.MSE
27
+ if timestep_respacing is None or timestep_respacing == "":
28
+ timestep_respacing = [diffusion_steps]
29
+ return SpacedDiffusion(
30
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31
+ betas=betas,
32
+ model_mean_type=(
33
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34
+ ),
35
+ model_var_type=(
36
+ (
37
+ gd.ModelVarType.FIXED_LARGE
38
+ if not sigma_small
39
+ else gd.ModelVarType.FIXED_SMALL
40
+ )
41
+ if not learn_sigma
42
+ else gd.ModelVarType.LEARNED_RANGE
43
+ ),
44
+ loss_type=loss_type
45
+ # rescale_timesteps=rescale_timesteps,
46
+ )
diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import torch as th
7
+ import numpy as np
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a continuous Gaussian distribution.
50
+ :param x: the targets
51
+ :param means: the Gaussian mean Tensor.
52
+ :param log_scales: the Gaussian log stddev Tensor.
53
+ :return: a tensor like x of log probabilities (in nats).
54
+ """
55
+ centered_x = x - means
56
+ inv_stdv = th.exp(-log_scales)
57
+ normalized_x = centered_x * inv_stdv
58
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59
+ return log_probs
60
+
61
+
62
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63
+ """
64
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
65
+ given image.
66
+ :param x: the target images. It is assumed that this was uint8 values,
67
+ rescaled to the range [-1, 1].
68
+ :param means: the Gaussian mean Tensor.
69
+ :param log_scales: the Gaussian log stddev Tensor.
70
+ :return: a tensor like x of log probabilities (in nats).
71
+ """
72
+ assert x.shape == means.shape == log_scales.shape
73
+ centered_x = x - means
74
+ inv_stdv = th.exp(-log_scales)
75
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76
+ cdf_plus = approx_standard_normal_cdf(plus_in)
77
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78
+ cdf_min = approx_standard_normal_cdf(min_in)
79
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81
+ cdf_delta = cdf_plus - cdf_min
82
+ log_probs = th.where(
83
+ x < -0.999,
84
+ log_cdf_plus,
85
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86
+ )
87
+ assert log_probs.shape == x.shape
88
+ return log_probs
diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch as th
11
+ import enum
12
+
13
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
+
15
+
16
+ def mean_flat(tensor):
17
+ """
18
+ Take the mean over all non-batch dimensions.
19
+ """
20
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
21
+
22
+
23
+ class ModelMeanType(enum.Enum):
24
+ """
25
+ Which type of output the model predicts.
26
+ """
27
+
28
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29
+ START_X = enum.auto() # the model predicts x_0
30
+ EPSILON = enum.auto() # the model predicts epsilon
31
+
32
+
33
+ class ModelVarType(enum.Enum):
34
+ """
35
+ What is used as the model's output variance.
36
+ The LEARNED_RANGE option has been added to allow the model to predict
37
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38
+ """
39
+
40
+ LEARNED = enum.auto()
41
+ FIXED_SMALL = enum.auto()
42
+ FIXED_LARGE = enum.auto()
43
+ LEARNED_RANGE = enum.auto()
44
+
45
+
46
+ class LossType(enum.Enum):
47
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48
+ RESCALED_MSE = (
49
+ enum.auto()
50
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
51
+ KL = enum.auto() # use the variational lower-bound
52
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
53
+
54
+ def is_vb(self):
55
+ return self == LossType.KL or self == LossType.RESCALED_KL
56
+
57
+
58
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
59
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
60
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
61
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
62
+ return betas
63
+
64
+
65
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
66
+ """
67
+ This is the deprecated API for creating beta schedules.
68
+ See get_named_beta_schedule() for the new library of schedules.
69
+ """
70
+ if beta_schedule == "quad":
71
+ betas = (
72
+ np.linspace(
73
+ beta_start ** 0.5,
74
+ beta_end ** 0.5,
75
+ num_diffusion_timesteps,
76
+ dtype=np.float64,
77
+ )
78
+ ** 2
79
+ )
80
+ elif beta_schedule == "linear":
81
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
82
+ elif beta_schedule == "warmup10":
83
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
84
+ elif beta_schedule == "warmup50":
85
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
86
+ elif beta_schedule == "const":
87
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
88
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
89
+ betas = 1.0 / np.linspace(
90
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
91
+ )
92
+ else:
93
+ raise NotImplementedError(beta_schedule)
94
+ assert betas.shape == (num_diffusion_timesteps,)
95
+ return betas
96
+
97
+
98
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
99
+ """
100
+ Get a pre-defined beta schedule for the given name.
101
+ The beta schedule library consists of beta schedules which remain similar
102
+ in the limit of num_diffusion_timesteps.
103
+ Beta schedules may be added, but should not be removed or changed once
104
+ they are committed to maintain backwards compatibility.
105
+ """
106
+ if schedule_name == "linear":
107
+ # Linear schedule from Ho et al, extended to work for any number of
108
+ # diffusion steps.
109
+ scale = 1000 / num_diffusion_timesteps
110
+ return get_beta_schedule(
111
+ "linear",
112
+ beta_start=scale * 0.0001,
113
+ beta_end=scale * 0.02,
114
+ num_diffusion_timesteps=num_diffusion_timesteps,
115
+ )
116
+ elif schedule_name == "squaredcos_cap_v2":
117
+ return betas_for_alpha_bar(
118
+ num_diffusion_timesteps,
119
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
120
+ )
121
+ else:
122
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
123
+
124
+
125
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
126
+ """
127
+ Create a beta schedule that discretizes the given alpha_t_bar function,
128
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
129
+ :param num_diffusion_timesteps: the number of betas to produce.
130
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
131
+ produces the cumulative product of (1-beta) up to that
132
+ part of the diffusion process.
133
+ :param max_beta: the maximum beta to use; use values lower than 1 to
134
+ prevent singularities.
135
+ """
136
+ betas = []
137
+ for i in range(num_diffusion_timesteps):
138
+ t1 = i / num_diffusion_timesteps
139
+ t2 = (i + 1) / num_diffusion_timesteps
140
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
141
+ return np.array(betas)
142
+
143
+
144
+ class GaussianDiffusion:
145
+ """
146
+ Utilities for training and sampling diffusion models.
147
+ Original ported from this codebase:
148
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
149
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
150
+ starting at T and going to 1.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ *,
156
+ betas,
157
+ model_mean_type,
158
+ model_var_type,
159
+ loss_type
160
+ ):
161
+
162
+ self.model_mean_type = model_mean_type
163
+ self.model_var_type = model_var_type
164
+ self.loss_type = loss_type
165
+
166
+ # Use float64 for accuracy.
167
+ betas = np.array(betas, dtype=np.float64)
168
+ self.betas = betas
169
+ assert len(betas.shape) == 1, "betas must be 1-D"
170
+ assert (betas > 0).all() and (betas <= 1).all()
171
+
172
+ self.num_timesteps = int(betas.shape[0])
173
+
174
+ alphas = 1.0 - betas
175
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
176
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
177
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
178
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
179
+
180
+ # calculations for diffusion q(x_t | x_{t-1}) and others
181
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
182
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
183
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
184
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
185
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
186
+
187
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
188
+ self.posterior_variance = (
189
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
190
+ )
191
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
192
+ self.posterior_log_variance_clipped = np.log(
193
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
194
+ ) if len(self.posterior_variance) > 1 else np.array([])
195
+
196
+ self.posterior_mean_coef1 = (
197
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
198
+ )
199
+ self.posterior_mean_coef2 = (
200
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
201
+ )
202
+
203
+ def q_mean_variance(self, x_start, t):
204
+ """
205
+ Get the distribution q(x_t | x_0).
206
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
207
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
208
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
209
+ """
210
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
211
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
212
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
213
+ return mean, variance, log_variance
214
+
215
+ def q_sample(self, x_start, t, noise=None):
216
+ """
217
+ Diffuse the data for a given number of diffusion steps.
218
+ In other words, sample from q(x_t | x_0).
219
+ :param x_start: the initial data batch.
220
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
221
+ :param noise: if specified, the split-out normal noise.
222
+ :return: A noisy version of x_start.
223
+ """
224
+ if noise is None:
225
+ noise = th.randn_like(x_start)
226
+ assert noise.shape == x_start.shape
227
+ return (
228
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
229
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
230
+ )
231
+
232
+ def q_posterior_mean_variance(self, x_start, x_t, t):
233
+ """
234
+ Compute the mean and variance of the diffusion posterior:
235
+ q(x_{t-1} | x_t, x_0)
236
+ """
237
+ assert x_start.shape == x_t.shape
238
+ posterior_mean = (
239
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
240
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
241
+ )
242
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
243
+ posterior_log_variance_clipped = _extract_into_tensor(
244
+ self.posterior_log_variance_clipped, t, x_t.shape
245
+ )
246
+ assert (
247
+ posterior_mean.shape[0]
248
+ == posterior_variance.shape[0]
249
+ == posterior_log_variance_clipped.shape[0]
250
+ == x_start.shape[0]
251
+ )
252
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
253
+
254
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
255
+ """
256
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
257
+ the initial x, x_0.
258
+ :param model: the model, which takes a signal and a batch of timesteps
259
+ as input.
260
+ :param x: the [N x C x ...] tensor at time t.
261
+ :param t: a 1-D Tensor of timesteps.
262
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
263
+ :param denoised_fn: if not None, a function which applies to the
264
+ x_start prediction before it is used to sample. Applies before
265
+ clip_denoised.
266
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
267
+ pass to the model. This can be used for conditioning.
268
+ :return: a dict with the following keys:
269
+ - 'mean': the model mean output.
270
+ - 'variance': the model variance output.
271
+ - 'log_variance': the log of 'variance'.
272
+ - 'pred_xstart': the prediction for x_0.
273
+ """
274
+ if model_kwargs is None:
275
+ model_kwargs = {}
276
+
277
+ B, C = x.shape[:2]
278
+ assert t.shape == (B,)
279
+ model_output = model(x, t, **model_kwargs)
280
+ if isinstance(model_output, tuple):
281
+ model_output, extra = model_output
282
+ else:
283
+ extra = None
284
+
285
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
286
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
287
+ model_output, model_var_values = th.split(model_output, C, dim=1)
288
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
289
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
290
+ # The model_var_values is [-1, 1] for [min_var, max_var].
291
+ frac = (model_var_values + 1) / 2
292
+ model_log_variance = frac * max_log + (1 - frac) * min_log
293
+ model_variance = th.exp(model_log_variance)
294
+ else:
295
+ model_variance, model_log_variance = {
296
+ # for fixedlarge, we set the initial (log-)variance like so
297
+ # to get a better decoder log likelihood.
298
+ ModelVarType.FIXED_LARGE: (
299
+ np.append(self.posterior_variance[1], self.betas[1:]),
300
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
301
+ ),
302
+ ModelVarType.FIXED_SMALL: (
303
+ self.posterior_variance,
304
+ self.posterior_log_variance_clipped,
305
+ ),
306
+ }[self.model_var_type]
307
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
308
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
309
+
310
+ def process_xstart(x):
311
+ if denoised_fn is not None:
312
+ x = denoised_fn(x)
313
+ if clip_denoised:
314
+ return x.clamp(-1, 1)
315
+ return x
316
+
317
+ if self.model_mean_type == ModelMeanType.START_X:
318
+ pred_xstart = process_xstart(model_output)
319
+ else:
320
+ pred_xstart = process_xstart(
321
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
322
+ )
323
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
324
+
325
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
326
+ return {
327
+ "mean": model_mean,
328
+ "variance": model_variance,
329
+ "log_variance": model_log_variance,
330
+ "pred_xstart": pred_xstart,
331
+ "extra": extra,
332
+ }
333
+
334
+ def _predict_xstart_from_eps(self, x_t, t, eps):
335
+ assert x_t.shape == eps.shape
336
+ return (
337
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
338
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
339
+ )
340
+
341
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
342
+ return (
343
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
344
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
345
+
346
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
347
+ """
348
+ Compute the mean for the previous step, given a function cond_fn that
349
+ computes the gradient of a conditional log probability with respect to
350
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
351
+ condition on y.
352
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
353
+ """
354
+ gradient = cond_fn(x, t, **model_kwargs)
355
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
356
+ return new_mean
357
+
358
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
359
+ """
360
+ Compute what the p_mean_variance output would have been, should the
361
+ model's score function be conditioned by cond_fn.
362
+ See condition_mean() for details on cond_fn.
363
+ Unlike condition_mean(), this instead uses the conditioning strategy
364
+ from Song et al (2020).
365
+ """
366
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
367
+
368
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
369
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
370
+
371
+ out = p_mean_var.copy()
372
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
373
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
374
+ return out
375
+
376
+ def p_sample(
377
+ self,
378
+ model,
379
+ x,
380
+ t,
381
+ clip_denoised=True,
382
+ denoised_fn=None,
383
+ cond_fn=None,
384
+ model_kwargs=None,
385
+ ):
386
+ """
387
+ Sample x_{t-1} from the model at the given timestep.
388
+ :param model: the model to sample from.
389
+ :param x: the current tensor at x_{t-1}.
390
+ :param t: the value of t, starting at 0 for the first diffusion step.
391
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
392
+ :param denoised_fn: if not None, a function which applies to the
393
+ x_start prediction before it is used to sample.
394
+ :param cond_fn: if not None, this is a gradient function that acts
395
+ similarly to the model.
396
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
397
+ pass to the model. This can be used for conditioning.
398
+ :return: a dict containing the following keys:
399
+ - 'sample': a random sample from the model.
400
+ - 'pred_xstart': a prediction of x_0.
401
+ """
402
+ out = self.p_mean_variance(
403
+ model,
404
+ x,
405
+ t,
406
+ clip_denoised=clip_denoised,
407
+ denoised_fn=denoised_fn,
408
+ model_kwargs=model_kwargs,
409
+ )
410
+ noise = th.randn_like(x)
411
+ nonzero_mask = (
412
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
413
+ ) # no noise when t == 0
414
+ if cond_fn is not None:
415
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
416
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
417
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
418
+
419
+ def p_sample_loop(
420
+ self,
421
+ model,
422
+ shape,
423
+ noise=None,
424
+ clip_denoised=True,
425
+ denoised_fn=None,
426
+ cond_fn=None,
427
+ model_kwargs=None,
428
+ device=None,
429
+ progress=False,
430
+ ):
431
+ """
432
+ Generate samples from the model.
433
+ :param model: the model module.
434
+ :param shape: the shape of the samples, (N, C, H, W).
435
+ :param noise: if specified, the noise from the encoder to sample.
436
+ Should be of the same shape as `shape`.
437
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
438
+ :param denoised_fn: if not None, a function which applies to the
439
+ x_start prediction before it is used to sample.
440
+ :param cond_fn: if not None, this is a gradient function that acts
441
+ similarly to the model.
442
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
443
+ pass to the model. This can be used for conditioning.
444
+ :param device: if specified, the device to create the samples on.
445
+ If not specified, use a model parameter's device.
446
+ :param progress: if True, show a tqdm progress bar.
447
+ :return: a non-differentiable batch of samples.
448
+ """
449
+ final = None
450
+ for sample in self.p_sample_loop_progressive(
451
+ model,
452
+ shape,
453
+ noise=noise,
454
+ clip_denoised=clip_denoised,
455
+ denoised_fn=denoised_fn,
456
+ cond_fn=cond_fn,
457
+ model_kwargs=model_kwargs,
458
+ device=device,
459
+ progress=progress,
460
+ ):
461
+ final = sample
462
+ return final["sample"]
463
+
464
+ def p_sample_loop_progressive(
465
+ self,
466
+ model,
467
+ shape,
468
+ noise=None,
469
+ clip_denoised=True,
470
+ denoised_fn=None,
471
+ cond_fn=None,
472
+ model_kwargs=None,
473
+ device=None,
474
+ progress=False,
475
+ ):
476
+ """
477
+ Generate samples from the model and yield intermediate samples from
478
+ each timestep of diffusion.
479
+ Arguments are the same as p_sample_loop().
480
+ Returns a generator over dicts, where each dict is the return value of
481
+ p_sample().
482
+ """
483
+ if device is None:
484
+ device = next(model.parameters()).device
485
+ assert isinstance(shape, (tuple, list))
486
+ if noise is not None:
487
+ img = noise
488
+ else:
489
+ img = th.randn(*shape, device=device)
490
+ indices = list(range(self.num_timesteps))[::-1]
491
+
492
+ if progress:
493
+ # Lazy import so that we don't depend on tqdm.
494
+ from tqdm.auto import tqdm
495
+
496
+ indices = tqdm(indices)
497
+
498
+ for i in indices:
499
+ t = th.tensor([i] * shape[0], device=device)
500
+ with th.no_grad():
501
+ out = self.p_sample(
502
+ model,
503
+ img,
504
+ t,
505
+ clip_denoised=clip_denoised,
506
+ denoised_fn=denoised_fn,
507
+ cond_fn=cond_fn,
508
+ model_kwargs=model_kwargs,
509
+ )
510
+ yield out
511
+ img = out["sample"]
512
+
513
+ def ddim_sample(
514
+ self,
515
+ model,
516
+ x,
517
+ t,
518
+ clip_denoised=True,
519
+ denoised_fn=None,
520
+ cond_fn=None,
521
+ model_kwargs=None,
522
+ eta=0.0,
523
+ ):
524
+ """
525
+ Sample x_{t-1} from the model using DDIM.
526
+ Same usage as p_sample().
527
+ """
528
+ out = self.p_mean_variance(
529
+ model,
530
+ x,
531
+ t,
532
+ clip_denoised=clip_denoised,
533
+ denoised_fn=denoised_fn,
534
+ model_kwargs=model_kwargs,
535
+ )
536
+ if cond_fn is not None:
537
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
538
+
539
+ # Usually our model outputs epsilon, but we re-derive it
540
+ # in case we used x_start or x_prev prediction.
541
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
542
+
543
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
544
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
545
+ sigma = (
546
+ eta
547
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
548
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
549
+ )
550
+ # Equation 12.
551
+ noise = th.randn_like(x)
552
+ mean_pred = (
553
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
554
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
555
+ )
556
+ nonzero_mask = (
557
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
558
+ ) # no noise when t == 0
559
+ sample = mean_pred + nonzero_mask * sigma * noise
560
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
561
+
562
+ def ddim_reverse_sample(
563
+ self,
564
+ model,
565
+ x,
566
+ t,
567
+ clip_denoised=True,
568
+ denoised_fn=None,
569
+ cond_fn=None,
570
+ model_kwargs=None,
571
+ eta=0.0,
572
+ ):
573
+ """
574
+ Sample x_{t+1} from the model using DDIM reverse ODE.
575
+ """
576
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
577
+ out = self.p_mean_variance(
578
+ model,
579
+ x,
580
+ t,
581
+ clip_denoised=clip_denoised,
582
+ denoised_fn=denoised_fn,
583
+ model_kwargs=model_kwargs,
584
+ )
585
+ if cond_fn is not None:
586
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
587
+ # Usually our model outputs epsilon, but we re-derive it
588
+ # in case we used x_start or x_prev prediction.
589
+ eps = (
590
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
591
+ - out["pred_xstart"]
592
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
593
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
594
+
595
+ # Equation 12. reversed
596
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
597
+
598
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
599
+
600
+ def ddim_sample_loop(
601
+ self,
602
+ model,
603
+ shape,
604
+ noise=None,
605
+ clip_denoised=True,
606
+ denoised_fn=None,
607
+ cond_fn=None,
608
+ model_kwargs=None,
609
+ device=None,
610
+ progress=False,
611
+ eta=0.0,
612
+ ):
613
+ """
614
+ Generate samples from the model using DDIM.
615
+ Same usage as p_sample_loop().
616
+ """
617
+ final = None
618
+ for sample in self.ddim_sample_loop_progressive(
619
+ model,
620
+ shape,
621
+ noise=noise,
622
+ clip_denoised=clip_denoised,
623
+ denoised_fn=denoised_fn,
624
+ cond_fn=cond_fn,
625
+ model_kwargs=model_kwargs,
626
+ device=device,
627
+ progress=progress,
628
+ eta=eta,
629
+ ):
630
+ final = sample
631
+ return final["sample"]
632
+
633
+ def ddim_sample_loop_progressive(
634
+ self,
635
+ model,
636
+ shape,
637
+ noise=None,
638
+ clip_denoised=True,
639
+ denoised_fn=None,
640
+ cond_fn=None,
641
+ model_kwargs=None,
642
+ device=None,
643
+ progress=False,
644
+ eta=0.0,
645
+ ):
646
+ """
647
+ Use DDIM to sample from the model and yield intermediate samples from
648
+ each timestep of DDIM.
649
+ Same usage as p_sample_loop_progressive().
650
+ """
651
+ if device is None:
652
+ device = next(model.parameters()).device
653
+ assert isinstance(shape, (tuple, list))
654
+ if noise is not None:
655
+ img = noise
656
+ else:
657
+ img = th.randn(*shape, device=device)
658
+ indices = list(range(self.num_timesteps))[::-1]
659
+
660
+ if progress:
661
+ # Lazy import so that we don't depend on tqdm.
662
+ from tqdm.auto import tqdm
663
+
664
+ indices = tqdm(indices)
665
+
666
+ for i in indices:
667
+ t = th.tensor([i] * shape[0], device=device)
668
+ with th.no_grad():
669
+ out = self.ddim_sample(
670
+ model,
671
+ img,
672
+ t,
673
+ clip_denoised=clip_denoised,
674
+ denoised_fn=denoised_fn,
675
+ cond_fn=cond_fn,
676
+ model_kwargs=model_kwargs,
677
+ eta=eta,
678
+ )
679
+ yield out
680
+ img = out["sample"]
681
+
682
+ def _vb_terms_bpd(
683
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
684
+ ):
685
+ """
686
+ Get a term for the variational lower-bound.
687
+ The resulting units are bits (rather than nats, as one might expect).
688
+ This allows for comparison to other papers.
689
+ :return: a dict with the following keys:
690
+ - 'output': a shape [N] tensor of NLLs or KLs.
691
+ - 'pred_xstart': the x_0 predictions.
692
+ """
693
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
694
+ x_start=x_start, x_t=x_t, t=t
695
+ )
696
+ out = self.p_mean_variance(
697
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
698
+ )
699
+ kl = normal_kl(
700
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
701
+ )
702
+ kl = mean_flat(kl) / np.log(2.0)
703
+
704
+ decoder_nll = -discretized_gaussian_log_likelihood(
705
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
706
+ )
707
+ assert decoder_nll.shape == x_start.shape
708
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
709
+
710
+ # At the first timestep return the decoder NLL,
711
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
712
+ output = th.where((t == 0), decoder_nll, kl)
713
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
714
+
715
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
716
+ """
717
+ Compute training losses for a single timestep.
718
+ :param model: the model to evaluate loss on.
719
+ :param x_start: the [N x C x ...] tensor of inputs.
720
+ :param t: a batch of timestep indices.
721
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
722
+ pass to the model. This can be used for conditioning.
723
+ :param noise: if specified, the specific Gaussian noise to try to remove.
724
+ :return: a dict with the key "loss" containing a tensor of shape [N].
725
+ Some mean or variance settings may also have other keys.
726
+ """
727
+ if model_kwargs is None:
728
+ model_kwargs = {}
729
+ if noise is None:
730
+ noise = th.randn_like(x_start)
731
+ x_t = self.q_sample(x_start, t, noise=noise)
732
+
733
+ terms = {}
734
+
735
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
736
+ terms["loss"] = self._vb_terms_bpd(
737
+ model=model,
738
+ x_start=x_start,
739
+ x_t=x_t,
740
+ t=t,
741
+ clip_denoised=False,
742
+ model_kwargs=model_kwargs,
743
+ )["output"]
744
+ if self.loss_type == LossType.RESCALED_KL:
745
+ terms["loss"] *= self.num_timesteps
746
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
747
+ model_output = model(x_t, t, **model_kwargs)
748
+
749
+ if self.model_var_type in [
750
+ ModelVarType.LEARNED,
751
+ ModelVarType.LEARNED_RANGE,
752
+ ]:
753
+ B, C = x_t.shape[:2]
754
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
755
+ model_output, model_var_values = th.split(model_output, C, dim=1)
756
+ # Learn the variance using the variational bound, but don't let
757
+ # it affect our mean prediction.
758
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
759
+ terms["vb"] = self._vb_terms_bpd(
760
+ model=lambda *args, r=frozen_out: r,
761
+ x_start=x_start,
762
+ x_t=x_t,
763
+ t=t,
764
+ clip_denoised=False,
765
+ )["output"]
766
+ if self.loss_type == LossType.RESCALED_MSE:
767
+ # Divide by 1000 for equivalence with initial implementation.
768
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
769
+ terms["vb"] *= self.num_timesteps / 1000.0
770
+
771
+ target = {
772
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
773
+ x_start=x_start, x_t=x_t, t=t
774
+ )[0],
775
+ ModelMeanType.START_X: x_start,
776
+ ModelMeanType.EPSILON: noise,
777
+ }[self.model_mean_type]
778
+ assert model_output.shape == target.shape == x_start.shape
779
+ terms["mse"] = mean_flat((target - model_output) ** 2)
780
+ if "vb" in terms:
781
+ terms["loss"] = terms["mse"] + terms["vb"]
782
+ else:
783
+ terms["loss"] = terms["mse"]
784
+ else:
785
+ raise NotImplementedError(self.loss_type)
786
+
787
+ return terms
788
+
789
+ def _prior_bpd(self, x_start):
790
+ """
791
+ Get the prior KL term for the variational lower-bound, measured in
792
+ bits-per-dim.
793
+ This term can't be optimized, as it only depends on the encoder.
794
+ :param x_start: the [N x C x ...] tensor of inputs.
795
+ :return: a batch of [N] KL values (in bits), one per batch element.
796
+ """
797
+ batch_size = x_start.shape[0]
798
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
799
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
800
+ kl_prior = normal_kl(
801
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
802
+ )
803
+ return mean_flat(kl_prior) / np.log(2.0)
804
+
805
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
806
+ """
807
+ Compute the entire variational lower-bound, measured in bits-per-dim,
808
+ as well as other related quantities.
809
+ :param model: the model to evaluate loss on.
810
+ :param x_start: the [N x C x ...] tensor of inputs.
811
+ :param clip_denoised: if True, clip denoised samples.
812
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
813
+ pass to the model. This can be used for conditioning.
814
+ :return: a dict containing the following keys:
815
+ - total_bpd: the total variational lower-bound, per batch element.
816
+ - prior_bpd: the prior term in the lower-bound.
817
+ - vb: an [N x T] tensor of terms in the lower-bound.
818
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
819
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
820
+ """
821
+ device = x_start.device
822
+ batch_size = x_start.shape[0]
823
+
824
+ vb = []
825
+ xstart_mse = []
826
+ mse = []
827
+ for t in list(range(self.num_timesteps))[::-1]:
828
+ t_batch = th.tensor([t] * batch_size, device=device)
829
+ noise = th.randn_like(x_start)
830
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
831
+ # Calculate VLB term at the current timestep
832
+ with th.no_grad():
833
+ out = self._vb_terms_bpd(
834
+ model,
835
+ x_start=x_start,
836
+ x_t=x_t,
837
+ t=t_batch,
838
+ clip_denoised=clip_denoised,
839
+ model_kwargs=model_kwargs,
840
+ )
841
+ vb.append(out["output"])
842
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
843
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
844
+ mse.append(mean_flat((eps - noise) ** 2))
845
+
846
+ vb = th.stack(vb, dim=1)
847
+ xstart_mse = th.stack(xstart_mse, dim=1)
848
+ mse = th.stack(mse, dim=1)
849
+
850
+ prior_bpd = self._prior_bpd(x_start)
851
+ total_bpd = vb.sum(dim=1) + prior_bpd
852
+ return {
853
+ "total_bpd": total_bpd,
854
+ "prior_bpd": prior_bpd,
855
+ "vb": vb,
856
+ "xstart_mse": xstart_mse,
857
+ "mse": mse,
858
+ }
859
+
860
+
861
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
862
+ """
863
+ Extract values from a 1-D numpy array for a batch of indices.
864
+ :param arr: the 1-D numpy array.
865
+ :param timesteps: a tensor of indices into the array to extract.
866
+ :param broadcast_shape: a larger shape of K dimensions with the batch
867
+ dimension equal to the length of timesteps.
868
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
869
+ """
870
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
871
+ while len(res.shape) < len(broadcast_shape):
872
+ res = res[..., None]
873
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diffusion/respace.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from .gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ if section_count <= 1:
52
+ frac_stride = 1
53
+ else:
54
+ frac_stride = (size - 1) / (section_count - 1)
55
+ cur_idx = 0.0
56
+ taken_steps = []
57
+ for _ in range(section_count):
58
+ taken_steps.append(start_idx + round(cur_idx))
59
+ cur_idx += frac_stride
60
+ all_steps += taken_steps
61
+ start_idx += size
62
+ return set(all_steps)
63
+
64
+
65
+ class SpacedDiffusion(GaussianDiffusion):
66
+ """
67
+ A diffusion process which can skip steps in a base diffusion process.
68
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
69
+ original diffusion process to retain.
70
+ :param kwargs: the kwargs to create the base diffusion process.
71
+ """
72
+
73
+ def __init__(self, use_timesteps, **kwargs):
74
+ self.use_timesteps = set(use_timesteps)
75
+ self.timestep_map = []
76
+ self.original_num_steps = len(kwargs["betas"])
77
+
78
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79
+ last_alpha_cumprod = 1.0
80
+ new_betas = []
81
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82
+ if i in self.use_timesteps:
83
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84
+ last_alpha_cumprod = alpha_cumprod
85
+ self.timestep_map.append(i)
86
+ kwargs["betas"] = np.array(new_betas)
87
+ super().__init__(**kwargs)
88
+
89
+ def p_mean_variance(
90
+ self, model, *args, **kwargs
91
+ ): # pylint: disable=signature-differs
92
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93
+
94
+ def training_losses(
95
+ self, model, *args, **kwargs
96
+ ): # pylint: disable=signature-differs
97
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
98
+
99
+ def condition_mean(self, cond_fn, *args, **kwargs):
100
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101
+
102
+ def condition_score(self, cond_fn, *args, **kwargs):
103
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104
+
105
+ def _wrap_model(self, model):
106
+ if isinstance(model, _WrappedModel):
107
+ return model
108
+ return _WrappedModel(
109
+ model, self.timestep_map, self.original_num_steps
110
+ )
111
+
112
+ def _scale_timesteps(self, t):
113
+ # Scaling is done by the wrapped model.
114
+ return t
115
+
116
+
117
+ class _WrappedModel:
118
+ def __init__(self, model, timestep_map, original_num_steps):
119
+ self.model = model
120
+ self.timestep_map = timestep_map
121
+ # self.rescale_timesteps = rescale_timesteps
122
+ self.original_num_steps = original_num_steps
123
+
124
+ def __call__(self, x, ts, **kwargs):
125
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126
+ new_ts = map_tensor[ts]
127
+ # if self.rescale_timesteps:
128
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129
+ return self.model(x, new_ts, **kwargs)
diffusion/timestep_sampler.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.distributed as dist
11
+
12
+
13
+ def create_named_schedule_sampler(name, diffusion):
14
+ """
15
+ Create a ScheduleSampler from a library of pre-defined samplers.
16
+ :param name: the name of the sampler.
17
+ :param diffusion: the diffusion object to sample for.
18
+ """
19
+ if name == "uniform":
20
+ return UniformSampler(diffusion)
21
+ elif name == "loss-second-moment":
22
+ return LossSecondMomentResampler(diffusion)
23
+ else:
24
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
25
+
26
+
27
+ class ScheduleSampler(ABC):
28
+ """
29
+ A distribution over timesteps in the diffusion process, intended to reduce
30
+ variance of the objective.
31
+ By default, samplers perform unbiased importance sampling, in which the
32
+ objective's mean is unchanged.
33
+ However, subclasses may override sample() to change how the resampled
34
+ terms are reweighted, allowing for actual changes in the objective.
35
+ """
36
+
37
+ @abstractmethod
38
+ def weights(self):
39
+ """
40
+ Get a numpy array of weights, one per diffusion step.
41
+ The weights needn't be normalized, but must be positive.
42
+ """
43
+
44
+ def sample(self, batch_size, device):
45
+ """
46
+ Importance-sample timesteps for a batch.
47
+ :param batch_size: the number of timesteps.
48
+ :param device: the torch device to save to.
49
+ :return: a tuple (timesteps, weights):
50
+ - timesteps: a tensor of timestep indices.
51
+ - weights: a tensor of weights to scale the resulting losses.
52
+ """
53
+ w = self.weights()
54
+ p = w / np.sum(w)
55
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56
+ indices = th.from_numpy(indices_np).long().to(device)
57
+ weights_np = 1 / (len(p) * p[indices_np])
58
+ weights = th.from_numpy(weights_np).float().to(device)
59
+ return indices, weights
60
+
61
+
62
+ class UniformSampler(ScheduleSampler):
63
+ def __init__(self, diffusion):
64
+ self.diffusion = diffusion
65
+ self._weights = np.ones([diffusion.num_timesteps])
66
+
67
+ def weights(self):
68
+ return self._weights
69
+
70
+
71
+ class LossAwareSampler(ScheduleSampler):
72
+ def update_with_local_losses(self, local_ts, local_losses):
73
+ """
74
+ Update the reweighting using losses from a model.
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+ :param local_ts: an integer Tensor of timesteps.
80
+ :param local_losses: a 1D Tensor of losses.
81
+ """
82
+ batch_sizes = [
83
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
84
+ for _ in range(dist.get_world_size())
85
+ ]
86
+ dist.all_gather(
87
+ batch_sizes,
88
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89
+ )
90
+
91
+ # Pad all_gather batches to be the maximum batch size.
92
+ batch_sizes = [x.item() for x in batch_sizes]
93
+ max_bs = max(batch_sizes)
94
+
95
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97
+ dist.all_gather(timestep_batches, local_ts)
98
+ dist.all_gather(loss_batches, local_losses)
99
+ timesteps = [
100
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101
+ ]
102
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103
+ self.update_with_all_losses(timesteps, losses)
104
+
105
+ @abstractmethod
106
+ def update_with_all_losses(self, ts, losses):
107
+ """
108
+ Update the reweighting using losses from a model.
109
+ Sub-classes should override this method to update the reweighting
110
+ using losses from the model.
111
+ This method directly updates the reweighting without synchronizing
112
+ between workers. It is called by update_with_local_losses from all
113
+ ranks with identical arguments. Thus, it should have deterministic
114
+ behavior to maintain state across workers.
115
+ :param ts: a list of int timesteps.
116
+ :param losses: a list of float losses, one per timestep.
117
+ """
118
+
119
+
120
+ class LossSecondMomentResampler(LossAwareSampler):
121
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122
+ self.diffusion = diffusion
123
+ self.history_per_term = history_per_term
124
+ self.uniform_prob = uniform_prob
125
+ self._loss_history = np.zeros(
126
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
127
+ )
128
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129
+
130
+ def weights(self):
131
+ if not self._warmed_up():
132
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134
+ weights /= np.sum(weights)
135
+ weights *= 1 - self.uniform_prob
136
+ weights += self.uniform_prob / len(weights)
137
+ return weights
138
+
139
+ def update_with_all_losses(self, ts, losses):
140
+ for t, loss in zip(ts, losses):
141
+ if self._loss_counts[t] == self.history_per_term:
142
+ # Shift out the oldest loss term.
143
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
144
+ self._loss_history[t, -1] = loss
145
+ else:
146
+ self._loss_history[t, self._loss_counts[t]] = loss
147
+ self._loss_counts[t] += 1
148
+
149
+ def _warmed_up(self):
150
+ return (self._loss_counts == self.history_per_term).all()
gtzan-ck/model_epoch_20000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47347904d66464d7c77044b00ec00c6c24ce4a034df87f8c3f735564b2a328cb
3
+ size 392135773
gtzan-test.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ids,descri
2
+ classical_classical.00019,classical
3
+ rock_rock.00092,rock
4
+ reggae_reggae.00083,reggae
5
+ reggae_reggae.00087,reggae
6
+ country_country.00020,country
7
+ reggae_reggae.00080,reggae
8
+ metal_metal.00061,metal
9
+ jazz_jazz.00058,jazz
10
+ disco_disco.00033,disco
11
+ metal_metal.00068,metal
12
+ rock_rock.00096,rock
13
+ pop_pop.00070,pop
14
+ blues_blues.00001,blues
15
+ jazz_jazz.00050,jazz
16
+ country_country.00022,country
17
+ rock_rock.00095,rock
18
+ metal_metal.00066,metal
19
+ disco_disco.00032,disco
20
+ pop_pop.00071,pop
21
+ blues_blues.00006,blues
22
+ disco_disco.00039,disco
23
+ reggae_reggae.00081,reggae
24
+ reggae_reggae.00088,reggae
25
+ hiphop_hiphop.00046,hiphop
26
+ country_country.00023,country
27
+ hiphop_hiphop.00040,hiphop
28
+ classical_classical.00012,classical
29
+ reggae_reggae.00084,reggae
30
+ reggae_reggae.00085,reggae
31
+ hiphop_hiphop.00043,hiphop
32
+ jazz_jazz.00052,jazz
33
+ blues_blues.00004,blues
34
+ disco_disco.00037,disco
35
+ hiphop_hiphop.00047,hiphop
36
+ pop_pop.00076,pop
37
+ classical_classical.00014,classical
38
+ rock_rock.00090,rock
39
+ classical_classical.00013,classical
40
+ blues_blues.00002,blues
41
+ rock_rock.00098,rock
42
+ hiphop_hiphop.00044,hiphop
43
+ rock_rock.00099,rock
44
+ metal_metal.00065,metal
45
+ metal_metal.00062,metal
46
+ blues_blues.00007,blues
47
+ pop_pop.00073,pop
48
+ jazz_jazz.00053,jazz
49
+ country_country.00024,country
50
+ pop_pop.00078,pop
51
+ blues_blues.00000,blues
52
+ jazz_jazz.00055,jazz
53
+ blues_blues.00003,blues
54
+ hiphop_hiphop.00041,hiphop
55
+ hiphop_hiphop.00048,hiphop
56
+ pop_pop.00077,pop
57
+ metal_metal.00067,metal
58
+ reggae_reggae.00089,reggae
59
+ jazz_jazz.00056,jazz
60
+ hiphop_hiphop.00049,hiphop
61
+ disco_disco.00038,disco
62
+ jazz_jazz.00057,jazz
63
+ reggae_reggae.00082,reggae
64
+ rock_rock.00091,rock
65
+ metal_metal.00060,metal
66
+ country_country.00028,country
67
+ pop_pop.00075,pop
68
+ rock_rock.00094,rock
69
+ classical_classical.00010,classical
70
+ rock_rock.00097,rock
71
+ jazz_jazz.00051,jazz
72
+ country_country.00025,country
73
+ country_country.00029,country
74
+ country_country.00027,country
75
+ pop_pop.00072,pop
76
+ metal_metal.00063,metal
77
+ classical_classical.00011,classical
78
+ blues_blues.00008,blues
79
+ classical_classical.00018,classical
80
+ pop_pop.00079,pop
81
+ jazz_jazz.00059,jazz
82
+ disco_disco.00034,disco
83
+ country_country.00021,country
84
+ hiphop_hiphop.00045,hiphop
85
+ reggae_reggae.00086,reggae
86
+ metal_metal.00069,metal
87
+ classical_classical.00016,classical
88
+ classical_classical.00015,classical
89
+ disco_disco.00036,disco
90
+ blues_blues.00009,blues
91
+ country_country.00026,country
92
+ jazz_jazz.00054,jazz
93
+ disco_disco.00035,disco
94
+ pop_pop.00074,pop
95
+ rock_rock.00093,rock
96
+ hiphop_hiphop.00042,hiphop
97
+ disco_disco.00031,disco
98
+ blues_blues.00005,blues
99
+ disco_disco.00030,disco
100
+ classical_classical.00017,classical
101
+ metal_metal.00064,metal
models.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ import math
16
+ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
17
+
18
+
19
+ def modulate(x, shift, scale):
20
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
21
+
22
+
23
+ #################################################################################
24
+ # Embedding Layers for Timesteps and Class Labels #
25
+ #################################################################################
26
+
27
+ class TimestepEmbedder(nn.Module):
28
+ """
29
+ Embeds scalar timesteps into vector representations.
30
+ """
31
+ def __init__(self, hidden_size, frequency_embedding_size=256):
32
+ super().__init__()
33
+ self.mlp = nn.Sequential(
34
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
35
+ nn.SiLU(),
36
+ nn.Linear(hidden_size, hidden_size, bias=True),
37
+ )
38
+ self.frequency_embedding_size = frequency_embedding_size
39
+
40
+ @staticmethod
41
+ def timestep_embedding(t, dim, max_period=10000):
42
+ """
43
+ Create sinusoidal timestep embeddings.
44
+ :param t: a 1-D Tensor of N indices, one per batch element.
45
+ These may be fractional.
46
+ :param dim: the dimension of the output.
47
+ :param max_period: controls the minimum frequency of the embeddings.
48
+ :return: an (N, D) Tensor of positional embeddings.
49
+ """
50
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
51
+ half = dim // 2
52
+ freqs = torch.exp(
53
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
54
+ ).to(device=t.device)
55
+ args = t[:, None].float() * freqs[None]
56
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
57
+ if dim % 2:
58
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
59
+ return embedding
60
+
61
+ def forward(self, t):
62
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
63
+ t_emb = self.mlp(t_freq)
64
+ return t_emb
65
+
66
+
67
+ class LabelEmbedder(nn.Module):
68
+ """
69
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
70
+ """
71
+ def __init__(self, num_classes, hidden_size, dropout_prob):
72
+ super().__init__()
73
+ use_cfg_embedding = dropout_prob > 0
74
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
75
+ self.num_classes = num_classes
76
+ self.dropout_prob = dropout_prob
77
+
78
+ def token_drop(self, labels, force_drop_ids=None):
79
+ """
80
+ Drops labels to enable classifier-free guidance.
81
+ """
82
+ if force_drop_ids is None:
83
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
84
+ else:
85
+ drop_ids = force_drop_ids == 1
86
+ labels = torch.where(drop_ids, self.num_classes, labels)
87
+ return labels
88
+
89
+ def forward(self, labels, train, force_drop_ids=None):
90
+ use_dropout = self.dropout_prob > 0
91
+ if (train and use_dropout) or (force_drop_ids is not None):
92
+ labels = self.token_drop(labels, force_drop_ids)
93
+ embeddings = self.embedding_table(labels)
94
+ return embeddings
95
+
96
+
97
+ #################################################################################
98
+ # Core DiT Model #
99
+ #################################################################################
100
+
101
+ class DiTBlock(nn.Module):
102
+ """
103
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
104
+ """
105
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
106
+ super().__init__()
107
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
108
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
109
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
110
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
111
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
112
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
113
+ self.adaLN_modulation = nn.Sequential(
114
+ nn.SiLU(),
115
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
116
+ )
117
+
118
+ def forward(self, x, c):
119
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
120
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
121
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
122
+ return x
123
+
124
+
125
+ class FinalLayer(nn.Module):
126
+ """
127
+ The final layer of DiT.
128
+ """
129
+ def __init__(self, hidden_size, patch_size, out_channels):
130
+ super().__init__()
131
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
132
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
133
+ self.adaLN_modulation = nn.Sequential(
134
+ nn.SiLU(),
135
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
136
+ )
137
+
138
+ def forward(self, x, c):
139
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
140
+ x = modulate(self.norm_final(x), shift, scale)
141
+ x = self.linear(x)
142
+ return x
143
+
144
+
145
+ class DiT(nn.Module):
146
+ """
147
+ Diffusion model with a Transformer backbone.
148
+ """
149
+ def __init__(
150
+ self,
151
+ input_size=32,
152
+ patch_size=2,
153
+ in_channels=4,
154
+ hidden_size=1152,
155
+ depth=28,
156
+ num_heads=16,
157
+ mlp_ratio=4.0,
158
+ #num_classes=1000,
159
+ learn_sigma=True,
160
+ ):
161
+ super().__init__()
162
+ self.learn_sigma = learn_sigma
163
+ self.in_channels = in_channels
164
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
165
+ self.patch_size = patch_size
166
+ self.num_heads = num_heads
167
+
168
+ #self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
169
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
170
+ self.t_embedder = TimestepEmbedder(hidden_size)
171
+ #self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
172
+ num_patches = self.x_embedder.num_patches
173
+ # Will use fixed sin-cos embedding:
174
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
175
+
176
+ self.blocks = nn.ModuleList([
177
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
178
+ ])
179
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
180
+ self.initialize_weights()
181
+
182
+ def initialize_weights(self):
183
+ # Initialize transformer layers:
184
+ def _basic_init(module):
185
+ if isinstance(module, nn.Linear):
186
+ torch.nn.init.xavier_uniform_(module.weight)
187
+ if module.bias is not None:
188
+ nn.init.constant_(module.bias, 0)
189
+ self.apply(_basic_init)
190
+
191
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
192
+ #pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
193
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size)
194
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
195
+
196
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
197
+ w = self.x_embedder.proj.weight.data
198
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
199
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
200
+
201
+ # Initialize label embedding table:
202
+ #nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
203
+
204
+ # Initialize timestep embedding MLP:
205
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
206
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
207
+
208
+ # Zero-out adaLN modulation layers in DiT blocks:
209
+ for block in self.blocks:
210
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
211
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
212
+
213
+ # Zero-out output layers:
214
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
215
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
216
+ nn.init.constant_(self.final_layer.linear.weight, 0)
217
+ nn.init.constant_(self.final_layer.linear.bias, 0)
218
+
219
+ def unpatchify(self, x):
220
+ """
221
+ x: (N, T, patch_size**2 * C)
222
+ imgs: (N, H, W, C)
223
+ """
224
+ c = self.out_channels
225
+ p = self.x_embedder.patch_size[0]
226
+ #h = w = int(x.shape[1] ** 0.5)
227
+ h = int(self.x_embedder.grid_size[0])
228
+ w = int(self.x_embedder.grid_size[1])
229
+ #assert h * w == x.shape[1]
230
+
231
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
232
+ x = torch.einsum('nhwpqc->nchpwq', x)
233
+ #imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
234
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
235
+ return imgs
236
+
237
+ def forward(self, x, t, y):
238
+ """
239
+ Forward pass of DiT.
240
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
241
+ t: (N,) tensor of diffusion timesteps
242
+ y: (N,) tensor of class labels
243
+ """
244
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
245
+ t = self.t_embedder(t) # (N, D)
246
+ #y = self.y_embedder(y, self.training) # (N, D)
247
+ y = y.squeeze(dim=1)
248
+ c = t + y # (N, D)
249
+ for block in self.blocks:
250
+ x = block(x, c) # (N, T, D)
251
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
252
+ x = self.unpatchify(x) # (N, out_channels, H, W)
253
+ return x
254
+
255
+ def forward_with_cfg(self, x, t, y, cfg_scale):
256
+ """
257
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
258
+ """
259
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
260
+ half = x[: len(x) // 2]
261
+ combined = torch.cat([half, half], dim=0)
262
+ model_out = self.forward(combined, t, y)
263
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
264
+ # three channels by default. The standard approach to cfg applies it to all channels.
265
+ # This can be done by uncommenting the following line and commenting-out the line following that.
266
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
267
+ eps, rest = model_out[:, :3], model_out[:, 3:]
268
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
269
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
270
+ eps = torch.cat([half_eps, half_eps], dim=0)
271
+ return torch.cat([eps, rest], dim=1)
272
+
273
+
274
+ #################################################################################
275
+ # Sine/Cosine Positional Embedding Functions #
276
+ #################################################################################
277
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
278
+
279
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
280
+ """
281
+ grid_size: int of the grid height and width
282
+ return:
283
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
284
+ """
285
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
286
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
287
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
288
+ grid = np.stack(grid, axis=0)
289
+
290
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
291
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
292
+ if cls_token and extra_tokens > 0:
293
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
294
+ return pos_embed
295
+
296
+
297
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
298
+ assert embed_dim % 2 == 0
299
+
300
+ # use half of dimensions to encode grid_h
301
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
302
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
303
+
304
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
305
+ return emb
306
+
307
+
308
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
309
+ """
310
+ embed_dim: output dimension for each position
311
+ pos: a list of positions to be encoded: size (M,)
312
+ out: (M, D)
313
+ """
314
+ assert embed_dim % 2 == 0
315
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
316
+ omega /= embed_dim / 2.
317
+ omega = 1. / 10000**omega # (D/2,)
318
+
319
+ pos = pos.reshape(-1) # (M,)
320
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
321
+
322
+ emb_sin = np.sin(out) # (M, D/2)
323
+ emb_cos = np.cos(out) # (M, D/2)
324
+
325
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
326
+ return emb
327
+
328
+
329
+ #################################################################################
330
+ # DiT Configs #
331
+ #################################################################################
332
+
333
+ def DiT_XL_2(**kwargs):
334
+ return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
335
+
336
+ def DiT_XL_4(**kwargs):
337
+ return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
338
+
339
+ def DiT_XL_8(**kwargs):
340
+ return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
341
+
342
+ def DiT_L_2(**kwargs):
343
+ return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
344
+
345
+ def DiT_L_4(**kwargs):
346
+ return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
347
+
348
+ def DiT_L_8(**kwargs):
349
+ return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
350
+
351
+ def DiT_B_2(**kwargs):
352
+ return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
353
+
354
+ def DiT_B_4(**kwargs):
355
+ return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
356
+
357
+ def DiT_B_8(**kwargs):
358
+ return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
359
+
360
+ def DiT_S_2(**kwargs):
361
+ return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
362
+
363
+ def DiT_S_4(**kwargs):
364
+ return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
365
+
366
+ def DiT_S_8(**kwargs):
367
+ return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
368
+
369
+
370
+ DiT_models = {
371
+ 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
372
+ 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
373
+ 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
374
+ 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
375
+ }
requirement.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ aiohttp==3.10.0
3
+ attrs==23.2.0
4
+ audioread==3.0.1
5
+ cffi==1.16.0
6
+ datasets==2.20.0
7
+ einops==0.8.0
8
+ fsspec==2024.5.0
9
+ GitPython==3.1.43
10
+ h5py==3.11.0
11
+ huggingface-hub==0.24.5
12
+ joblib==1.4.2
13
+ librosa==0.10.2.post1
14
+ numpy==1.26.4
15
+ pandas==2.2.2
16
+ pydub==0.25.1
17
+ scipy==1.13.1
18
+ sentence-transformers==3.1.0
19
+ six==1.16.0
20
+ soundfile==0.12.1
21
+ timm==0.9.2
22
+ tqdm==4.66.4
23
+ torch==2.0.0
24
+ torchmetrics==1.4.1
25
+ transformers==4.43.3
26
+ tensorboard
sample.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import h5py
4
+ import random
5
+ import numpy as np
6
+ import soundfile as sf
7
+ from models import DiT
8
+ from diffusion import create_diffusion
9
+ from tqdm import tqdm
10
+ import sys
11
+ sys.path.append('./tools/bigvgan_v2_22khz_80band_256x')
12
+ from bigvgan import BigVGAN
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+ import argparse
16
+
17
+ device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
18
+
19
+ class MelToAudio_bigvgan(nn.Module):
20
+ def __init__(self):
21
+ super().__init__()
22
+ self.vocoder = BigVGAN.from_pretrained('/home/zheqid/workspace/music_dit/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
23
+ self.vocoder.remove_weight_norm()
24
+
25
+ def __call__(self, z):
26
+ x = self.mel_to_audio(z)
27
+ return x
28
+
29
+ def mel_to_audio(self, x):
30
+ with torch.no_grad():
31
+ self.vocoder.eval()
32
+ y = self.vocoder(x[:, :, :])
33
+ y = y.squeeze(0)
34
+ return y
35
+
36
+ vocoder = MelToAudio_bigvgan().to(device)
37
+
38
+ def load_trained_model(checkpoint_path):
39
+ model = DiT(
40
+ input_size=(80, 800),
41
+ patch_size=8,
42
+ in_channels=1,
43
+ hidden_size=384,
44
+ depth=12,
45
+ num_heads=6,
46
+ )
47
+ model.to(device)
48
+ checkpoint = torch.load(checkpoint_path)
49
+ model.load_state_dict(checkpoint['model_state_dict'])
50
+ model.eval()
51
+ return model
52
+
53
+ def load_all_meta_and_mel_from_h5(h5_file):
54
+ with h5py.File(h5_file, 'r') as f:
55
+ keys = list(f.keys())
56
+ for key in keys:
57
+ meta_latent = torch.FloatTensor(f[key]['meta'][:]).to(device)
58
+ mel = torch.FloatTensor(f[key]['mel'][:]).to(device)
59
+ yield key, meta_latent, mel
60
+
61
+ def extract_random_mel_segment(mel, segment_length=800):
62
+ total_length = mel.shape[2]
63
+ if total_length > segment_length:
64
+ start = np.random.randint(0, total_length - segment_length)
65
+ mel_segment = mel[:, :, start:start + segment_length]
66
+ else:
67
+ padding = segment_length - total_length
68
+ mel_segment = F.pad(mel, (0, padding), mode='constant', value=0)
69
+
70
+ mel_segment = (mel_segment + 10) / 20
71
+ return mel_segment
72
+
73
+ def infer_and_generate_audio(model, diffusion, meta_latent):
74
+ latent_size = (80, 800)
75
+ z = torch.randn(1, 1, latent_size[0], latent_size[1], device=device)
76
+ model_kwargs = dict(y=meta_latent)
77
+
78
+ with torch.no_grad():
79
+ samples = diffusion.p_sample_loop(
80
+ model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
81
+ )
82
+
83
+ return samples
84
+
85
+ def save_audio(mel, vocoder, output_path, sample_rate=24000):
86
+ with torch.no_grad():
87
+ if mel.dim() == 4 and mel.shape[1] == 1:
88
+ mel = mel[0, 0, :, :]
89
+ elif mel.dim() == 3 and mel.shape[0] == 1:
90
+ mel = mel[0]
91
+ else:
92
+ raise ValueError(f"Unexpected mel shape: {mel.shape}")
93
+
94
+ mel = mel.unsqueeze(0)
95
+ wav = vocoder(mel * 20 - 10).cpu().numpy()
96
+
97
+ sf.write(output_path, wav[0], samplerate=sample_rate)
98
+ print(f"Saved audio to: {output_path}")
99
+
100
+ def main():
101
+ parser = argparse.ArgumentParser(description='Generate audio using DiT and BigVGAN')
102
+ parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
103
+ parser.add_argument('--h5_file', type=str, required=True, help='Path to input H5 file')
104
+ parser.add_argument('--output_gt_dir', type=str, required=True, help='Directory to save ground truth audio')
105
+ parser.add_argument('--output_gen_dir', type=str, required=True, help='Directory to save generated audio')
106
+ parser.add_argument('--segment_length', type=int, default=800, help='Segment length for mel slices (default: 800)')
107
+ parser.add_argument('--sample_rate', type=int, default=22050, help='Sample rate for output audio (default: 24000)')
108
+ args = parser.parse_args()
109
+
110
+ model = load_trained_model(args.checkpoint)
111
+ diffusion = create_diffusion(timestep_respacing="")
112
+
113
+ for i, (key, meta_latent, mel) in enumerate(tqdm(load_all_meta_and_mel_from_h5(args.h5_file))):
114
+ mel_segment = extract_random_mel_segment(mel, segment_length=args.segment_length)
115
+
116
+ ground_truth_wav_path = os.path.join(args.output_gt_dir, f"{key}.wav")
117
+ save_audio(mel_segment, vocoder, ground_truth_wav_path, sample_rate=args.sample_rate)
118
+
119
+ generated_mel = infer_and_generate_audio(model, diffusion, meta_latent)
120
+
121
+ output_wav_path = os.path.join(args.output_gen_dir, f"{key}.wav")
122
+ save_audio(generated_mel, vocoder, output_wav_path, sample_rate=args.sample_rate)
123
+
124
+ if __name__ == "__main__":
125
+ main()
126
+
127
+ ### how to use
128
+ '''
129
+ python sample.py --checkpoint ./gtzan-ck/model_epoch_20000.pt \
130
+ --h5_file ./dataset/gtzan_test.h5 \
131
+ --output_gt_dir ./sample/gn \
132
+ --output_gen_dir ./sample/gt \
133
+ --segment_length 800 \
134
+ --sample_rate 22050
135
+ '''
sample/gn/blues_blues.00000.mp3 ADDED
Binary file (58.2 kB). View file
 
sample/gn/blues_blues.00001.mp3 ADDED
Binary file (57.9 kB). View file
 
sample/gn/blues_blues.00002.mp3 ADDED
Binary file (56.7 kB). View file
 
sample/gn/blues_blues.00003.mp3 ADDED
Binary file (55.7 kB). View file
 
sample/gn/blues_blues.00004.mp3 ADDED
Binary file (53.4 kB). View file
 
sample/gn/blues_blues.00005.mp3 ADDED
Binary file (60.2 kB). View file
 
sample/gn/blues_blues.00006.mp3 ADDED
Binary file (53.8 kB). View file
 
sample/gn/blues_blues.00007.mp3 ADDED
Binary file (55 kB). View file
 
sample/gn/blues_blues.00008.mp3 ADDED
Binary file (54.8 kB). View file
 
sample/gn/blues_blues.00009.mp3 ADDED
Binary file (54.1 kB). View file
 
sample/gt/blues_blues.00000.mp3 ADDED
Binary file (54.8 kB). View file
 
sample/gt/blues_blues.00001.mp3 ADDED
Binary file (55.4 kB). View file
 
sample/gt/blues_blues.00002.mp3 ADDED
Binary file (63.9 kB). View file
 
sample/gt/blues_blues.00003.mp3 ADDED
Binary file (57.2 kB). View file
 
sample/gt/blues_blues.00004.mp3 ADDED
Binary file (59.8 kB). View file
 
sample/gt/blues_blues.00005.mp3 ADDED
Binary file (58.2 kB). View file
 
sample/gt/blues_blues.00006.mp3 ADDED
Binary file (60.4 kB). View file
 
sample/gt/blues_blues.00007.mp3 ADDED
Binary file (59.6 kB). View file
 
sample/gt/blues_blues.00008.mp3 ADDED
Binary file (56.6 kB). View file
 
sample/gt/blues_blues.00009.mp3 ADDED
Binary file (52.6 kB). View file
 
tools/bigvgan_v2_22khz_80band_256x ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 633ff708ed5b74903e86ff1298cf4a98e921c513
tools/gtzan2h5.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import h5py
4
+ import random
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from sentence_transformers import SentenceTransformer
8
+ import librosa
9
+ from bigvgan_v2_22khz_80band_256x.meldataset import get_mel_spectrogram
10
+ from types import SimpleNamespace
11
+ from torch import nn
12
+ from einops import rearrange
13
+ import json
14
+ import argparse
15
+
16
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+
18
+ # Load SentenceTransformer model
19
+ sentence_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
20
+
21
+ class AudioToMel_bigvgan(nn.Module):
22
+ def __init__(self, config_path):
23
+ super().__init__()
24
+
25
+ # Load configuration file
26
+ with open(config_path, 'r') as f:
27
+ self.h = json.load(f, object_hook=lambda d: SimpleNamespace(**d))
28
+
29
+ def __call__(self, audio):
30
+ x = self.audio_to_mel(audio) # Extract mel spectrogram
31
+ return x
32
+
33
+ def audio_to_mel(self, audio):
34
+ # Convert to mono channel
35
+ audio = audio[:, 0, :] # Assuming input is (b, c, t), take first channel
36
+ audio = torch.tensor(audio)
37
+
38
+ # Extract mel spectrogram
39
+ x = get_mel_spectrogram(
40
+ wav=audio[:, :],
41
+ h=self.h
42
+ ) # Shape: (b, f, t)
43
+
44
+ return x
45
+
46
+ # Initialize BigVGAN Mel extraction model
47
+ audio_to_mel_model = None # Placeholder, will be initialized later
48
+
49
+ def extract_mel_features(audio_path, sr=24000):
50
+ """
51
+ Extract Mel features using BigVGAN model, with normalization.
52
+ :param audio_path: Path to the audio file
53
+ :param sr: Sampling rate (default 24000)
54
+ :return: Mel spectrogram
55
+ """
56
+ # Load and normalize audio
57
+ wav, _ = librosa.load(audio_path, sr=sr)
58
+ max_val = np.max(np.abs(wav))
59
+ if max_val > 1.0:
60
+ wav = wav / max_val
61
+
62
+ wav_tensor = torch.FloatTensor(wav).unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, T)
63
+
64
+ # Extract Mel spectrogram
65
+ mel_spectrogram = audio_to_mel_model(wav_tensor).cpu().numpy()
66
+ return mel_spectrogram
67
+
68
+ def get_embedding_from_folder_name(folder_name):
69
+ """
70
+ Convert folder name into embedding using SentenceTransformer.
71
+ :param folder_name: Name of the folder
72
+ :return: Corresponding embedding
73
+ """
74
+ try:
75
+ embedding = sentence_model.encode([folder_name])
76
+ return embedding
77
+ except Exception as e:
78
+ print(f"Error encoding label for {folder_name}: {e}")
79
+ return None
80
+
81
+ def process_single_file(file_info):
82
+ """
83
+ Process a single audio file and return its key, mel features, and meta embedding.
84
+ :param file_info: (root_dir, audio_path) tuple
85
+ :return: (key, mel_features, embedding)
86
+ """
87
+ root_dir, audio_path = file_info
88
+ try:
89
+ # Get file and folder names
90
+ file_name_with_ext = os.path.basename(audio_path)
91
+ folder_name = os.path.basename(os.path.dirname(audio_path))
92
+
93
+ # Extract Mel features
94
+ mel_features = extract_mel_features(audio_path)
95
+
96
+ # Get embedding from folder name
97
+ embedding = get_embedding_from_folder_name(folder_name)
98
+
99
+ if embedding is None:
100
+ return None, None, None
101
+
102
+ key = os.path.relpath(audio_path, root_dir).replace('/', '_').replace('\\', '_')
103
+ return key, mel_features, embedding
104
+ except Exception as e:
105
+ print(f"Error processing {audio_path}: {e}")
106
+ return None, None, None
107
+
108
+ def process_and_save_files(audio_files, output_h5_file):
109
+ """
110
+ Process audio files and save Mel features and meta embeddings to an HDF5 file.
111
+ :param audio_files: List of audio file paths
112
+ :param output_h5_file: Path to the HDF5 output file
113
+ """
114
+ with h5py.File(output_h5_file, 'w') as h5f:
115
+ for file_info in tqdm(audio_files, desc="Processing audio files"):
116
+ key, mel_features, embedding = process_single_file(file_info)
117
+ if key is not None and mel_features is not None and embedding is not None:
118
+ group = h5f.create_group(key)
119
+ group.create_dataset('mel', data=mel_features)
120
+ group.create_dataset('meta', data=embedding)
121
+
122
+ def process_audio_files(root_dir, output_h5_file):
123
+ """
124
+ Walk through a directory and process all audio files, saving them to an HDF5 file.
125
+ :param root_dir: Root directory containing audio files
126
+ :param output_h5_file: Path to the HDF5 output file
127
+ """
128
+ audio_files = []
129
+
130
+ for subdir, _, files in os.walk(root_dir):
131
+ for file in files:
132
+ if file.endswith('.wav') or file.endswith('.mp3') or file.endswith('.flac'):
133
+ audio_path = os.path.join(subdir, file)
134
+ audio_files.append((root_dir, audio_path))
135
+
136
+ random.shuffle(audio_files)
137
+
138
+ print(f"Processing {len(audio_files)} files...")
139
+ process_and_save_files(audio_files, output_h5_file)
140
+
141
+ if __name__ == "__main__":
142
+ # Argument parser for command line arguments
143
+ parser = argparse.ArgumentParser(description="Process audio files and extract mel features.")
144
+ parser.add_argument('--root_dir', type=str, required=True, help='Root directory of the audio files.')
145
+ parser.add_argument('--output_h5_file', type=str, required=True, help='Output HDF5 file path.')
146
+ parser.add_argument('--config_path', type=str, required=True, help='Path to the BigVGAN config.json file.')
147
+ parser.add_argument('--sr', type=int, default=22050, help='Sampling rate (default: 24000).')
148
+
149
+ args = parser.parse_args()
150
+
151
+ # Initialize the BigVGAN Mel extraction model
152
+ audio_to_mel_model = AudioToMel_bigvgan(args.config_path).to(device)
153
+
154
+ # Process audio files
155
+ process_audio_files(args.root_dir, args.output_h5_file)
156
+
157
+ print(f"Processing completed. H5 file saved at: {args.output_h5_file}")
158
+
159
+ ### how to use
160
+ # python process_audio.py --root_dir /path/to/audio/files --output_h5_file /path/to/output.h5 --config_path --config_path bigvgan_v2_22khz_80band_256x/config.json --sr 22050
tools/gtzan_split.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import argparse
4
+ from pydub import AudioSegment
5
+
6
+ class GTZAN:
7
+ def __init__(self, root_dir, output_dir, labels):
8
+ """
9
+ Args:
10
+ root_dir (str): Root directory of the dataset.
11
+ output_dir (str): Output directory to save converted MP3 files.
12
+ labels (list): List of genres in the dataset.
13
+ """
14
+ self.root_dir = root_dir
15
+ self.output_dir = output_dir
16
+ self.labels = labels
17
+
18
+ # Create output directory structure for MP3 files
19
+ self.create_output_dirs()
20
+
21
+ def create_output_dirs(self):
22
+ """Create directories to store train and test audio files"""
23
+ for split in ['train', 'test']:
24
+ for genre in self.labels:
25
+ genre_dir = os.path.join(self.output_dir, split, genre)
26
+ os.makedirs(genre_dir, exist_ok=True)
27
+
28
+ def split_train_test(self, audio_names, test_fold):
29
+ """
30
+ Split the dataset into train and test sets based on test_fold.
31
+ E.g., test_ids = [30, 31, 32, ..., 39].
32
+ """
33
+ test_audio_names = []
34
+ train_audio_names = []
35
+
36
+ test_ids = range(test_fold * 10, (test_fold + 1) * 10)
37
+
38
+ for audio_name in audio_names:
39
+ # Extract the numeric ID from the audio file name
40
+ audio_id = int(re.search(r'\d+', audio_name).group())
41
+
42
+ if audio_id in test_ids:
43
+ test_audio_names.append(audio_name)
44
+ else:
45
+ train_audio_names.append(audio_name)
46
+
47
+ return train_audio_names, test_audio_names
48
+
49
+ def convert_and_save(self, file_path, target_path):
50
+ """Convert AU format to MP3 and save to target path"""
51
+ audio = AudioSegment.from_file(file_path, format="au")
52
+ audio.export(target_path, format="mp3")
53
+ print(f"Converted and saved {target_path}")
54
+
55
+ def process_genre(self, genre, test_fold):
56
+ """Process a single genre, split the dataset, and convert formats"""
57
+ genre_path = os.path.join(self.root_dir, genre)
58
+ audio_files = os.listdir(genre_path)
59
+
60
+ # Split the dataset
61
+ train_files, test_files = self.split_train_test(audio_files, test_fold)
62
+
63
+ # Process training set
64
+ for audio_name in train_files:
65
+ file_path = os.path.join(genre_path, audio_name)
66
+ target_path = os.path.join(self.output_dir, 'train', genre, audio_name.replace('.au', '.mp3'))
67
+ self.convert_and_save(file_path, target_path)
68
+
69
+ # Process test set
70
+ for audio_name in test_files:
71
+ file_path = os.path.join(genre_path, audio_name)
72
+ target_path = os.path.join(self.output_dir, 'test', genre, audio_name.replace('.au', '.mp3'))
73
+ self.convert_and_save(file_path, target_path)
74
+
75
+ def process_dataset(self):
76
+ """Process the entire GTZAN dataset and split it into train and test sets"""
77
+ for idx, genre in enumerate(self.labels):
78
+ print(f"Processing genre: {genre}...")
79
+ test_fold = idx % 10 # Each genre has a different test_fold
80
+ self.process_genre(genre, test_fold)
81
+
82
+
83
+ if __name__ == "__main__":
84
+ # Define argument parser
85
+ parser = argparse.ArgumentParser(description="GTZAN Dataset Converter")
86
+ parser.add_argument('--root_dir', type=str, required=True, help='Root directory of the GTZAN dataset')
87
+ parser.add_argument('--output_dir', type=str, required=True, help='Directory to save the converted MP3 files')
88
+ args = parser.parse_args()
89
+
90
+ # Example genre labels in the GTZAN dataset
91
+ labels = ["blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock"]
92
+
93
+ # Initialize the GTZAN processor
94
+ gtzan = GTZAN(args.root_dir, args.output_dir, labels)
95
+ gtzan.process_dataset()
96
+
97
+ ### how to use
98
+ # python gtzan_converter.py --root_dir /path/to/gtzan/genres --output_dir /path/to/output/directory
train.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import h5py
3
+ import torch
4
+ import random
5
+ import yaml
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from tqdm import tqdm
9
+ from diffusion import create_diffusion
10
+ from models import DiT
11
+ import torch.optim as optim
12
+ from torch.utils.tensorboard import SummaryWriter # TensorBoard
13
+
14
+ # Load hyperparameters from YAML file
15
+ with open('config/train.yaml', 'r') as file:
16
+ config = yaml.safe_load(file)
17
+
18
+ # Create TensorBoard writer
19
+ writer = SummaryWriter()
20
+
21
+ class MelMetaDataset(Dataset):
22
+ def __init__(self, h5_file, mel_frames):
23
+ self.h5_file = h5_file
24
+ self.mel_frames = mel_frames
25
+ with h5py.File(h5_file, 'r') as f:
26
+ self.keys = list(f.keys())
27
+
28
+ def __len__(self):
29
+ return len(self.keys)
30
+
31
+ def pad_mel(self, mel_segment, total_frames):
32
+ if total_frames < self.mel_frames:
33
+ padding_frames = self.mel_frames - total_frames
34
+ mel_segment = F.pad(mel_segment, (0, padding_frames), mode='constant', value=0)
35
+ return mel_segment
36
+
37
+ def __getitem__(self, idx):
38
+ key = self.keys[idx]
39
+ with h5py.File(self.h5_file, 'r') as f:
40
+ mel = torch.FloatTensor(f[key]['mel'][:])
41
+ meta_latent = torch.FloatTensor(f[key]['meta'][:])
42
+
43
+ total_frames = mel.shape[2]
44
+ if total_frames > self.mel_frames:
45
+ start_frame = random.randint(0, total_frames - self.mel_frames)
46
+ mel_segment = mel[:, :, start_frame:start_frame + self.mel_frames]
47
+ else:
48
+ mel_segment = self.pad_mel(mel, total_frames)
49
+ mel_segment = (mel_segment + 10) / 20
50
+ return mel_segment, meta_latent
51
+
52
+ # Dataset & DataLoader
53
+ dataset = MelMetaDataset(config['h5_file_path'], mel_frames=config['mel_frames'])
54
+ dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)
55
+
56
+ # Model and optimizer
57
+ device = config['device'] if torch.cuda.is_available() else "cpu"
58
+ model = DiT(
59
+ input_size=tuple(config['input_size']),
60
+ patch_size=config['patch_size'],
61
+ in_channels=config['in_channels'],
62
+ hidden_size=config['hidden_size'],
63
+ depth=config['depth'],
64
+ num_heads=config['num_heads'],
65
+ )
66
+ model.to(device)
67
+
68
+ # Create diffusion model
69
+ diffusion = create_diffusion(timestep_respacing="")
70
+
71
+ # Optimizer
72
+ optimizer = optim.AdamW(model.parameters(), lr=config['lr'])
73
+
74
+ # Create directory to save model checkpoints
75
+ os.makedirs(config['checkpoint_dir'], exist_ok=True)
76
+
77
+ # Training function
78
+ def train_model(model, dataloader, optimizer, diffusion, num_epochs, sample_interval):
79
+ model.train()
80
+ for epoch in range(num_epochs):
81
+ total_loss = 0.0
82
+ for step, (mel_segment, meta_latent) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")):
83
+ mel_segment = mel_segment.to(device)
84
+ meta_latent = meta_latent.to(device)
85
+ t = torch.randint(0, diffusion.num_timesteps, (mel_segment.shape[0],), device=device)
86
+ model_kwargs = dict(y=meta_latent)
87
+ loss_dict = diffusion.training_losses(model, mel_segment, t, model_kwargs)
88
+ loss = loss_dict["loss"].mean()
89
+
90
+ optimizer.zero_grad()
91
+ loss.backward()
92
+ optimizer.step()
93
+
94
+ total_loss += loss.item()
95
+
96
+ avg_loss = total_loss / len(dataloader)
97
+ print(f"Epoch {epoch + 1}/{num_epochs}: Average Loss: {avg_loss:.4f}")
98
+ writer.add_scalar('Loss/epoch', avg_loss, epoch + 1)
99
+
100
+ if (epoch + 1) % sample_interval == 0:
101
+ checkpoint = {
102
+ 'epoch': epoch + 1,
103
+ 'model_state_dict': model.state_dict(),
104
+ 'optimizer_state_dict': optimizer.state_dict(),
105
+ }
106
+ checkpoint_path = f"{config['checkpoint_dir']}/model_epoch_{epoch + 1}.pt"
107
+ torch.save(checkpoint, checkpoint_path)
108
+ print(f"Model checkpoint saved at epoch {epoch + 1}")
109
+
110
+ # Start training
111
+ train_model(model, dataloader, optimizer, diffusion, num_epochs=config['num_epochs'], sample_interval=config['sample_interval'])
112
+
113
+ # Close TensorBoard writer
114
+ writer.close()