PRISM is a plug-in remasking framework that fine-tunes any pretrained Masked Diffusion Model (MDM) to predict per-token quality in a single forward pass, enabling self-correction at inference. It adds a lightweight objective without changing the base MDM architecture.
-
Theory-backed head: estimates per-token quality
$q_i \approx \Pr[x_i = y_i \mid y \oplus m_i]$ . - Practical: simple to fine-tune; Improves over remasking baselines on Sudoku, unconditional text (170M), and code (LLaDA-8B on MBPP) with modest fine-tuning compute.
🔔 News
- [2025.11.09] PRISM codebase and Sudoku dataset have been open-sourced.
- [2025.10.01] Paper posted on arXiv.
To get started, create a conda environment and install the FlashAttention:
conda env create -f requirements.yaml
conda activate PRISM
pip install flash-attn==2.6.3Create the following directories to store saved models and logs:
mkdir outputs
mkdir watch_folderIf you are pre-training or fine-tuning on the Sudoku dataset (48K), ensure the CSV files in the ./data/ directory are placed in your ${data.cache_dir}/sudoku/.
You can download the pretrained MDLM checkpoint (OWT) from this Google Drive folder, provided by the MDLM repository. Place the downloaded files in the ./outputs/checkpoints directory.
We provide pre-training script for Sudoku only:
./scripts/pretrain_sudoku.shFor PRISM fine-tuning, we provide the following scripts:
-
Sudoku:
./scripts/finetune_sudoku_prism.sh
-
OWT:
./scripts/finetune_owt_prism.sh
-
LLaDA:
./PRISM_llada/training_scripts/test_run.sh
We provide evaluation scripts for the fine-tuned module using a static sampler:
-
Sudoku (Success Rate):
./scripts/sample_sudoku_prism.sh
-
OWT (MAUVE, Gen PPL, Entropy): For faster evaluation, we support multi-node text generation:
./scripts/sample_owt_prism.sh
You can enhance the generated texts by applying a loop strategy. Modify the following parameters in the bash scripts to customize the behavior:
- sampling.loop_steps: Number of loop iterations to perform.
- sampling.num_remask_loop: Number of tokens to remask during each iteration.
-
LLaDA (HumanEval, MBPP) : Code evaluation follows with two procedures;
- generate samples and save .jsonl file : PRISM_llada/generate_samples.py
- evalute .jsonl file: PRISM_llada/eval_mbpp.py
We provide baseline implementations for comparison:
-
Sudoku:
- Token-Critic: An unofficial training recipe for the Token-Critic approach on the Sudoku dataset:
./scripts/finetune_sudoku_token_critic.sh
- ReMDM-conf Sampler: Evaluate using the ReMDM-conf sampler:
./scripts/sample_sudoku_remdm-conf.sh
- Token-Critic: An unofficial training recipe for the Token-Critic approach on the Sudoku dataset:
-
OWT:
- ReMDM-cap Sampler: Evaluate using the ReMDM-cap sampler:
./scripts/sample_owt_remdm.sh
- ReMDM-cap Sampler: Evaluate using the ReMDM-cap sampler:
This repository was built on top of ReMDM, which was based on MDLM and SEDD.
@article{kim2025fine,
title={Fine-Tuning Masked Diffusion for Provable Self-Correction},
author={Kim, Jaeyeon and Kim, Seunggeun and Lee, Taekyun and Pan, David Z and Kim, Hyeji and Kakade, Sham and Chen, Sitan},
note={Jaeyeon Kim and Seunggeun Kim contributed equally; Taekyun Lee is also a co–first author.},
journal={arXiv preprint arXiv:2510.01384},
year={2025}
}
