Musimple:Text2Music with DiT Made simple
Due to repository size limitations, the complete dataset and checkpoints are available on Hugging Face: https://huggingface.co/ZheqiDAI/Musimple.
Introduction
This repository provides a simple and clear implementation of a Text-to-Music Generation pipeline using a DiT (Diffusion Transformer) model. The codebase includes key components such as model training, inference, and evaluation. We use the GTZAN dataset as an example to demonstrate a minimal, working pipeline for text-conditioned music generation.
The repository is designed to be easy to use and customize, making it simple to reproduce our results on a single NVIDIA RTX 4090 GPU. Additionally, the code is structured to be flexible, allowing you to modify it for your own tasks and datasets.
We plan to continue maintaining and improving this repository with new features, model improvements, and extended documentation in the future.
Features
- Text-to-Music Generation: Generate music directly from text descriptions using a DiT model.
- GTZAN Example: A simple pipeline using the GTZAN dataset to demonstrate the workflow.
- End-to-End Pipeline: Includes model training, inference, and evaluation with support for generating audio files.
- Customizable: Easy to modify and extend for different datasets or use cases.
- Single GPU Training: Optimized for training on a single RTX 4090 GPU but adaptable to different hardware setups.
Requirements
Before using the code, ensure that the following dependencies are installed:
- Python >= 3.9
- CUDA (if available)
- Required Python libraries from
requirements.txt
You can install the dependencies using:
conda create -n musimple python=3.9
conda activate musimple
pip install -r requirements.txt
Data Preprocessing
To begin with, you will need to download the GTZAN dataset. Once downloaded, you can use the gtzan_split.py
script located in the tools
directory to split the dataset into training and testing sets. Run the following command:
python gtzan_split.py --root_dir /path/to/gtzan/genres --output_dir /path/to/output/directory
Next, convert the audio files into an HDF5 format using the gtzan2h5.py script:
python gtzan2h5.py --root_dir /path/to/audio/files --output_h5_file /path/to/output.h5 --config_path bigvgan_v2_22khz_80band_256x/config.json --sr 22050
If this process seems cumbersome, don’t worry! We have already preprocessed the dataset, and you can find it in the musimple/dataset directory. You can download and use this data directly to skip the preprocessing steps.
In this preprocessing stage, there are two main parts:
Text to Latent Transformation: We use a Sentence Transformer to convert text labels into latent representations. Audio to Mel Spectrogram: The original audio files are converted into mel spectrograms. Both the latent representations and mel spectrograms are stored in an HDF5 file, making them easily accessible during training and inference.
Training
To begin training, simply navigate to the Musimple
directory and run the following command:
cd Musimple
python train.py
All training-related parameters can be adjusted in the configuration file located at:
./config/train.yaml
This allows you to easily modify aspects like the learning rate, batch size, number of epochs, and more to suit your hardware or dataset requirements.
We also provide a pre-trained checkpoint trained for two days on a single NVIDIA RTX 4090. You can use this checkpoint for inference or fine-tuning. The key training parameters for this checkpoint are as follows:
batch_size
: 48mel_frames
: 800lr
: 0.0001num_epochs
: 100000sample_interval
: 250h5_file_path
: './dataset/gtzan_train.h5'device
: 'cuda:4'input_size
: [80, 800]patch_size
: 8in_channels
: 1hidden_size
: 384depth
: 12num_heads
: 6checkpoint_dir
: 'gtzan-ck'
You can modify the model architecture and parameters in the train.yaml
configuration file to compare your models against ours. We will continue to release more checkpoints and models in future updates.
Inference
Once you have trained your own model, you can perform inference using the trained model. To do so, run the following command:
python sample.py --checkpoint ./gtzan-ck/model_epoch_20000.pt \
--h5_file ./dataset/gtzan_test.h5 \
--output_gt_dir ./sample/gt \
--output_gen_dir ./sample/gn \
--segment_length 800 \
--sample_rate 22050
You can also try running inference using our pre-trained model to familiarize yourself with the inference process. We have saved some inference results in the sample folder as a demo. However, due to the limited size of our model, the generated results are not of the highest quality and are intended as simple examples to guide further evaluation.
Evaluation
For the evaluation phase, we highly recommend creating a new environment and using the evaluation library available at Generated Music Evaluation. This repository provides detailed instructions on setting up the environment and how to use the evaluation tools. New features and functionality will be added to this library over time.
Once you have set up the environment following the instructions from the evaluation repository, you can run the following script to evaluate your generated music:
python eval.py \
--ref_path ../sample/gt \
--gen_path ../sample/gn \
--id2text_csv_path ../gtzan-test.csv \
--output_path ./output \
--device_id 0 \
--batch_size 32 \
--original_sample_rate 24000 \
--fad_sample_rate 16000 \
--kl_sample_rate 16000 \
--clap_sample_rate 48000 \
--run_fad 1 \
--run_kl 1 \
--run_clap 1
This script evaluates the generated music against reference music, producing evaluation metrics such as CLAP, KL, and FAD scores.
To-Do
The following features and improvements are planned for future updates:
- EMA Model: Implement Exponential Moving Average (EMA) for model weights to stabilize training and improve final generation quality.
- Long-Term Music Fine-tuning: Explore fine-tuning the model to generate longer-term music with more coherent structures.
- VAE Integration: Integrate a Variational Autoencoder (VAE) to improve latent space representations and potentially enhance generation diversity.
- T5-based Text Conditioning: Add T5 to enhance text conditioning, improving the control and accuracy of the text-to-music generation process.