We introduce ReMDM, a simple and general framework to design remasking samplers for masked discrete diffusion models. In this repo, we provide our implementation of different ReMDM strategies for unconditional text generation on OpenWebText. We also provide a demo in this notebook showing how to download the MDLM checkpoint and implement ReMDM-loop on top of it.
Our main add-ons to MDLM are:
- Evaluation Metrics
- Add MAUVE computation code.
- Add entropy computation code.
- Sampling Tricks
- Replace fp32 gumbel noise with the correct fp64 gumbel noise.
- Implement nucleus sampling.
- ReMDM Strategies
- Implement different ReMDM strategies, including ReMDM-cap, ReMDM-rescale, ReMDM-conf, and ReMDM-loop.
- Predictor-Corrector Samplers
- Implement forward-backward and discrete flow matching corrector samplers as extra baselines.
To get started, create a conda environment containing the required dependencies.
conda env create -f requirements.yaml
conda activate remdm
# install flash-attention separately
pip install flash-attn==2.6.3
Create the following directories to store saved models and slurm logs:
mkdir outputs
mkdir watch_folder
Download checkpoints from this Google Drive folder released by the MDLM repo and put them under
the following directory ./outputs/checkpoints
Below, we demonstrate how to generate text samples using different models and samplers. It should be as easy as replacing the "YOUR-BASE-PATH" with your path to the repository and running the following SLURM bash scripts.
- AR
sbatch scripts/ar.sh
- SEDD
sbatch scripts/sedd.sh
- MDLM
sbatch scripts/mdlm.sh
- Forward-Backward corrector
sbatch scripts/fb.sh
- Discrete flow matching corrector
sbatch scripts/dfm.sh
- ReMDM
sbatch scripts/remdm-{YOUR-CHOSEN-STRATEGY}.sh
To conduct your own hyperparameter search, change the following values in the bash scripts:
- sampling_steps: number of sample steps
- p: top-p value in nucleus sampling
- eta:
$\eta$ value in different ReMDM strategies - t_on:
$t_{on}$ in ReMDM-loop - t_off:
$t_{off}$ in ReMDM-loop - alpha_on:
$\alpha(t_{on})$ in ReMDM-loop
This repository was built off of MDLM which was based on SEDD.
@article{wang2025remasking,
title={Remasking Discrete Diffusion Models with Inference-Time Scaling},
author={Wang, Guanghan and Schiff, Yair and Sahoo, Subham and Kuleshov, Volodymyr},
journal={arXiv preprint arXiv:2503.00307},
year={2025}
}