--- license: mit datasets: - Skylion007/openwebtext tags: - diffusion --- # Generalized Interpolating Discrete Diffusion By Dimitri von Rütte, Janis Fluri, Yuhui Ding, Antonio Orvieto, Bernhard Schölkopf, Thomas Hofmann
---  We present Generalized Interpolating Discrete Diffusion (GIDD), a novel framework for training discrete diffusion models. GIDD can be seen as a generalization of the popular masked diffusion paradigm (MDM) to any diffusion process that can be written as a linear interpolation between a data distribution and some (time-variable) mixing distribution. We demonstrate the flexibility of GIDD by training models on a hybrid diffusion process that combines masking and uniform noise. The model therefore is trained to not only "fill in the blanks" (i.e. the masked tokens), but also to consider the correctness of already-filled-in tokens and, if necessary, replace incorrect tokens with more plausible ones. We show that GIDD models trained on hybrid noise have better sample quality (generative PPL) than mask-only models, and that they are able to identify and correct their own mistakes in generated samples through a self-correction step. This repository contains all training and evaluation code necessary for reproducing the results in the paper. ### Pretrained Checkpoints Our trained checkpoints are available under the following links. All of them have been trained on 131B tokens from the [OpenWebText](https://huggingface.co/datasets/Skylion007/openwebtext) dataset with the [GPT-2 tokenizer](https://huggingface.co/openai-community/gpt2). | Model | Small (169.6M) | Base (424.5M) | |-------|-------|------| | GIDD+ (p_u = 0.0) | [dvruette/gidd-small-p_unif-0.0](https://huggingface.co/dvruette/gidd-small-p_unif-0.0) | dvruette/gidd-base-p_unif-0.0 | | GIDD+ (p_u = 0.1) | [dvruette/gidd-small-p_unif-0.1](https://huggingface.co/dvruette/gidd-small-p_unif-0.1) | [dvruette/gidd-base-p_unif-0.1](https://huggingface.co/dvruette/gidd-base-p_unif-0.1) | | GIDD+ (p_u = 0.2) | [dvruette/gidd-small-p_unif-0.2](https://huggingface.co/dvruette/gidd-small-p_unif-0.2) | [dvruette/gidd-base-p_unif-0.2](https://huggingface.co/dvruette/gidd-base-p_unif-0.2) | ## Use the Model 1. Install the GIDD repo: ```bash pip install git+https://github.com/dvruette/gidd ``` 2. For quickly downloading a trained model and playing around with it, the `GiddPipeline` class is most convenient: ```python from gidd import GiddPipeline # Download a pretrained model from HuggingFace pipe = GiddPipeline.from_pretrained("dvruette/gidd-base-p_unif-0.0", trust_remote_code=True) # Generate samples texts = pipe.generate(num_samples=4, num_inference_steps=128) # Run self-correction step corrected_texts = pipe.self_correction(texts, num_inference_steps=128, early_stopping=True, temperature=0.1) print(corrected_texts) ```