Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

/dataset_v2
.venv
21 changes: 0 additions & 21 deletions LICENSE

This file was deleted.

113 changes: 22 additions & 91 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,102 +1,33 @@
# Diffusion Schrödinger Bridge with Applications to Score-Based Generative Modeling
# Pont de Schrödinger Diffusif avec Applications à la Modélisation Générative basée sur le Score

This repository contains the implementation for the paper Diffusion
Schrödinger Bridge with Applications to Score-Based Generative Modeling.
Ce dépôt contient l'implémentation de l'article "Pont de Schrödinger Diffusif avec Applications à la Modélisation Générative basée sur le Score".

If using this code, please cite the paper:
```
## Qu'est-ce qu'un pont de Schrödinger ?

Le problème du Pont de Schrödinger (SB) est un problème classique en mathématiques appliquées, contrôle optimal et probabilité ; voir [1, 2, 3]. En temps discret, il prend la forme dynamique suivante : on considère une densité de référence p(x<sub>0:N</sub>) décrivant le processus d'ajout de bruit aux données. On cherche à trouver p\*(x<sub>0:N</sub>) telle que p\*(x<sub>0</sub>) = p<sub>data</sub>(x<sub>0</sub>) et p\*(x<sub>N</sub>) = p<sub>prior</sub>(x<sub>N</sub>), tout en minimisant la divergence de Kullback-Leibler entre p\* et p. Dans ce travail, nous introduisons le **Pont de Schrödinger Diffusif** (DSB), un nouvel algorithme utilisant des approches de score-matching [4] pour approximer l'algorithme *Iterative Proportional Fitting*, une méthode itérative pour résoudre le problème SB. DSB peut être vu comme un raffinement des méthodes existantes de modélisation générative basée sur le score.

## Applications

Une application prometteuse de cette approche est le virtual staining en histopathologie, par exemple la conversion d'images colorées HES (Hématoxyline-Éosine-Safran) vers des images IHC. Le DSB permet d'apprendre une transformation probabiliste entre deux distributions d'images, ici entre la distribution des coupes HES et celle des coupes IHC, tout en garantissant une correspondance réaliste et contrôlée. Cela ouvre la voie à la génération d'images IHC virtuelles à partir de coupes HES, facilitant l'analyse biomédicale sans recourir à des colorations coûteuses ou destructives, et en conservant la structure et l'information du tissu original.

```text
@article{de2021diffusion,
title={Diffusion Schr$\backslash$" odinger Bridge with Applications to Score-Based Generative Modeling},
author={De Bortoli, Valentin and Thornton, James and Heng, Jeremy and Doucet, Arnaud},
author={De Bortoli, Valentin et Thornton, James et Heng, Jeremy et Doucet, Arnaud},
journal={arXiv preprint arXiv:2106.01357},
year={2021}
}
```

Contributors
------------

* Valentin De Bortoli
* James Thornton
* Jeremy Heng
* Arnaud Doucet

What is a Schr&ouml;dinger bridge?
-----------------------------

The Schr&ouml;dinger Bridge (SB) problem is a classical problem appearing in
applied mathematics, optimal control and probability; see [1, 2, 3]. In the
discrete-time setting, it takes the following (dynamic) form. Consider as
reference density p(x<sub>0:N</sub>) describing the process adding noise to the
data. We aim to find p\*(x<sub>0:N</sub>) such that p\*(x<sub>0</sub>) =
p<sub>data</sub>(x<sub>0</sub>) and p\*(x<sub>N</sub>) =
p<sub>prior</sub>(x<sub>N</sub>) and minimize the Kullback-Leibler divergence
between p\* and p. In this work we introduce **Diffusion Schrodinger Bridge**
(DSB), a new algorithm which uses score-matching approaches [4] to
approximate the *Iterative Proportional Fitting* algorithm, an iterative method
to find the solutions of the SB problem. DSB can be seen as a refinement of
existing score-based generative modeling methods [5, 6].


![Schrodinger bridge](schrodinger_bridge.png)

Installation
------------

This project can be installed from its git repository.

1. Obtain the sources by:

`git clone https://github.com/anon284/schrodinger_bridge.git`

or, if `git` is unavailable, download as a ZIP from GitHub https://github.com/<repository>.

2. Install:
## Création de l'environnement virtuel sur le CCUB

`conda env create -f conda.yaml`

`conda activate bridge`
### à fair par Isen

3. Download data examples:

- CelebA: `python data.py --data celeba --data_dir './data/' `
- MNIST: `python data.py --data mnist --data_dir './data/' `


How to use this code?
---------------------

3. Train Networks:
- 2d: `python main.py dataset=2d model=Basic num_steps=20 num_iter=5000`
- mnist `python main.py dataset=stackedmnist num_steps=30 model=UNET num_iter=5000 data_dir=<insert filepath of data dir <local paths/data/>`
- celeba `python main.py dataset=celeba num_steps=50 model=UNET num_iter=5000 data_dir=<insert filepath of data dir <local paths/data/>`

Checkpoints and sampled images will be saved to a newly created directory. If GPU has insufficient memory, then reduce cache size. 2D dataset should train on CPU. MNIST and CelebA was ran on 2 high-memory V100 GPUs.


References
----------

.. [1] Hans F&ouml;llmer
*Random fields and diffusion processes*
In: École d'été de Probabilités de Saint-Flour 1985-1987

.. [2] Christian Léonard
*A survey of the Schr&ouml;dinger problem and some of its connections with optimal transport*
In: Discrete & Continuous Dynamical Systems-A 2014

.. [3] Yongxin Chen, Tryphon Georgiou and Michele Pavon
*Optimal Transport in Systems and Control*
In: Annual Review of Control, Robotics, and Autonomous Systems 2020

.. [4] Aapo Hyv&auml;rinen and Peter Dayan
*Estimation of non-normalized statistical models by score matching*
In: Journal of Machine Learning Research 2005

.. [5] Yang Song and Stefano Ermon
*Generative modeling by estimating gradients of the data distribution*
In: Advances in Neural Information Processing Systems 2019

.. [6] Jonathan Ho, Ajay Jain and Pieter Abbeel
*Denoising diffusion probabilistic models*
In: Advances in Neural Information Processing Systems 2020
```bash
module load python
python3 -m venv venv
source venv/bin/activate
pip3 install --prefix=/work/imvia/in156281/diffusion_schrodinger_bridge/venv -r requirements.txt
export PYTHONPATH=/work/imvia/in156281/diffusion_schrodinger_bridge/venv/lib/python3.9/site-packages:$PYTHONPATH
pip3 list
```
Binary file added article.pdf
Binary file not shown.
Empty file removed bridge/__init__.py
Empty file.
151 changes: 44 additions & 107 deletions bridge/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,137 +1,74 @@
import os
"""
Ce fichier regroupe des fonctions utilitaires liées au pré-traitement des données.
Il contient des transformations appliquées aux images avant de les envoyer au modèle
(ex : ajout de bruit léger, changement d’échelle des valeurs, transformation logit).
L’objectif est de mettre les données dans un format plus adapté et plus stable pour l’apprentissage.
"""

import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from torchvision.datasets import CIFAR10, LSUN
from .cacheloader import CacheLoader
from .celeba import CelebA
#from datasets.ffhq import FFHQ
#from datasets.stackedmnist import Stacked_MNIST


def get_dataloader(d_config, data_init, data_folder, bs, training=False):
"""
Args:
d_config: data configuration
data_init: whether to initialize sampling from images
data_folder: data folder
bs: dataset batch size
training: whether the dataloader are used for training. If False, sampling is assumed
Returns:
training dataloader (and test dataloader if training)
"""
kwargs = {'batch_size': bs, 'shuffle': True, 'num_workers': d_config.num_workers, 'drop_last': True}
if training:
dataset, testset = get_dataset(d_config, data_folder)
dataloader = DataLoader(dataset, **kwargs)
testloader = DataLoader(testset, **kwargs)
return dataloader, testloader

if data_init:
dataset, _ = get_dataset(d_config, data_folder)
return DataLoader(dataset, **kwargs)
else:
return None


def get_dataset(d_config, data_folder):
cmp = lambda x: transforms.Compose([*x])

if d_config.dataset == 'CIFAR10':

train_transform = [transforms.Resize(d_config.image_size), transforms.ToTensor()]
test_transform = [transforms.Resize(d_config.image_size), transforms.ToTensor()]
if d_config.random_flip:
train_transform.insert(1, transforms.RandomHorizontalFlip())

path = os.path.join(data_folder, 'CIFAR10')
dataset = CIFAR10(path, train=True, download=True, transform=cmp(train_transform))
test_dataset = CIFAR10(path, train=False, download=True, transform=cmp(test_transform))

elif d_config.dataset == 'CELEBA':

train_transform = [transforms.CenterCrop(140), transforms.Resize(d_config.image_size), transforms.ToTensor()]
test_transform = [transforms.CenterCrop(140), transforms.Resize(d_config.image_size), transforms.ToTensor()]
if d_config.random_flip:
train_transform.insert(2, transforms.RandomHorizontalFlip())

path = os.path.join(data_folder, 'celeba')
dataset = CelebA(path, split='train', transform=cmp(train_transform), download=True)
test_dataset = CelebA(path, split='test', transform=cmp(test_transform), download=True)

# elif d_config.dataset == 'Stacked_MNIST':

# dataset = Stacked_MNIST(root=os.path.join(data_folder, 'stackedmnist_train'), load=False,
# source_root=data_folder, train=True)
# test_dataset = Stacked_MNIST(root=os.path.join(data_folder, 'stackedmnist_test'), load=False,
# source_root=data_folder, train=False)

elif d_config.dataset == 'LSUN':

ims = d_config.image_size
train_transform = [transforms.Resize(ims), transforms.CenterCrop(ims), transforms.ToTensor()]
test_transform = [transforms.Resize(ims), transforms.CenterCrop(ims), transforms.ToTensor()]
if d_config.random_flip:
train_transform.insert(2, transforms.RandomHorizontalFlip())

path = data_folder
dataset = LSUN(path, classes=[d_config.category + "_train"], transform=cmp(train_transform))
test_dataset = LSUN(path, classes=[d_config.category + "_val"], transform=cmp(test_transform))

# elif d_config.dataset == "FFHQ":

# train_transform = [transforms.ToTensor()]
# test_transform = [transforms.ToTensor()]
# if d_config.random_flip:
# train_transform.insert(0, transforms.RandomHorizontalFlip())

# path = os.path.join(data_folder, 'FFHQ')
# dataset = FFHQ(path, transform=train_transform, resolution=d_config.image_size)
# test_dataset = FFHQ(path, transform=test_transform, resolution=d_config.image_size)

# num_items = len(dataset)
# indices = list(range(num_items))
# random_state = np.random.get_state()
# np.random.seed(2019)
# np.random.shuffle(indices)
# np.random.set_state(random_state)
# train_indices, test_indices = indices[:int(num_items * 0.9)], indices[int(num_items * 0.9):]
# dataset = Subset(dataset, train_indices)
# test_dataset = Subset(test_dataset, test_indices)

else:
raise ValueError("Dataset [" + d_config.dataset + "] not configured.")

return dataset, test_dataset
import torch


def logit_transform(image: torch.Tensor, lam=1e-6):
"""" La fonction logit_transform sert à changer la "forme" des valeurs de l’image.
Au départ, les pixels sont entre 0 et 1.
Ici, on les transforme pour qu’ils puissent prendre n’importe quelle valeur
(positives ou négatives), ce qui est souvent plus pratique pour entraîner un modèle.

lam sert à éviter les valeurs exactement égales à 0 ou 1,
car elles posent des problèmes mathématiques.
"""
image = lam + (1 - 2 * lam) * image
return torch.log(image) - torch.log1p(-image)




def data_transform(d_config, X):
""" La fonction data_transform prépare les images AVANT de les donner au modèle.
Elle applique différentes transformations selon la configuration choisie.
L’idée générale est :
- rendre les images plus "continues"= éviter des valeurs trop "rigides" dans les pixels (0,1,2…)
→ aide le modèle à apprendre en douceur
- adapter les nombres pour faciliter l’apprentissage du modèle
"""
if d_config.uniform_dequantization:
# Ajoute un petit bruit aléatoire uniforme aux pixels.
# Cela permet d’éviter que les valeurs soient trop "discrètes"
# (ex: uniquement 0, 1, 2, 3...).
X = X / 256. * 255. + torch.rand_like(X) / 256.
elif d_config.gaussian_dequantization:
# Ajoute un très léger bruit gaussien aux pixels.
# Le but est le même : lisser les valeurs pour aider l’apprentissage.
X = X + torch.randn_like(X) * 0.01

if d_config.rescaled:
# Change l’échelle des pixels :
# au lieu d’être entre 0 et 1, ils seront entre -1 et 1.
# Beaucoup de réseaux de neurones fonctionnent mieux avec cette échelle.
X = 2 * X - 1.
elif d_config.logit_transform:
# Applique la transformation logit définie plus haut,
# pour enlever les bornes [0,1].
X = logit_transform(X)

return X


def inverse_data_transform(d_config, X):


def inverse_data_transform(d_config, X):
# La fonction inverse_data_transform fait l’inverse de data_transform.
# Elle sert à remettre la sortie du modèle sous forme d’image "normale"
# que l’on peut afficher ou sauvegarder.
if d_config.logit_transform:
# Ramène les valeurs vers l’intervalle [0,1]
# après une transformation logit.
X = torch.sigmoid(X)
elif d_config.rescaled:
# Ramène les valeurs de [-1,1] à [0,1].
X = (X + 1.) / 2.

# Sécurité : on force toutes les valeurs à rester entre 0 et 1
# pour éviter des pixels invalides.
return torch.clamp(X, 0.0, 1.0)
Loading