Skip to content

SeunggeunKimkr/PRISM

Repository files navigation

Fine-Tuning Masked Diffusion for Provable Self-Correction

ArXiv badge

graphical_abstract

PRISM is a plug-in remasking framework that fine-tunes any pretrained Masked Diffusion Model (MDM) to predict per-token quality in a single forward pass, enabling self-correction at inference. It adds a lightweight objective without changing the base MDM architecture.

  • Theory-backed head: estimates per-token quality $q_i \approx \Pr[x_i = y_i \mid y \oplus m_i]$.
  • Practical: simple to fine-tune; Improves over remasking baselines on Sudoku, unconditional text (170M), and code (LLaDA-8B on MBPP) with modest fine-tuning compute.

🔔 News

  • [2025.11.09] PRISM codebase and Sudoku dataset have been open-sourced.
  • [2025.10.01] Paper posted on arXiv.

Getting started 🔥

To get started, create a conda environment and install the FlashAttention:

conda env create -f requirements.yaml
conda activate PRISM
pip install flash-attn==2.6.3

Create the following directories to store saved models and logs:

mkdir outputs
mkdir watch_folder

If you are pre-training or fine-tuning on the Sudoku dataset (48K), ensure the CSV files in the ./data/ directory are placed in your ${data.cache_dir}/sudoku/.

You can download the pretrained MDLM checkpoint (OWT) from this Google Drive folder, provided by the MDLM repository. Place the downloaded files in the ./outputs/checkpoints directory.

Training 🦾

Pre-training

We provide pre-training script for Sudoku only:

./scripts/pretrain_sudoku.sh

Fine-tuning

For PRISM fine-tuning, we provide the following scripts:

  • Sudoku:

    ./scripts/finetune_sudoku_prism.sh
  • OWT:

    ./scripts/finetune_owt_prism.sh
  • LLaDA:

    ./PRISM_llada/training_scripts/test_run.sh

Evaluation 🎯

We provide evaluation scripts for the fine-tuned module using a static sampler:

  • Sudoku (Success Rate):

    ./scripts/sample_sudoku_prism.sh
  • OWT (MAUVE, Gen PPL, Entropy): For faster evaluation, we support multi-node text generation:

    ./scripts/sample_owt_prism.sh

    Additional Loop Strategy

    You can enhance the generated texts by applying a loop strategy. Modify the following parameters in the bash scripts to customize the behavior:

    1. sampling.loop_steps: Number of loop iterations to perform.
    2. sampling.num_remask_loop: Number of tokens to remask during each iteration.
  • LLaDA (HumanEval, MBPP) : Code evaluation follows with two procedures;

    1. generate samples and save .jsonl file : PRISM_llada/generate_samples.py
    2. evalute .jsonl file: PRISM_llada/eval_mbpp.py

Baselines 🆚

We provide baseline implementations for comparison:

  • Sudoku:

    • Token-Critic: An unofficial training recipe for the Token-Critic approach on the Sudoku dataset:
      ./scripts/finetune_sudoku_token_critic.sh
    • ReMDM-conf Sampler: Evaluate using the ReMDM-conf sampler:
      ./scripts/sample_sudoku_remdm-conf.sh
  • OWT:

    • ReMDM-cap Sampler: Evaluate using the ReMDM-cap sampler:
      ./scripts/sample_owt_remdm.sh

Acknowledgements 🙏

This repository was built on top of ReMDM, which was based on MDLM and SEDD.

Citation 📝

@article{kim2025fine,
  title={Fine-Tuning Masked Diffusion for Provable Self-Correction},
  author={Kim, Jaeyeon and Kim, Seunggeun and Lee, Taekyun and Pan, David Z and Kim, Hyeji and Kakade, Sham and Chen, Sitan},
  note={Jaeyeon Kim and Seunggeun Kim contributed equally; Taekyun Lee is also a co–first author.},
  journal={arXiv preprint arXiv:2510.01384},
  year={2025}
}

About

Public repository for fine-tuning Masked Diffusion Models toward provable self-correction.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors