diff --git a/.gitignore b/.gitignore
index b6e4761..fb278ff 100644
--- a/.gitignore
+++ b/.gitignore
@@ -127,3 +127,6 @@ dmypy.json
# Pyre type checker
.pyre/
+
+/dataset_v2
+.venv
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
deleted file mode 100644
index bb1b3e0..0000000
--- a/LICENSE
+++ /dev/null
@@ -1,21 +0,0 @@
-MIT License
-
-Copyright (c) 2021 anon284
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
diff --git a/README.md b/README.md
index e9e2ef3..86f38eb 100644
--- a/README.md
+++ b/README.md
@@ -1,102 +1,33 @@
-# Diffusion Schrödinger Bridge with Applications to Score-Based Generative Modeling
+# Pont de Schrödinger Diffusif avec Applications à la Modélisation Générative basée sur le Score
-This repository contains the implementation for the paper Diffusion
-Schrödinger Bridge with Applications to Score-Based Generative Modeling.
+Ce dépôt contient l'implémentation de l'article "Pont de Schrödinger Diffusif avec Applications à la Modélisation Générative basée sur le Score".
-If using this code, please cite the paper:
-```
+## Qu'est-ce qu'un pont de Schrödinger ?
+
+Le problème du Pont de Schrödinger (SB) est un problème classique en mathématiques appliquées, contrôle optimal et probabilité ; voir [1, 2, 3]. En temps discret, il prend la forme dynamique suivante : on considère une densité de référence p(x0:N) décrivant le processus d'ajout de bruit aux données. On cherche à trouver p\*(x0:N) telle que p\*(x0) = pdata(x0) et p\*(xN) = pprior(xN), tout en minimisant la divergence de Kullback-Leibler entre p\* et p. Dans ce travail, nous introduisons le **Pont de Schrödinger Diffusif** (DSB), un nouvel algorithme utilisant des approches de score-matching [4] pour approximer l'algorithme *Iterative Proportional Fitting*, une méthode itérative pour résoudre le problème SB. DSB peut être vu comme un raffinement des méthodes existantes de modélisation générative basée sur le score.
+
+## Applications
+
+Une application prometteuse de cette approche est le virtual staining en histopathologie, par exemple la conversion d'images colorées HES (Hématoxyline-Éosine-Safran) vers des images IHC. Le DSB permet d'apprendre une transformation probabiliste entre deux distributions d'images, ici entre la distribution des coupes HES et celle des coupes IHC, tout en garantissant une correspondance réaliste et contrôlée. Cela ouvre la voie à la génération d'images IHC virtuelles à partir de coupes HES, facilitant l'analyse biomédicale sans recourir à des colorations coûteuses ou destructives, et en conservant la structure et l'information du tissu original.
+
+```text
@article{de2021diffusion,
title={Diffusion Schr$\backslash$" odinger Bridge with Applications to Score-Based Generative Modeling},
- author={De Bortoli, Valentin and Thornton, James and Heng, Jeremy and Doucet, Arnaud},
+ author={De Bortoli, Valentin et Thornton, James et Heng, Jeremy et Doucet, Arnaud},
journal={arXiv preprint arXiv:2106.01357},
year={2021}
}
```
-Contributors
-------------
-
-* Valentin De Bortoli
-* James Thornton
-* Jeremy Heng
-* Arnaud Doucet
-
-What is a Schrödinger bridge?
------------------------------
-
-The Schrödinger Bridge (SB) problem is a classical problem appearing in
-applied mathematics, optimal control and probability; see [1, 2, 3]. In the
-discrete-time setting, it takes the following (dynamic) form. Consider as
-reference density p(x0:N) describing the process adding noise to the
-data. We aim to find p\*(x0:N) such that p\*(x0) =
-pdata(x0) and p\*(xN) =
-pprior(xN) and minimize the Kullback-Leibler divergence
-between p\* and p. In this work we introduce **Diffusion Schrodinger Bridge**
-(DSB), a new algorithm which uses score-matching approaches [4] to
-approximate the *Iterative Proportional Fitting* algorithm, an iterative method
-to find the solutions of the SB problem. DSB can be seen as a refinement of
-existing score-based generative modeling methods [5, 6].
-
-
-
-
-Installation
-------------
-
-This project can be installed from its git repository.
-
-1. Obtain the sources by:
-
- `git clone https://github.com/anon284/schrodinger_bridge.git`
-
-or, if `git` is unavailable, download as a ZIP from GitHub https://github.com/.
-
-2. Install:
+## Création de l'environnement virtuel sur le CCUB
- `conda env create -f conda.yaml`
-
- `conda activate bridge`
+### à fair par Isen
-3. Download data examples:
-
- - CelebA: `python data.py --data celeba --data_dir './data/' `
- - MNIST: `python data.py --data mnist --data_dir './data/' `
-
-
-How to use this code?
----------------------
-
-3. Train Networks:
- - 2d: `python main.py dataset=2d model=Basic num_steps=20 num_iter=5000`
- - mnist `python main.py dataset=stackedmnist num_steps=30 model=UNET num_iter=5000 data_dir=`
- - celeba `python main.py dataset=celeba num_steps=50 model=UNET num_iter=5000 data_dir=`
-
-Checkpoints and sampled images will be saved to a newly created directory. If GPU has insufficient memory, then reduce cache size. 2D dataset should train on CPU. MNIST and CelebA was ran on 2 high-memory V100 GPUs.
-
-
-References
-----------
-
-.. [1] Hans Föllmer
- *Random fields and diffusion processes*
- In: École d'été de Probabilités de Saint-Flour 1985-1987
-
-.. [2] Christian Léonard
- *A survey of the Schrödinger problem and some of its connections with optimal transport*
- In: Discrete & Continuous Dynamical Systems-A 2014
-
-.. [3] Yongxin Chen, Tryphon Georgiou and Michele Pavon
- *Optimal Transport in Systems and Control*
- In: Annual Review of Control, Robotics, and Autonomous Systems 2020
-
-.. [4] Aapo Hyvärinen and Peter Dayan
- *Estimation of non-normalized statistical models by score matching*
- In: Journal of Machine Learning Research 2005
-
-.. [5] Yang Song and Stefano Ermon
- *Generative modeling by estimating gradients of the data distribution*
- In: Advances in Neural Information Processing Systems 2019
-
-.. [6] Jonathan Ho, Ajay Jain and Pieter Abbeel
- *Denoising diffusion probabilistic models*
- In: Advances in Neural Information Processing Systems 2020
+```bash
+module load python
+python3 -m venv venv
+source venv/bin/activate
+pip3 install --prefix=/work/imvia/in156281/diffusion_schrodinger_bridge/venv -r requirements.txt
+export PYTHONPATH=/work/imvia/in156281/diffusion_schrodinger_bridge/venv/lib/python3.9/site-packages:$PYTHONPATH
+pip3 list
+```
diff --git a/article.pdf b/article.pdf
new file mode 100644
index 0000000..23cbd86
Binary files /dev/null and b/article.pdf differ
diff --git a/bridge/__init__.py b/bridge/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/bridge/data/__init__.py b/bridge/data/__init__.py
index e065a56..86431f2 100644
--- a/bridge/data/__init__.py
+++ b/bridge/data/__init__.py
@@ -1,137 +1,74 @@
-import os
+"""
+Ce fichier regroupe des fonctions utilitaires liées au pré-traitement des données.
+Il contient des transformations appliquées aux images avant de les envoyer au modèle
+(ex : ajout de bruit léger, changement d’échelle des valeurs, transformation logit).
+L’objectif est de mettre les données dans un format plus adapté et plus stable pour l’apprentissage.
+"""
-import numpy as np
-import torch
-import torchvision.transforms as transforms
-from torch.utils.data import DataLoader
-from torch.utils.data import Subset
-from torchvision.datasets import CIFAR10, LSUN
-from .cacheloader import CacheLoader
-from .celeba import CelebA
-#from datasets.ffhq import FFHQ
-#from datasets.stackedmnist import Stacked_MNIST
-
-
-def get_dataloader(d_config, data_init, data_folder, bs, training=False):
- """
- Args:
- d_config: data configuration
- data_init: whether to initialize sampling from images
- data_folder: data folder
- bs: dataset batch size
- training: whether the dataloader are used for training. If False, sampling is assumed
- Returns:
- training dataloader (and test dataloader if training)
- """
- kwargs = {'batch_size': bs, 'shuffle': True, 'num_workers': d_config.num_workers, 'drop_last': True}
- if training:
- dataset, testset = get_dataset(d_config, data_folder)
- dataloader = DataLoader(dataset, **kwargs)
- testloader = DataLoader(testset, **kwargs)
- return dataloader, testloader
-
- if data_init:
- dataset, _ = get_dataset(d_config, data_folder)
- return DataLoader(dataset, **kwargs)
- else:
- return None
-
-
-def get_dataset(d_config, data_folder):
- cmp = lambda x: transforms.Compose([*x])
-
- if d_config.dataset == 'CIFAR10':
-
- train_transform = [transforms.Resize(d_config.image_size), transforms.ToTensor()]
- test_transform = [transforms.Resize(d_config.image_size), transforms.ToTensor()]
- if d_config.random_flip:
- train_transform.insert(1, transforms.RandomHorizontalFlip())
-
- path = os.path.join(data_folder, 'CIFAR10')
- dataset = CIFAR10(path, train=True, download=True, transform=cmp(train_transform))
- test_dataset = CIFAR10(path, train=False, download=True, transform=cmp(test_transform))
-
- elif d_config.dataset == 'CELEBA':
-
- train_transform = [transforms.CenterCrop(140), transforms.Resize(d_config.image_size), transforms.ToTensor()]
- test_transform = [transforms.CenterCrop(140), transforms.Resize(d_config.image_size), transforms.ToTensor()]
- if d_config.random_flip:
- train_transform.insert(2, transforms.RandomHorizontalFlip())
-
- path = os.path.join(data_folder, 'celeba')
- dataset = CelebA(path, split='train', transform=cmp(train_transform), download=True)
- test_dataset = CelebA(path, split='test', transform=cmp(test_transform), download=True)
-
- # elif d_config.dataset == 'Stacked_MNIST':
- # dataset = Stacked_MNIST(root=os.path.join(data_folder, 'stackedmnist_train'), load=False,
- # source_root=data_folder, train=True)
- # test_dataset = Stacked_MNIST(root=os.path.join(data_folder, 'stackedmnist_test'), load=False,
- # source_root=data_folder, train=False)
-
- elif d_config.dataset == 'LSUN':
-
- ims = d_config.image_size
- train_transform = [transforms.Resize(ims), transforms.CenterCrop(ims), transforms.ToTensor()]
- test_transform = [transforms.Resize(ims), transforms.CenterCrop(ims), transforms.ToTensor()]
- if d_config.random_flip:
- train_transform.insert(2, transforms.RandomHorizontalFlip())
-
- path = data_folder
- dataset = LSUN(path, classes=[d_config.category + "_train"], transform=cmp(train_transform))
- test_dataset = LSUN(path, classes=[d_config.category + "_val"], transform=cmp(test_transform))
-
- # elif d_config.dataset == "FFHQ":
-
- # train_transform = [transforms.ToTensor()]
- # test_transform = [transforms.ToTensor()]
- # if d_config.random_flip:
- # train_transform.insert(0, transforms.RandomHorizontalFlip())
-
- # path = os.path.join(data_folder, 'FFHQ')
- # dataset = FFHQ(path, transform=train_transform, resolution=d_config.image_size)
- # test_dataset = FFHQ(path, transform=test_transform, resolution=d_config.image_size)
-
- # num_items = len(dataset)
- # indices = list(range(num_items))
- # random_state = np.random.get_state()
- # np.random.seed(2019)
- # np.random.shuffle(indices)
- # np.random.set_state(random_state)
- # train_indices, test_indices = indices[:int(num_items * 0.9)], indices[int(num_items * 0.9):]
- # dataset = Subset(dataset, train_indices)
- # test_dataset = Subset(test_dataset, test_indices)
-
- else:
- raise ValueError("Dataset [" + d_config.dataset + "] not configured.")
-
- return dataset, test_dataset
+import torch
def logit_transform(image: torch.Tensor, lam=1e-6):
+ """" La fonction logit_transform sert à changer la "forme" des valeurs de l’image.
+ Au départ, les pixels sont entre 0 et 1.
+ Ici, on les transforme pour qu’ils puissent prendre n’importe quelle valeur
+ (positives ou négatives), ce qui est souvent plus pratique pour entraîner un modèle.
+
+ lam sert à éviter les valeurs exactement égales à 0 ou 1,
+ car elles posent des problèmes mathématiques.
+ """
image = lam + (1 - 2 * lam) * image
return torch.log(image) - torch.log1p(-image)
+
+
def data_transform(d_config, X):
+ """ La fonction data_transform prépare les images AVANT de les donner au modèle.
+ Elle applique différentes transformations selon la configuration choisie.
+ L’idée générale est :
+ - rendre les images plus "continues"= éviter des valeurs trop "rigides" dans les pixels (0,1,2…)
+ → aide le modèle à apprendre en douceur
+ - adapter les nombres pour faciliter l’apprentissage du modèle
+ """
if d_config.uniform_dequantization:
+ # Ajoute un petit bruit aléatoire uniforme aux pixels.
+ # Cela permet d’éviter que les valeurs soient trop "discrètes"
+ # (ex: uniquement 0, 1, 2, 3...).
X = X / 256. * 255. + torch.rand_like(X) / 256.
elif d_config.gaussian_dequantization:
+ # Ajoute un très léger bruit gaussien aux pixels.
+ # Le but est le même : lisser les valeurs pour aider l’apprentissage.
X = X + torch.randn_like(X) * 0.01
if d_config.rescaled:
+ # Change l’échelle des pixels :
+ # au lieu d’être entre 0 et 1, ils seront entre -1 et 1.
+ # Beaucoup de réseaux de neurones fonctionnent mieux avec cette échelle.
X = 2 * X - 1.
elif d_config.logit_transform:
+ # Applique la transformation logit définie plus haut,
+ # pour enlever les bornes [0,1].
X = logit_transform(X)
return X
-def inverse_data_transform(d_config, X):
+
+def inverse_data_transform(d_config, X):
+ # La fonction inverse_data_transform fait l’inverse de data_transform.
+ # Elle sert à remettre la sortie du modèle sous forme d’image "normale"
+ # que l’on peut afficher ou sauvegarder.
if d_config.logit_transform:
+ # Ramène les valeurs vers l’intervalle [0,1]
+ # après une transformation logit.
X = torch.sigmoid(X)
elif d_config.rescaled:
+ # Ramène les valeurs de [-1,1] à [0,1].
X = (X + 1.) / 2.
+ # Sécurité : on force toutes les valeurs à rester entre 0 et 1
+ # pour éviter des pixels invalides.
return torch.clamp(X, 0.0, 1.0)
\ No newline at end of file
diff --git a/bridge/data/cacheloader.py b/bridge/data/cacheloader.py
deleted file mode 100644
index 7ae1e64..0000000
--- a/bridge/data/cacheloader.py
+++ /dev/null
@@ -1,75 +0,0 @@
-import torch
-from torch.utils.data import Dataset
-from tqdm import tqdm
-import time
-
-
-class CacheLoader(Dataset):
- def __init__(self, fb,
- sample_net,
- dataloader_b,
- num_batches,
- langevin,
- n,
- mean, std,
- batch_size, device='cpu',
- dataloader_f=None,
- transfer=False):
-
- super().__init__()
- start = time.time()
- shape = langevin.d
- num_steps = langevin.num_steps
- self.data = torch.zeros(
- (num_batches, batch_size*num_steps, 2, *shape)).to(device) # .cpu()
- # self.steps_data = torch.zeros(
- # (num_batches, batch_size*num_steps, 1), dtype=torch.long).to(device) # .cpu() # steps
- self.steps_data = torch.zeros(
- (num_batches, batch_size*num_steps, 1)).to(device) # .cpu() # steps
- with torch.no_grad():
- for b in range(num_batches):
- if fb == 'b':
- batch = next(dataloader_b)[0]
- batch = batch.to(device)
- elif fb == 'f' and transfer:
- batch = next(dataloader_f)[0]
- batch = batch.to(device)
- else:
- batch = mean + std * \
- torch.randn((batch_size, *shape), device=device)
-
- if (n == 1) & (fb == 'b'):
- x, out, steps_expanded = langevin.record_init_langevin(
- batch)
- else:
- x, out, steps_expanded = langevin.record_langevin_seq(
- sample_net, batch, ipf_it=n)
-
- # store x, out
- x = x.unsqueeze(2)
- out = out.unsqueeze(2)
- batch_data = torch.cat((x, out), dim=2)
- flat_data = batch_data.flatten(start_dim=0, end_dim=1)
- self.data[b] = flat_data
-
- # store steps
- flat_steps = steps_expanded.flatten(start_dim=0, end_dim=1)
- self.steps_data[b] = flat_steps
-
- self.data = self.data.flatten(start_dim=0, end_dim=1)
- self.steps_data = self.steps_data.flatten(start_dim=0, end_dim=1)
-
- stop = time.time()
- print('Cache size: {0}'.format(self.data.shape))
- print("Load time: {0}".format(stop-start))
-
- def __getitem__(self, index):
- item = self.data[index]
- x = item[0]
- out = item[1]
- steps = self.steps_data[index]
-
- return x, out, steps
-
- def __len__(self):
- return self.data.shape[0]
diff --git a/bridge/data/celeba.py b/bridge/data/celeba.py
deleted file mode 100644
index 5fd1529..0000000
--- a/bridge/data/celeba.py
+++ /dev/null
@@ -1,143 +0,0 @@
-import os
-
-import PIL
-import torch
-
-from .utils import download_file_from_google_drive, check_integrity
-from .vision import VisionDataset
-
-class CelebA(VisionDataset):
- """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset.
- Args:
- root (string): Root directory where images are downloaded to.
- split (string): One of {'train', 'valid', 'test'}.
- Accordingly dataset is selected.
- target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
- or ``landmarks``. Can also be a list to output a tuple with all specified target types.
- The targets represent:
- ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
- ``identity`` (int): label for each person (data points with the same identity are the same person)
- ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
- ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
- righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
- Defaults to ``attr``.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.ToTensor``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- download (bool, optional): If true, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
- """
-
- base_folder = "./"
- # There currently does not appear to be a easy way to extract 7z in python (without introducing additional
- # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
- # right now.
- file_list = [['1ak57neUpo1hbikxozBCWTe_Vf3MmqnAT', 'list_landmarks_align_celeba.txt'],
- ['1g9dqxOv-jrkR_1p9hhsHwRZqdLUNS6XG', 'list_eval_partition.txt'],
- ['1raAP7l0kKaZg7W01Zk8DyAyN4B_E41TJ', 'list_bbox_celeba.txt'],
- ['1zGgRizyR872PG1N4TzVuI7t1BsX4lqpq', 'list_attr_celeba.txt'],
- ['1T_FfvbnT7NwqGYwF9-OB4ZBXaTK0hb4c', 'img_align_celeba.zip'],
- ['1F5XjLVZ7PjTDybzUDV9pi6KUmf_9j_Nz', 'identity_CelebA.txt']
- ]
-
- def __init__(self, root,
- split="train",
- target_type="attr",
- transform=None,
- target_transform=None,
- download=False):
- import pandas
- super(CelebA, self).__init__(root)
- self.split = split
- if isinstance(target_type, list):
- self.target_type = target_type
- else:
- self.target_type = [target_type]
- self.transform = transform
- self.target_transform = target_transform
-
- if download:
- self.download()
-
-
- self.transform = transform
- self.target_transform = target_transform
-
- if split.lower() == "train":
- split = 0
- elif split.lower() == "valid":
- split = 1
- elif split.lower() == "test":
- split = 2
- else:
- raise ValueError('Wrong split entered! Please use split="train" '
- 'or split="valid" or split="test"')
-
- with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f:
- splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)
-
- with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f:
- self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)
-
- with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f:
- self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0)
-
- with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f:
- self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1)
-
- with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f:
- self.attr = pandas.read_csv(f, delim_whitespace=True, header=1)
-
- mask = (splits[1] == split)
- self.filename = splits[mask].index.values
- self.identity = torch.as_tensor(self.identity[mask].values)
- self.bbox = torch.as_tensor(self.bbox[mask].values)
- self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values)
- self.attr = torch.as_tensor(self.attr[mask].values)
- self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
-
-
- def download(self):
- import zipfile
-
- for (file_id, filename) in self.file_list:
- fp = os.path.join(self.root, self.base_folder, filename)
- if not os.path.exists(fp):
- download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename)
-
- with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
- f.extractall(os.path.join(self.root, self.base_folder))
-
- def __getitem__(self, index):
- X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
-
- target = []
- for t in self.target_type:
- if t == "attr":
- target.append(self.attr[index, :])
- elif t == "identity":
- target.append(self.identity[index, 0])
- elif t == "bbox":
- target.append(self.bbox[index, :])
- elif t == "landmarks":
- target.append(self.landmarks_align[index, :])
- else:
- raise ValueError("Target type \"{}\" is not recognized.".format(t))
- target = tuple(target) if len(target) > 1 else target[0]
-
- if self.transform is not None:
- X = self.transform(X)
-
- if self.target_transform is not None:
- target = self.target_transform(target)
-
- return X, target
-
- def __len__(self):
- return len(self.attr)
-
- def extra_repr(self):
- lines = ["Target type: {target_type}", "Split: {split}"]
- return '\n'.join(lines).format(**self.__dict__)
\ No newline at end of file
diff --git a/bridge/data/emnist.py b/bridge/data/emnist.py
deleted file mode 100644
index 4fbe9bf..0000000
--- a/bridge/data/emnist.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import os, shutil
-import urllib
-import torch
-import torchvision.datasets
-import torchvision.transforms as transforms
-import torchvision.utils as vutils
-from torch.utils.data import DataLoader
-from torch.utils.data import Dataset
-from tqdm import tqdm
-from torchvision.utils import save_image
-
-
-class EMNIST(Dataset):
- def __init__(self, root="./dataset", load=True, source_root=None, imageSize=28,
- train=True, num_channels=3, device='cpu'): # load=True means loading the dataset from existed files.
- super(EMNIST, self).__init__()
- self.data = torch.load(os.path.join(root, "data.pt"))
- self.targets = torch.load(os.path.join(root, "targets.pt"))
-
- def __getitem__(self, index):
- img, targets = self.data[index], self.targets[index]
-
- return img, targets
-
- def __len__(self):
- return len(self.targets)
diff --git a/bridge/data/hes_cd30.py b/bridge/data/hes_cd30.py
new file mode 100644
index 0000000..a96c313
--- /dev/null
+++ b/bridge/data/hes_cd30.py
@@ -0,0 +1,111 @@
+"""
+Ce fichier définit le dataset PyTorch HES_CD30.
+Il sert à charger des images depuis un dossier organisé par domaines (HES ou CD30),
+à appliquer les transformations (resize, crop, conversion en tenseur), puis à fournir
+les images au DataLoader pendant l’entraînement. Il ne fait pas l’entraînement : il ne fait
+que "lire et préparer" les images.
+"""
+
+
+import os
+import glob
+import torch
+from torch.utils.data import Dataset
+from PIL import Image
+import torchvision.transforms as transforms
+
+
+class HES_CD30(Dataset):
+ """
+ Cette classe permet de charger des images médicales
+ pour les utiliser dans un modèle PyTorch.
+
+ Elle lit des images stockées sur le disque et les transforme
+ en tenseurs exploitables par un réseau de neurones.
+
+ Même si la documentation parle d’images HES et CD30 appariées,
+ cette classe charge UN SEUL type d’image à la fois
+ (soit HES, soit CD30).
+ """
+
+ """Dataset for paired HES / CD30 virtual staining.
+
+ Expected folder structure:
+ root/
+ ├── HES/
+ │ ├── c/
+ │ │ ├── patch_x2000_y32000.jpg
+ │ │ └── ...
+ │ ├── e/
+ │ └── ...
+ └── CD30/
+ ├── c/
+ │ ├── patch_x2000_y32000.jpg
+ │ └── ...
+ ├── e/
+ └── ...
+
+ Each HES image has a paired CD30 image with the same subfolder
+ and filename.
+ """
+
+ def __init__(self, root, image_size=256, domain='HES', transform=None):
+ """
+ Args:
+ root: path to dataset_v2/ directory
+ image_size: resize images to this size
+ domain: 'HES' or 'CD30'
+ transform: optional torchvision transform (overrides default)
+ """
+ super().__init__()
+ # Sauvegarde des paramètres
+ self.root = root
+ self.domain = domain
+ self.image_size = image_size
+
+ # Construction du chemin vers le dossier des images
+ # Exemple : root/HES/ ou root/CD30/
+ domain_dir = os.path.join(root, domain)
+ # Recherche de toutes les images .jpg dans les sous-dossiers
+ # (par exemple c/, e/, etc.)
+ self.image_paths = sorted(
+ glob.glob(os.path.join(domain_dir, '*', '*.jpg')))
+
+ if len(self.image_paths) == 0:
+ raise RuntimeError(
+ f"No images found in {domain_dir}. "
+ f"Expected structure: {domain_dir}//*.jpg"
+ )
+
+ # Définition des transformations appliquées aux images
+ if transform is not None:
+ # Si l’utilisateur fournit ses propres transformations
+ self.transform = transform
+ else:
+ # Transformations par défaut :
+ # - redimensionnement
+ # - découpe centrale
+ # - conversion en tenseur PyTorch
+ self.transform = transforms.Compose([
+ transforms.Resize(image_size),
+ transforms.CenterCrop(image_size),
+ transforms.ToTensor(),
+ ])
+ # Message informatif pour vérifier que le chargement s’est bien passé
+ print(f"[HES_CD30] Loaded {len(self.image_paths)} images from {domain_dir}")
+
+ def __len__(self):
+ # Retourne le nombre total d’images du dataset.
+ # PyTorch utilise cette information pour parcourir les données.
+ return len(self.image_paths)
+
+ def __getitem__(self, index):
+ # Récupère le chemin de l’image demandée
+ img_path = self.image_paths[index]
+ # Ouvre l’image et la convertit en RGB (3 canaux)
+ img = Image.open(img_path).convert('RGB')
+ # Applique les transformations définies plus haut
+ img = self.transform(img)
+ # PyTorch attend souvent un couple (image, label).
+ # Ici, il n’y a pas de label, donc on renvoie une valeur factice.
+ return img, 0 # label fictif pour compatibilité
diff --git a/bridge/data/stackedmnist.py b/bridge/data/stackedmnist.py
deleted file mode 100644
index 4765135..0000000
--- a/bridge/data/stackedmnist.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import os, shutil
-import urllib
-import torch
-import torchvision.datasets
-import torchvision.transforms as transforms
-import torchvision.utils as vutils
-from torch.utils.data import DataLoader
-from torch.utils.data import Dataset
-from tqdm import tqdm
-from torchvision.utils import save_image
-
-
-class Stacked_MNIST(Dataset):
- def __init__(self, root="./dataset", load=True, source_root=None, imageSize=28,
- train=True, num_channels=3, device='cpu'): # load=True means loading the dataset from existed files.
- super(Stacked_MNIST, self).__init__()
- self.num_channels = min(3,num_channels)
- if load:
- self.data = torch.load(os.path.join(root, "data.pt"))
- self.targets = torch.load(os.path.join(root, "targets.pt"))
- else:
- if source_root is None:
- source_root = "./datasets"
-
- source_data = torchvision.datasets.MNIST(source_root, train=train, transform=transforms.Compose([
- transforms.Resize(imageSize),
- transforms.ToTensor(),
- transforms.Normalize((0.5,), (0.5,)),
- ]), download=True)
- self.data = torch.zeros((0, self.num_channels, imageSize, imageSize))
- self.targets = torch.zeros((0), dtype=torch.int64)
- # has 60000 images in total
- dataloader_R = DataLoader(source_data, batch_size=100, shuffle=True)
- dataloader_G = DataLoader(source_data, batch_size=100, shuffle=True)
- dataloader_B = DataLoader(source_data, batch_size=100, shuffle=True)
-
- im_dir = root + '/im'
- if os.path.exists(im_dir):
- shutil.rmtree(im_dir)
- os.makedirs(im_dir)
-
- idx = 0
- for (xR, yR), (xG, yG), (xB, yB) in tqdm(zip(dataloader_R, dataloader_G, dataloader_B)):
- x = torch.cat([xR, xG, xB][-self.num_channels:], dim=1)
- y = (100 * yR + 10 * yG + yB) % 10**self.num_channels
- self.data = torch.cat((self.data, x), dim=0)
- self.targets = torch.cat((self.targets, y), dim=0)
-
- for k in range(100):
- if idx < 10000:
- im = x[k]
- filename = root + '/im/{:05}.jpg'.format(idx)
- save_image(im, filename)
- idx += 1
-
- if not os.path.isdir(root):
- os.makedirs(root)
- torch.save(self.data, os.path.join(root, "data.pt"))
- torch.save(self.targets, os.path.join(root, "targets.pt"))
- vutils.save_image(x, "ali.png", nrow=10)
-
- self.data = self.data#.to(device)
- self.targets = self.targets#.to(device)
-
- def __getitem__(self, index):
- img, targets = self.data[index], self.targets[index]
-
- return img, targets
-
- def __len__(self):
- return len(self.targets)
diff --git a/bridge/data/two_dim.py b/bridge/data/two_dim.py
deleted file mode 100644
index cde91c1..0000000
--- a/bridge/data/two_dim.py
+++ /dev/null
@@ -1,91 +0,0 @@
-import numpy as np
-import torch
-from sklearn import datasets
-from torch.utils.data import TensorDataset
-
-# checker/pinwheel/8gaussians can be found at
-# https://github.com/rtqichen/ffjord/blob/994864ad0517db3549717c25170f9b71e96788b1/lib/toy_data.py#L8
-
-def data_distrib(npar, data):
-
- if data == 'mixture':
- init_sample = torch.randn(npar, 2)
- p = init_sample.shape[0]//2
- init_sample[:p,0] = init_sample[:p,0] - 7.
- init_sample[p:,0] = init_sample[p:,0] + 7.
-
- if data == 'scurve':
- X, y = datasets.make_s_curve(n_samples=npar, noise=0.0, random_state=None)
- init_sample = torch.tensor(X)[:,[0,2]]
- scaling_factor = 7
- init_sample = (init_sample - init_sample.mean()) / init_sample.std() * scaling_factor
-
- if data == 'swiss':
- X, y = datasets.make_swiss_roll(n_samples=npar, noise=0.0, random_state=None)
- init_sample = torch.tensor(X)[:,[0,2]]
- scaling_factor = 7
- init_sample = (init_sample - init_sample.mean()) / init_sample.std() * scaling_factor
-
- if data == 'moon':
- X, y = datasets.make_moons(n_samples=npar, noise=0.0, random_state=None)
- scaling_factor = 7.
- init_sample = torch.tensor(X)
- init_sample = (init_sample - init_sample.mean()) / init_sample.std() * scaling_factor
-
- if data == 'circle':
- X, y = datasets.make_circles(n_samples=npar, noise=0.0, random_state=None, factor=.5)
- init_sample = torch.tensor(X) * 10
-
- if data == 'checker':
- x1 = np.random.rand(npar) * 4 - 2
- x2_ = np.random.rand(npar) - np.random.randint(0, 2, npar) * 2
- x2 = x2_ + (np.floor(x1) % 2)
- x = np.concatenate([x1[:, None], x2[:, None]], 1) * 7.5
- init_sample = torch.from_numpy(x)
-
- if data == 'pinwheel':
- radial_std = 0.3
- tangential_std = 0.1
- num_classes = 5
- num_per_class = npar // 5
- rate = 0.25
- rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False)
-
- features = np.random.randn(num_classes*num_per_class, 2) \
- * np.array([radial_std, tangential_std])
- features[:, 0] += 1.
- labels = np.repeat(np.arange(num_classes), num_per_class)
-
- angles = rads[labels] + rate * np.exp(features[:, 0])
- rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)])
- rotations = np.reshape(rotations.T, (-1, 2, 2))
- x = 7.5 * np.random.permutation(np.einsum("ti,tij->tj", features, rotations))
- init_sample = torch.from_numpy(x)
-
- if data == '8gaussians':
- scale = 4.
- centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)),
- (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2),
- 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))]
- centers = [(scale * x, scale * y) for x, y in centers]
-
- dataset = []
- for i in range(npar):
- point = np.random.randn(2) * 0.5
- idx = np.random.randint(8)
- center = centers[idx]
- point[0] += center[0]
- point[1] += center[1]
- dataset.append(point)
- dataset = np.array(dataset, dtype="float32")
- dataset *= 3
- init_sample = torch.from_numpy(dataset)
-
- init_sample = init_sample.float()
-
- return init_sample
-
-def two_dim_ds(npar, data_tag):
- init_sample = data_distrib(npar, data_tag)
- init_ds = TensorDataset(init_sample)
- return init_ds
diff --git a/bridge/data/utils.py b/bridge/data/utils.py
deleted file mode 100644
index 24cdc38..0000000
--- a/bridge/data/utils.py
+++ /dev/null
@@ -1,183 +0,0 @@
-import errno
-import hashlib
-import os
-import os.path
-
-from torch.utils.model_zoo import tqdm
-
-
-def gen_bar_updater():
- pbar = tqdm(total=None)
-
- def bar_update(count, block_size, total_size):
- if pbar.total is None and total_size:
- pbar.total = total_size
- progress_bytes = count * block_size
- pbar.update(progress_bytes - pbar.n)
-
- return bar_update
-
-
-def check_integrity(fpath, md5=None):
- if md5 is None:
- return True
- if not os.path.isfile(fpath):
- return False
- md5o = hashlib.md5()
- with open(fpath, 'rb') as f:
- # read in 1MB chunks
- for chunk in iter(lambda: f.read(1024 * 1024), b''):
- md5o.update(chunk)
- md5c = md5o.hexdigest()
- if md5c != md5:
- return False
- return True
-
-
-def makedir_exist_ok(dirpath):
- """
- Python2 support for os.makedirs(.., exist_ok=True)
- """
- try:
- os.makedirs(dirpath)
- except OSError as e:
- if e.errno == errno.EEXIST:
- pass
- else:
- raise
-
-
-def download_url(url, root, filename=None, md5=None):
- """Download a file from a url and place it in root.
- Args:
- url (str): URL to download file from
- root (str): Directory to place downloaded file in
- filename (str, optional): Name to save the file under. If None, use the basename of the URL
- md5 (str, optional): MD5 checksum of the download. If None, do not check
- """
- from six.moves import urllib
-
- root = os.path.expanduser(root)
- if not filename:
- filename = os.path.basename(url)
- fpath = os.path.join(root, filename)
-
- makedir_exist_ok(root)
-
- # downloads file
- if os.path.isfile(fpath) and check_integrity(fpath, md5):
- print('Using downloaded and verified file: ' + fpath)
- else:
- try:
- print('Downloading ' + url + ' to ' + fpath)
- urllib.request.urlretrieve(
- url, fpath,
- reporthook=gen_bar_updater()
- )
- except OSError:
- if url[:5] == 'https':
- url = url.replace('https:', 'http:')
- print('Failed download. Trying https -> http instead.'
- ' Downloading ' + url + ' to ' + fpath)
- urllib.request.urlretrieve(
- url, fpath,
- reporthook=gen_bar_updater()
- )
-
-
-def list_dir(root, prefix=False):
- """List all directories at a given root
- Args:
- root (str): Path to directory whose folders need to be listed
- prefix (bool, optional): If true, prepends the path to each result, otherwise
- only returns the name of the directories found
- """
- root = os.path.expanduser(root)
- directories = list(
- filter(
- lambda p: os.path.isdir(os.path.join(root, p)),
- os.listdir(root)
- )
- )
-
- if prefix is True:
- directories = [os.path.join(root, d) for d in directories]
-
- return directories
-
-
-def list_files(root, suffix, prefix=False):
- """List all files ending with a suffix at a given root
- Args:
- root (str): Path to directory whose folders need to be listed
- suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
- It uses the Python "str.endswith" method and is passed directly
- prefix (bool, optional): If true, prepends the path to each result, otherwise
- only returns the name of the files found
- """
- root = os.path.expanduser(root)
- files = list(
- filter(
- lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
- os.listdir(root)
- )
- )
-
- if prefix is True:
- files = [os.path.join(root, d) for d in files]
-
- return files
-
-
-def download_file_from_google_drive(file_id, root, filename=None, md5=None):
- """Download a Google Drive file from and place it in root.
- Args:
- file_id (str): id of file to be downloaded
- root (str): Directory to place downloaded file in
- filename (str, optional): Name to save the file under. If None, use the id of the file.
- md5 (str, optional): MD5 checksum of the download. If None, do not check
- """
- # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
- import requests
- url = "https://docs.google.com/uc?export=download"
-
- root = os.path.expanduser(root)
- if not filename:
- filename = file_id
- fpath = os.path.join(root, filename)
-
- makedir_exist_ok(root)
-
- if os.path.isfile(fpath) and check_integrity(fpath, md5):
- print('Using downloaded and verified file: ' + fpath)
- else:
- session = requests.Session()
-
- response = session.get(url, params={'id': file_id}, stream=True)
- token = _get_confirm_token(response)
-
- if token:
- params = {'id': file_id, 'confirm': token}
- response = session.get(url, params=params, stream=True)
-
- _save_response_content(response, fpath)
-
-
-def _get_confirm_token(response):
- for key, value in response.cookies.items():
- if key.startswith('download_warning'):
- return value
-
- return None
-
-
-def _save_response_content(response, destination, chunk_size=32768):
- with open(destination, "wb") as f:
- pbar = tqdm(total=None)
- progress = 0
- for chunk in response.iter_content(chunk_size):
- if chunk: # filter out keep-alive new chunks
- f.write(chunk)
- progress += len(chunk)
- pbar.update(progress - pbar.n)
- pbar.close()
\ No newline at end of file
diff --git a/bridge/data/vision.py b/bridge/data/vision.py
deleted file mode 100644
index ce9418c..0000000
--- a/bridge/data/vision.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import os
-
-import torch
-import torch.utils.data as data
-
-
-class VisionDataset(data.Dataset):
- _repr_indent = 4
-
- def __init__(self, root):
- if isinstance(root, torch._six.string_classes):
- root = os.path.expanduser(root)
- self.root = root
-
- def __getitem__(self, index):
- raise NotImplementedError
-
- def __len__(self):
- raise NotImplementedError
-
- def __repr__(self):
- head = "Dataset " + self.__class__.__name__
- body = ["Number of datapoints: {}".format(self.__len__())]
- if self.root is not None:
- body.append("Root location: {}".format(self.root))
- body += self.extra_repr().splitlines()
- if hasattr(self, 'transform') and self.transform is not None:
- body += self._format_transform_repr(self.transform,
- "Transforms: ")
- if hasattr(self, 'target_transform') and self.target_transform is not None:
- body += self._format_transform_repr(self.target_transform,
- "Target transforms: ")
- lines = [head] + [" " * self._repr_indent + line for line in body]
- return '\n'.join(lines)
-
- def _format_transform_repr(self, transform, head):
- lines = transform.__repr__().splitlines()
- return (["{}{}".format(head, lines[0])] +
- ["{}{}".format(" " * len(head), line) for line in lines[1:]])
-
- def extra_repr(self):
- return ""
\ No newline at end of file
diff --git a/bridge/langevin.py b/bridge/langevin.py
deleted file mode 100644
index 01593cb..0000000
--- a/bridge/langevin.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import copy
-import torch
-import torch.nn.functional as F
-from tqdm import tqdm
-import os
-import numpy as np
-
-
-def grad_gauss(x, m, var):
- xout = (x - m) / var
- return -xout
-
-
-def ornstein_ulhenbeck(x, gradx, gamma):
- xout = x + gamma * gradx + \
- torch.sqrt(2 * gamma) * torch.randn(x.shape, device=x.device)
- return xout
-
-
-class Langevin(torch.nn.Module):
-
- def __init__(self, num_steps, shape, gammas, time_sampler, device=None,
- mean_final=torch.tensor([0., 0.]), var_final=torch.tensor([.5, .5]), mean_match=True):
- super().__init__()
-
- self.mean_match = mean_match
- self.mean_final = mean_final
- self.var_final = var_final
-
- self.num_steps = num_steps # num diffusion steps
- self.d = shape # shape of object to diffuse
- self.gammas = gammas.float() # schedule
- gammas_vec = torch.ones(self.num_steps, *self.d, device=device)
- for k in range(num_steps):
- gammas_vec[k] = gammas[k].float()
- self.gammas_vec = gammas_vec
-
- if device is not None:
- self.device = device
- else:
- self.device = gammas.device
-
- self.steps = torch.arange(self.num_steps).to(self.device)
- self.time = torch.cumsum(self.gammas, 0).to(self.device).float()
- self.time_sampler = time_sampler
-
- def record_init_langevin(self, init_samples):
- mean_final = self.mean_final
- var_final = self.var_final
-
- x = init_samples
- N = x.shape[0]
- steps = self.steps.reshape((1, self.num_steps, 1)).repeat((N, 1, 1))
- time = self.time.reshape((1, self.num_steps, 1)).repeat((N, 1, 1))
- gammas = self.gammas.reshape((1, self.num_steps, 1)).repeat((N, 1, 1))
- steps = time
-
- x_tot = torch.Tensor(N, self.num_steps, *self.d).to(x.device)
- out = torch.Tensor(N, self.num_steps, *self.d).to(x.device)
- store_steps = self.steps
- num_iter = self.num_steps
- steps_expanded = time
-
- for k in range(num_iter):
- gamma = self.gammas[k]
- gradx = grad_gauss(x, mean_final, var_final)
- t_old = x + gamma * gradx
- z = torch.randn(x.shape, device=x.device)
- x = t_old + torch.sqrt(2 * gamma)*z
- gradx = grad_gauss(x, mean_final, var_final)
- t_new = x + gamma * gradx
-
- x_tot[:, k, :] = x
- out[:, k, :] = (t_old - t_new) # / (2 * gamma)
-
- return x_tot, out, steps_expanded
-
- def record_langevin_seq(self, net, init_samples, t_batch=None, ipf_it=0, sample=False):
- mean_final = self.mean_final
- var_final = self.var_final
-
- x = init_samples
- N = x.shape[0]
- steps = self.steps.reshape((1, self.num_steps, 1)).repeat((N, 1, 1))
- time = self.time.reshape((1, self.num_steps, 1)).repeat((N, 1, 1))
- gammas = self.gammas.reshape((1, self.num_steps, 1)).repeat((N, 1, 1))
- steps = time
-
- x_tot = torch.Tensor(N, self.num_steps, *self.d).to(x.device)
- out = torch.Tensor(N, self.num_steps, *self.d).to(x.device)
- store_steps = self.steps
- steps_expanded = steps
- num_iter = self.num_steps
-
- if self.mean_match:
- for k in range(num_iter):
- gamma = self.gammas[k]
- t_old = net(x, steps[:, k, :])
-
- if sample & (k == num_iter-1):
- x = t_old
- else:
- z = torch.randn(x.shape, device=x.device)
- x = t_old + torch.sqrt(2 * gamma) * z
-
- t_new = net(x, steps[:, k, :])
- x_tot[:, k, :] = x
- out[:, k, :] = (t_old - t_new)
- else:
- for k in range(num_iter):
- gamma = self.gammas[k]
- t_old = x + net(x, steps[:, k, :])
-
- if sample & (k == num_iter-1):
- x = t_old
- else:
- z = torch.randn(x.shape, device=x.device)
- x = t_old + torch.sqrt(2 * gamma) * z
- t_new = x + net(x, steps[:, k, :])
-
- x_tot[:, k, :] = x
- out[:, k, :] = (t_old - t_new)
-
- return x_tot, out, steps_expanded
-
- def forward(self, net, init_samples, t_batch, ipf_it):
- return self.record_langevin_seq(net, init_samples, t_batch, ipf_it)
diff --git a/bridge/models/basic/__init__.py b/bridge/models/basic/__init__.py
index 5970f25..96a1efe 100644
--- a/bridge/models/basic/__init__.py
+++ b/bridge/models/basic/__init__.py
@@ -1 +1,2 @@
+# Ce fichier permet d'importer facilement ScoreNetwork depuis le dossier basic.
from .basic import ScoreNetwork
\ No newline at end of file
diff --git a/bridge/models/basic/basic.py b/bridge/models/basic/basic.py
index 82aaf93..280048f 100644
--- a/bridge/models/basic/basic.py
+++ b/bridge/models/basic/basic.py
@@ -1,37 +1,74 @@
+"""
+Ce fichier contient une version simple de réseau de score (ScoreNetwork).
+Il s’agit d’un petit réseau de type MLP qui prend en entrée :
+- une donnée x (ex : position / vecteur / état)
+- un temps t (timestep)
+et produit une sortie correspondant à une information "de score" utilisée dans les méthodes
+de diffusion / Schrödinger Bridge. C’est une version légère, plutôt utilisée sur des données simples
+(par ex en 2D), pas sur des images.
+"""
+
+
import torch
from .layers import MLP
from .time_embedding import get_timestep_embedding
class ScoreNetwork(torch.nn.Module):
+ """Réseau de neurones utilisé pour prédire une "direction de correction" (un score)
+ à partir :
+ - d'un état x (ex: un point, une image, un vecteur de données)
+ - d'un temps t (souvent un indice d'étape dans un processus progressif)
+
+ Idée simple : le réseau apprend "comment ajuster x" en fonction de t. """
def __init__(self, encoder_layers=[16], pos_dim=16, decoder_layers=[128,128], x_dim=2):
super().__init__()
+ # Taille de la représentation du temps (combien de nombres servent à représenter t)
self.temb_dim = pos_dim
+ # Après l'encodage du temps, on obtient souvent 2*pos_dim
+ # (car on utilise une représentation sinus/cosinus).
t_enc_dim = pos_dim *2
+ # Stocke les paramètres principaux (pratique pour debug / sauvegarde)
self.locals = [encoder_layers, pos_dim, decoder_layers, x_dim]
- self.net = MLP(2 * t_enc_dim,
+ # Partie principale du réseau :
+ # prend une représentation de x et une représentation de t,
+ # les combine, puis produit une sortie de même dimension que x.
+ self.net = MLP(2 * t_enc_dim, # entrée = concat(x_enc, t_enc)
layer_widths=decoder_layers +[x_dim],
activate_final = False,
activation_fn=torch.nn.LeakyReLU())
-
+
+ # Petit réseau qui transforme la représentation brute du temps
+ # en représentation plus riche (compréhensible par le modèle).
self.t_encoder = MLP(pos_dim,
layer_widths=encoder_layers +[t_enc_dim],
activate_final = False,
activation_fn=torch.nn.LeakyReLU())
-
+
+ # Petit réseau qui transforme x (les données) en représentation plus riche
+ # avant de la mélanger avec le temps.
self.x_encoder = MLP(x_dim,
layer_widths=encoder_layers +[t_enc_dim],
activate_final = False,
activation_fn=torch.nn.LeakyReLU())
def forward(self, x, t):
+ # Assure que x a bien une dimension "batch".
+ # Exemple : si x est un seul point (forme [x_dim]),
+ # on le transforme en [1, x_dim] pour que le réseau fonctionne pareil.
if len(x.shape) == 1:
x = x.unsqueeze(0)
-
+ # Convertit le temps t en un vecteur de nombres (embedding du temps).
+ # But : donner au réseau une façon "riche" de comprendre la notion de temps/étape.
temb = get_timestep_embedding(t, self.temb_dim)
+ # Rend cet embedding temps encore plus adapté au réseau (encodage appris).
temb = self.t_encoder(temb)
+ # Encode aussi x dans un espace de représentation.
xemb = self.x_encoder(x)
+ # Combine l'information "donnée" (x) et "étape" (t) dans un seul vecteur.
h = torch.cat([xemb ,temb], -1)
+ # Produit final : une sortie de dimension x_dim.
+ # Dans ce type de modèle, c'est souvent une "correction" ou un "score" appliqué à x.
out = self.net(h)
return out
diff --git a/bridge/models/basic/layers.py b/bridge/models/basic/layers.py
index 75b4427..1da2f1e 100644
--- a/bridge/models/basic/layers.py
+++ b/bridge/models/basic/layers.py
@@ -1,9 +1,19 @@
+"""
+Ce fichier regroupe des briques de réseau (layers) utilisées par les modèles "basic".
+En général, on y met des composants réutilisables (ex : MLP, blocs, fonctions d’activation).
+L’objectif est d’éviter de réécrire les mêmes couches dans plusieurs fichiers.
+"""
+
+
import torch
from torch import nn
import torch.nn.functional as F
import math
from functools import partial
+# ATTENTION : ce fichier semble contenir une copie de ScoreNetwork.
+# Normalement, layers.py devrait définir des briques de réseau (ex: MLP),
+# utilisées par basic.py.
class MLP(torch.nn.Module):
def __init__(self, input_dim, layer_widths, activate_final = False, activation_fn=F.relu):
diff --git a/bridge/models/basic/time_embedding.py b/bridge/models/basic/time_embedding.py
index be1c17d..e8ba80a 100644
--- a/bridge/models/basic/time_embedding.py
+++ b/bridge/models/basic/time_embedding.py
@@ -1,3 +1,10 @@
+"""
+Ce fichier fournit une fonction d’encodage du temps (timestep embedding).
+L’idée est de transformer un nombre t (étape de diffusion) en un vecteur de taille fixe,
+avec des sinusoïdes (sin/cos), afin que le réseau puisse mieux utiliser l’information temporelle.
+C’est le même principe que les positional encodings utilisés dans les Transformers.
+"""
+
import torch
import torch.nn.functional as F
import math
@@ -5,18 +12,34 @@
def get_timestep_embedding(timesteps, embedding_dim=128):
"""
+ Cette fonction transforme un "temps" (ou numéro d’étape) en un vecteur.
+ Exemple : t=10 devient une liste de nombres de taille embedding_dim.
+
+ Pourquoi faire ça ?
+ - Un réseau de neurones apprend mieux quand t n'est pas juste un nombre brut.
+ - On crée une représentation plus expressive avec des sinusoïdes (sin/cos),
+ comme dans les Transformers ("positional encoding").
+
+ Résultat :
+ - on obtient un vecteur qui change régulièrement avec t
+ - des temps proches donnent des vecteurs "proches"
+
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
https://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py
"""
+ # On construit moitié sin, moitié cos (donc 2 moitiés = embedding_dim)
half_dim = embedding_dim // 2
+ # Crée une série d'échelles (fréquences) pour couvrir plusieurs "rythmes"
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * -emb)
-
+ # Applique ces fréquences au temps
emb = timesteps.float() * emb.unsqueeze(0)
+ # Construit l'encodage final : [sin(...), cos(...)]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+ # Si embedding_dim est impair, on rajoute un 0 pour compléter la taille
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, [0,1])
diff --git a/bridge/models/unet/__init__.py b/bridge/models/unet/__init__.py
index 3e9efd9..4383bbc 100644
--- a/bridge/models/unet/__init__.py
+++ b/bridge/models/unet/__init__.py
@@ -1,2 +1,5 @@
+# Ce fichier simplifie les imports :
+# ailleurs dans le projet, on pourra faire :from models.unet import UNetModel
+
from .unet import UNetModel, SuperResModel
diff --git a/bridge/models/unet/fp16_util.py b/bridge/models/unet/fp16_util.py
index 06a6b56..6334d45 100644
--- a/bridge/models/unet/fp16_util.py
+++ b/bridge/models/unet/fp16_util.py
@@ -1,5 +1,9 @@
"""
-Helpers to train with 16-bit precision.
+Ce fichier contient des outils pour entraîner le modèle avec moins de mémoire grâce au float16 (FP16).
+Le FP16 accélère souvent l’entraînement sur GPU et réduit l’utilisation mémoire, mais peut être moins stable.
+Le fichier propose donc des fonctions pour :
+- convertir certains modules en float16 / float32
+- gérer des paramètres "maîtres" en float32 pour garder une optimisation stable.
"""
import torch.nn as nn
@@ -7,8 +11,11 @@
def convert_module_to_f16(l):
- """
- Convert primitive modules to float16.
+ """
+ Convertit certains modules (convolutions) en float16.
+ Concept :
+ - réduire la mémoire utilisée
+ - accélérer l'entraînement (selon le GPU)
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.half()
@@ -17,7 +24,9 @@ def convert_module_to_f16(l):
def convert_module_to_f32(l):
"""
- Convert primitive modules to float32, undoing convert_module_to_f16().
+ Fait l'inverse : remet les convolutions en float32.
+ Concept :
+ - revenir en précision normale si besoin (stabilité, export, etc.)
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.float()
@@ -26,8 +35,10 @@ def convert_module_to_f32(l):
def make_master_params(model_params):
"""
- Copy model parameters into a (differently-shaped) list of full-precision
- parameters.
+ Crée une copie des paramètres du modèle en float32 (paramètres "maîtres").
+ Concept :
+ - le modèle peut calculer en float16 (rapide)
+ - mais on garde une version float32 pour faire les mises à jour (plus stable)
"""
master_params = _flatten_dense_tensors(
[param.detach().float() for param in model_params]
@@ -39,8 +50,10 @@ def make_master_params(model_params):
def model_grads_to_master_grads(model_params, master_params):
"""
- Copy the gradients from the model parameters into the master parameters
- from make_master_params().
+ Copie les gradients calculés sur le modèle vers les paramètres maîtres.
+ Concept :
+ - le backward calcule des gradients sur les params du modèle
+ - on récupère ces gradients en float32 pour une mise à jour plus stable
"""
master_params[0].grad = _flatten_dense_tensors(
[param.grad.data.detach().float() for param in model_params]
@@ -49,7 +62,10 @@ def model_grads_to_master_grads(model_params, master_params):
def master_params_to_model_params(model_params, master_params):
"""
- Copy the master parameter data back into the model parameters.
+ Copie les paramètres maîtres (float32) vers les paramètres du modèle.
+ Concept :
+ - après l'optimisation (mise à jour), on remet les nouvelles valeurs
+ dans le modèle qui sert à faire les forward/backward
"""
# Without copying to a list, if a generator is passed, this will
# silently not copy any parameters.
@@ -63,12 +79,22 @@ def master_params_to_model_params(model_params, master_params):
def unflatten_master_params(model_params, master_params):
"""
- Unflatten the master parameters to look like model_params.
+ Re-transforme le gros vecteur de paramètres maîtres en une liste de tensors
+ qui ont les mêmes formes que les paramètres du modèle.
+ Concept :
+ - "flatten" = tout mettre bout à bout
+ - "unflatten" = remettre dans les formes d’origine
"""
return _unflatten_dense_tensors(master_params[0].detach(), model_params)
def zero_grad(model_params):
+ """
+ Remet les gradients à zéro.
+ Concept :
+ - en deep learning, on veut éviter d'accumuler les gradients
+ d'une itération à l'autre (sauf cas particulier).
+ """
for param in model_params:
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
if param.grad is not None:
diff --git a/bridge/models/unet/layers.py b/bridge/models/unet/layers.py
index 6195c8a..554cdd6 100644
--- a/bridge/models/unet/layers.py
+++ b/bridge/models/unet/layers.py
@@ -1,3 +1,13 @@
+"""
+Ce fichier contient les briques de base du UNet utilisé dans le modèle de diffusion / bridge.
+On y trouve des modules réutilisables comme :
+- blocs résiduels (ResBlock)
+- attention (AttentionBlock)
+- upsample / downsample
+- embeddings de temps
+- checkpointing (économie mémoire)
+Ce fichier sert de "boîte à outils" pour construire l’architecture du UNet dans unet.py.
+"""
import math
from abc import abstractmethod
@@ -8,11 +18,14 @@
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+# Activation : version compatible si PyTorch n'a pas SiLU.
+# Concept : fonction d’activation "douce", souvent utilisée dans les diffusions.
class SiLU(nn.Module):
def forward(self, x):
return x * th.sigmoid(x)
-
+# Normalisation : stabilise l’apprentissage.
+# Ici, on force les calculs internes en float32 pour éviter des erreurs numériques.
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
@@ -20,7 +33,9 @@ def forward(self, x):
def conv_nd(dims, *args, **kwargs):
"""
- Create a 1D, 2D, or 3D convolution module.
+ Fabrique une couche de convolution 1D / 2D / 3D selon le type de données.
+ Concept :
+ - même code pour images (2D), volumes (3D), signaux (1D)
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
@@ -33,14 +48,16 @@ def conv_nd(dims, *args, **kwargs):
def linear(*args, **kwargs):
"""
- Create a linear module.
+ Crée une couche fully-connected (linéaire).
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
- Create a 1D, 2D, or 3D average pooling module.
+ Pooling moyen 1D / 2D / 3D.
+ Concept :
+ - réduire la taille (downsampling) en gardant une moyenne
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
@@ -53,11 +70,10 @@ def avg_pool_nd(dims, *args, **kwargs):
def update_ema(target_params, source_params, rate=0.99):
"""
- Update target parameters to be closer to those of source parameters using
- an exponential moving average.
- :param target_params: the target parameter sequence.
- :param source_params: the source parameter sequence.
- :param rate: the EMA rate (closer to 1 means slower).
+ Exponential Moving Average (EMA) des poids.
+ Concept :
+ - garder une version "lissée" des poids (souvent meilleure en génération)
+ - rate proche de 1 => mise à jour lente (très lissée)
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
@@ -65,7 +81,10 @@ def update_ema(target_params, source_params, rate=0.99):
def zero_module(module, active=True):
"""
- Zero out the parameters of a module and return it.
+ Met les poids du module à 0.
+ Concept :
+ - démarrer certains blocs comme "neutres" au début de l’entraînement
+ (pour stabiliser / contrôler l’impact d’un bloc)
"""
if active:
for p in module.parameters():
@@ -75,7 +94,9 @@ def zero_module(module, active=True):
def scale_module(module, scale):
"""
- Scale the parameters of a module and return it.
+ Multiplie les poids par une constante.
+ Concept :
+ - ajuster l’intensité d’un bloc (utile pour init / stabilité)
"""
for p in module.parameters():
p.detach().mul_(scale)
@@ -84,28 +105,28 @@ def scale_module(module, scale):
def mean_flat(tensor):
"""
- Take the mean over all non-batch dimensions.
+ Moyenne sur toutes les dimensions sauf le batch.
+ Concept :
+ - calculer une moyenne "par exemple" (par image, par élément du batch)
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
- Make a standard normalization layer.
- :param channels: number of input channels.
- :return: an nn.Module for normalization.
+ Normalisation standard.
+ Concept :
+ - stabiliser l’apprentissage (éviter explosions/instabilités)
"""
return GroupNorm32(32, channels)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
- Create sinusoidal timestep embeddings.
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
- These may be fractional.
- :param dim: the dimension of the output.
- :param max_period: controls the minimum frequency of the embeddings.
- :return: an [N x dim] Tensor of positional embeddings.
+ Encode le temps t en vecteur (sin/cos).
+ Concept :
+ - donner au réseau une représentation riche de l’étape de diffusion
+ (comme les positional encodings des Transformers)
"""
half = dim // 2
freqs = th.exp(
@@ -120,13 +141,10 @@ def timestep_embedding(timesteps, dim, max_period=10000):
def checkpoint(func, inputs, params, flag):
"""
- Evaluate a function without caching intermediate activations, allowing for
- reduced memory at the expense of extra compute in the backward pass.
- :param func: the function to evaluate.
- :param inputs: the argument sequence to pass to `func`.
- :param params: a sequence of parameters `func` depends on but does not
- explicitly take as arguments.
- :param flag: if False, disable gradient checkpointing.
+ Gradient checkpointing.
+ Concept :
+ - économiser de la mémoire pendant l’entraînement
+ - en échange : recalculer certaines choses au backward (plus lent)
"""
if flag:
args = tuple(inputs) + tuple(params)
@@ -136,6 +154,7 @@ def checkpoint(func, inputs, params, flag):
class CheckpointFunction(th.autograd.Function):
+ # Mécanisme interne pour faire du checkpointing (économie mémoire).
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
@@ -146,6 +165,8 @@ def forward(ctx, run_function, length, *args):
return output_tensors
@staticmethod
+ # Recalcule le forward pour pouvoir obtenir les gradients sans stocker
+ # toutes les activations en mémoire.
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with th.enable_grad():
@@ -168,9 +189,10 @@ def backward(ctx, *output_grads):
class TimestepBlock(nn.Module):
"""
- Any module where forward() takes timestep embeddings as a second argument.
+ Interface : bloc qui a besoin du temps (embedding) en plus de x.
+ Concept :
+ - certains blocs du UNet doivent être "conditionnés" par l’étape t.
"""
-
@abstractmethod
def forward(self, x, emb):
"""
@@ -180,10 +202,10 @@ def forward(self, x, emb):
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
- A sequential module that passes timestep embeddings to the children that
- support it as an extra input.
+ Variante de nn.Sequential qui sait passer 'emb' aux couches qui en ont besoin.
+ Concept :
+ - enchaîner des couches, mais garder la possibilité d'injecter l'information temps
"""
-
def forward(self, x, emb):
for layer in self:
if isinstance(layer, TimestepBlock):
@@ -195,13 +217,11 @@ def forward(self, x, emb):
class Upsample(nn.Module):
"""
- An upsampling layer with an optional convolution.
- :param channels: channels in the inputs and outputs.
- :param use_conv: a bool determining if a convolution is applied.
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
- upsampling occurs in the inner-two dimensions.
+ Agrandit une représentation (upsampling).
+ Concept :
+ - dans UNet, on remonte en résolution pour reconstruire une image détaillée
+ - option : ajouter une convolution après l’agrandissement
"""
-
def __init__(self, channels, use_conv, dims=2):
super().__init__()
self.channels = channels
@@ -225,11 +245,9 @@ def forward(self, x):
class Downsample(nn.Module):
"""
- A downsampling layer with an optional convolution.
- :param channels: channels in the inputs and outputs.
- :param use_conv: a bool determining if a convolution is applied.
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
- downsampling occurs in the inner-two dimensions.
+ Réduit la résolution (downsampling).
+ Concept :
+ - dans UNet, on descend en résolution pour capter du contexte global
"""
def __init__(self, channels, use_conv, dims=2):
@@ -247,19 +265,23 @@ def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
+"""
+⚠️ Note : dans Downsample, la ligne self.op = avg_pool_nd(stride) semble bizarre :
+normalement avg_pool_nd attend dims en premier. C’est peut-être un bug/copie.
+"""
class ResBlock(TimestepBlock):
- """
- A residual block that can optionally change the number of channels.
- :param channels: the number of input channels.
- :param emb_channels: the number of timestep embedding channels.
- :param dropout: the rate of dropout.
- :param out_channels: if specified, the number of out channels.
- :param use_conv: if True and out_channels is specified, use a spatial
- convolution instead of a smaller 1x1 convolution to change the
- channels in the skip connection.
- :param dims: determines if the signal is 1D, 2D, or 3D.
- :param use_checkpoint: if True, use gradient checkpointing on this module.
+ """
+ Bloc "résiduel" (ResNet-style), utilisé partout dans le UNet.
+
+ Idée pour débutant :
+ - on applique quelques transformations à l'entrée x
+ - MAIS on ajoute aussi x à la fin ("raccourci", skip connection)
+ → ça aide le réseau à apprendre sans se perdre (plus stable, plus profond)
+
+ Particularité ici :
+ - le bloc est "conditionné" par le temps (emb),
+ donc son comportement dépend de l'étape timesteps.
"""
def __init__(
@@ -274,19 +296,35 @@ def __init__(
use_checkpoint=False,
):
super().__init__()
+
+ # Informations de base (surtout utile pour comprendre/debug)
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
+
+ # Le bloc peut garder le même nombre de canaux ou en changer
self.out_channels = out_channels or channels
+
+ # Options d'architecture
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
+ # Partie 1 : traitement principal de x (normaliser + activer + convolution)
+ # Concept :
+ # - normalisation : stabilise l'entraînement
+ # - activation : rend le modèle non-linéaire
+ # - convolution : extrait / transforme des motifs (features)
self.in_layers = nn.Sequential(
normalization(channels),
SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
+
+ # Partie 2 : transformation de l'embedding temps (emb)
+ # Concept :
+ # - on convertit le "temps" en informations utilisables pour moduler le bloc
+ # - si use_scale_shift_norm=True, on produit 2 infos (scale et shift)
self.emb_layers = nn.Sequential(
SiLU(),
linear(
@@ -294,6 +332,12 @@ def __init__(
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
+
+ # Partie 3 : sortie du bloc (après injection de l'information temps)
+ # Concept :
+ # - on affine les features
+ # - dropout : limite l'overfitting
+ # - zero_module(...) : initialise la dernière conv à 0 pour démarrer doucement
self.out_layers = nn.Sequential(
normalization(self.out_channels),
SiLU(),
@@ -303,44 +347,68 @@ def __init__(
),
)
+ # Skip connection (raccourci) :
+ # Concept :
+ # - si le nombre de canaux ne change pas : on peut ajouter x directement
+ # - sinon : on transforme x pour qu'il ait la même forme que h avant de les additionner
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
+ # Option : utiliser une convolution "classique" (3x3) pour adapter les canaux
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
+ # Option : convolution 1x1 (plus simple) pour adapter uniquement les canaux
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
- Apply the block to a Tensor, conditioned on a timestep embedding.
- :param x: an [N x C x ...] Tensor of features.
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
- :return: an [N x C x ...] Tensor of outputs.
+ Applique le bloc à x en tenant compte de emb (le temps).
+ Concept :
+ - si use_checkpoint=True : on économise de la mémoire à l'entraînement
+ (mais c'est un peu plus lent)
"""
return checkpoint(
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb):
+ # Transforme x (chemin principal)
h = self.in_layers(x)
+
+ # Transforme emb pour influencer le bloc
emb_out = self.emb_layers(emb).type(h.dtype)
+
+ # On adapte la forme de emb_out pour pouvoir l'appliquer sur une image/feature map
+ # (ex: passer de [N, C] à [N, C, 1, 1])
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
+ # Mode "scale/shift" :
+ # Concept : emb ne s'ajoute pas juste, il "modifie" la normalisation
+ # (souvent plus puissant)
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
+ # Mode simple :
+ # Concept : on injecte le temps en l'ajoutant aux features
h = h + emb_out
h = self.out_layers(h)
+ # Résiduel : on ajoute le raccourci (skip connection)
return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
+ Bloc d'attention.
+ Idée pour débutant :
+ - une convolution "voit" surtout localement (autour d'un pixel).
+ - l'attention permet à une zone de l'image de "regarder" toutes les autres zones.
+ → utile pour capturer des relations à longue distance (motifs éloignés).
+
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
@@ -351,62 +419,80 @@ def __init__(self, channels, num_heads=1, use_checkpoint=False):
self.channels = channels
self.num_heads = num_heads
self.use_checkpoint = use_checkpoint
-
+ # Normalisation avant l'attention (stabilité)
self.norm = normalization(channels)
+ # On crée en une seule couche :
+ # Q = Query, K = Key, V = Value (les 3 ingrédients de l'attention)
self.qkv = conv_nd(1, channels, channels * 3, 1)
+ # Calcul de l'attention
self.attention = QKVAttention()
+ # Projection finale, initialisée à 0 pour démarrer doucement
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
+ # Même idée que plus haut : checkpoint optionnel pour économiser de la mémoire
return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
def _forward(self, x):
+ # b = batch size, c = channels, spatial = dimensions spatiales (H,W) ou autre
b, c, *spatial = x.shape
+ # On "aplatit" l'image pour la voir comme une suite de positions
+ # (T = nombre de positions = H*W)
x = x.reshape(b, c, -1)
+ # Prépare Q, K, V
qkv = self.qkv(self.norm(x))
+ # Sépare en plusieurs têtes (multi-head attention)
qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
+ # Calcule l'attention : mélange les infos entre positions
h = self.attention(qkv)
+ # Remet la forme batch normale
h = h.reshape(b, -1, h.shape[-1])
+ # Projection de sortie
h = self.proj_out(h)
+ # Skip connection : on ajoute l'entrée à la sortie (stabilité)
return (x + h).reshape(b, c, *spatial)
class QKVAttention(nn.Module):
"""
- A module which performs QKV attention.
+ Moteur de calcul de l'attention (à partir de Q, K, V).
+ Idée pour débutant :
+ - Q (query) : "qu'est-ce que je cherche ?"
+ - K (key) : "qu'est-ce que je représente ?"
+ - V (value) : "quelle info je donne ?"
+ L'attention calcule quelles positions doivent influencer les autres.
"""
def forward(self, qkv):
"""
- Apply QKV attention.
- :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
- :return: an [N x C x T] tensor after attention.
+ Entrée :
+ - qkv contient Q, K, V concaténés.
+ Sortie :
+ - une nouvelle représentation où chaque position est un mélange
+ d’informations venant d’autres positions.
"""
+ # On coupe qkv en 3 morceaux : Q, K,
ch = qkv.shape[1] // 3
q, k, v = th.split(qkv, ch, dim=1)
+ # Normalisation pour stabiliser les produits (évite des valeurs trop grandes)
scale = 1 / math.sqrt(math.sqrt(ch))
+ # Calcule la "similarité" entre toutes les positions (poids d'attention)
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
- ) # More stable with f16 than dividing afterwards
+ )
+ # Transforme ces similarités en probabilités (somme = 1)
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ # Combine les valeurs V selon ces poids
return th.einsum("bts,bcs->bct", weight, v)
@staticmethod
def count_flops(model, _x, y):
"""
- A counter for the `thop` package to count the operations in an
- attention operation.
- Meant to be used like:
- macs, params = thop.profile(
- model,
- inputs=(inputs, timestamps),
- custom_ops={QKVAttention: QKVAttention.count_flops},
- )
+ Sert uniquement à estimer le coût de calcul (nombre d'opérations).
+ Utile pour du profiling (mesurer la lourdeur du modèle).
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
- # We perform two matmuls with the same number of ops.
- # The first computes the weight matrix, the second computes
- # the combination of the value vectors.
+ #L'attention est coûteuse car elle compare toutes les positions entre elles.
matmul_ops = 2 * b * (num_spatial ** 2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
diff --git a/bridge/models/unet/unet.py b/bridge/models/unet/unet.py
index 0a2ee29..a8c276c 100644
--- a/bridge/models/unet/unet.py
+++ b/bridge/models/unet/unet.py
@@ -1,3 +1,13 @@
+"""
+Ce fichier définit l’architecture principale UNetModel (et SuperResModel).
+Le UNet est le réseau utilisé pour traiter des images dans les modèles de diffusion :
+- il descend en résolution pour capturer le contexte global
+- puis remonte en résolution pour reconstruire des détails
+- il utilise des connexions "skip" pour ne pas perdre d’informations
+Le réseau est conditionné par le temps (timesteps) et peut inclure de l’attention à certaines résolutions.
+"""
+
+
from abc import abstractmethod
import math
@@ -13,24 +23,17 @@
class UNetModel(nn.Module):
"""
- The full UNet model with attention and timestep embedding.
- :param in_channels: channels in the input Tensor.
- :param model_channels: base channel count for the model.
- :param out_channels: channels in the output Tensor.
- :param num_res_blocks: number of residual blocks per downsample.
- :param attention_resolutions: a collection of downsample rates at which
- attention will take place. May be a set, list, or tuple.
- For example, if this contains 4, then at 4x downsampling, attention
- will be used.
- :param dropout: the dropout probability.
- :param channel_mult: channel multiplier for each level of the UNet.
- :param conv_resample: if True, use learned convolutions for upsampling and
- downsampling.
- :param dims: determines if the signal is 1D, 2D, or 3D.
- :param num_classes: if specified (as an int), then this model will be
- class-conditional with `num_classes` classes.
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
- :param num_heads: the number of attention heads in each attention layer.
+ UNet avec :
+ - un encodage du temps (timesteps)
+ - des blocs résiduels (ResBlock)
+ - de l’attention à certaines résolutions (AttentionBlock)
+
+ Concept UNet :
+ 1) on "descend" en résolution (downsampling) pour comprendre le contexte global
+ 2) on passe par un "milieu" (middle block)
+ 3) on "remonte" en résolution (upsampling) pour reconstruire des détails
+ 4) on utilise des "skip connections" : on garde des infos de la descente
+ et on les réinjecte pendant la remontée pour ne pas perdre les détails.
"""
def __init__(
@@ -54,6 +57,7 @@ def __init__(
if num_heads_upsample == -1:
num_heads_upsample = num_heads
+ # Stocke les hyperparamètres (utile pour reproduire/charger un modèle)
self.locals = [ in_channels,
model_channels,
out_channels,
@@ -82,16 +86,20 @@ def __init__(
self.num_heads = num_heads
self.num_heads_upsample = num_heads_upsample
+ # Encodage du temps :
+ # transforme un "timestep" en un vecteur qui sera utilisé partout dans le réseau.
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
SiLU(),
linear(time_embed_dim, time_embed_dim),
)
-
+ # Option : modèle conditionné par une classe (ex: générer une catégorie précise)
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ # input_blocks = la partie "descente" du UNet
+ # On commence par une convolution pour passer dans l’espace de features.
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
@@ -131,6 +139,13 @@ def __init__(
input_block_chans.append(ch)
ds *= 2
+ # On construit ensuite plusieurs niveaux :
+ # - à chaque niveau, on applique des ResBlocks
+ # - parfois, on ajoute Attention
+ # - puis on downsample pour aller vers une résolution plus basse
+
+ # middle_block = le fond du UNet
+ # Résolution la plus basse : beaucoup de contexte, moins de détail spatial.
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
@@ -178,7 +193,10 @@ def __init__(
layers.append(Upsample(ch, conv_resample, dims=dims))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
-
+ # output_blocks = la partie "remontée" du UNet
+ # On remonte en résolution et on concatène les features de la descente (skip connections).
+
+ # out = dernière partie : remet dans le nombre de canaux souhaité (ex: prédire un bruit / une image)
self.out = nn.Sequential(
normalization(ch),
SiLU(),
@@ -187,17 +205,13 @@ def __init__(
def convert_to_fp16(self):
- """
- Convert the torso of the model to float16.
- """
+ # Passe les blocs principaux en float16 (gain mémoire/perf)
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self):
- """
- Convert the torso of the model to float32.
- """
+ # Reviens en float32
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
@@ -206,60 +220,99 @@ def convert_to_fp32(self):
def forward(self, x, timesteps, y=None):
"""
- Apply the model to an input batch.
- :param x: an [N x C x ...] Tensor of inputs.
- :param timesteps: a 1-D batch of timesteps.
- :param y: an [N] Tensor of labels, if class-conditional.
- :return: an [N x C x ...] Tensor of outputs.
+ Concept : faire passer x dans le UNet en tenant compte du temps (timesteps).
+
+ Entrées :
+ - x : batch d’images/features
+ - timesteps : étape de diffusion (ou étape temporelle)
+ - y : label (si modèle conditionnel)
+
+ Sortie :
+ - tenseur de même forme générale que x (selon out_channels),
+ typiquement une prédiction liée au processus de diffusion (ex: bruit / score).
"""
+
+ # Mise en forme : on enlève les dimensions inutiles pour avoir un vecteur de timesteps propre.
timesteps = timesteps.squeeze()
+ # Vérification : si le modèle est conditionnel (num_classes défini),
+ # alors on DOIT fournir y, et inversement.
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
+ # hs va stocker les sorties intermédiaires de la "descente" (downsampling).
+ # Concept : ces valeurs serviront plus tard comme "raccourcis" (skip connections)
+ # pour récupérer les détails lors de la remontée.
hs = []
+
+ # On transforme timesteps en vecteur riche (embedding du temps),
+ # puis on l'adapte au réseau (time_embed).
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+ # Si le modèle est conditionné par une classe :
+ # on ajoute une information de classe à l'embedding du temps.
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
+
+ # h est le "signal" qui va traverser tout le UNet.
+ h = x #.type(self.inner_dtype)
- h = x#.type(self.inner_dtype)
+ # 1) DESCENTE du UNet :
+ # on applique les blocs d'entrée (resblocks + downsample) et on mémorise chaque étape.
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
+
+ # 2) MILIEU du UNet :
+ # partie la plus "profonde", à plus basse résolution (beaucoup de contexte global).
h = self.middle_block(h, emb)
+
+ # 3) REMONTÉE du UNet :
+ # à chaque étape, on concatène avec un état sauvegardé de la descente
+ # pour récupérer des informations fines (skip connections).
for module in self.output_blocks:
cat_in = th.cat([h, hs.pop()], dim=1)
h = module(cat_in, emb)
+
+ # Sécurité : on remet le même type numérique que l'entrée.
h = h.type(x.dtype)
+
+ # Dernière couche : produit la sortie dans le bon nombre de canaux.
return self.out(h)
def get_feature_vectors(self, x, timesteps, y=None):
"""
- Apply the model and return all of the intermediate tensors.
- :param x: an [N x C x ...] Tensor of inputs.
- :param timesteps: a 1-D batch of timesteps.
- :param y: an [N] Tensor of labels, if class-conditional.
- :return: a dict with the following keys:
- - 'down': a list of hidden state tensors from downsampling.
- - 'middle': the tensor of the output of the lowest-resolution
- block in the model.
- - 'up': a list of hidden state tensors from upsampling.
+ Concept : même passage que forward(), mais on récupère aussi les étapes intermédiaires.
+ Utilité :
+ - déboguer / visualiser ce que le réseau "apprend"
+ - analyser les features à différentes résolutions
"""
+ # Même logique que forward : on construit l'embedding temps (et classe si besoin)
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
+
+ # Dictionnaire de sortie :
+ # - down : ce qui sort pendant la descente
+ # - middle : la représentation au fond du UNet
+ # - up : ce qui sort pendant la remontée
result = dict(down=[], up=[])
+
h = x#.type(self.inner_dtype)
+
+ # DESCENTE : on stocke à la fois dans hs (pour les skip connections)
+ # et dans result["down"] pour pouvoir les retourner.
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
result["down"].append(h.type(x.dtype))
+ # MILIEU
h = self.middle_block(h, emb)
result["middle"] = h.type(x.dtype)
+ # REMONTÉE : pareil que forward, mais on sauvegarde chaque étape.
for module in self.output_blocks:
cat_in = th.cat([h, hs.pop()], dim=1)
h = module(cat_in, emb)
@@ -267,22 +320,33 @@ def get_feature_vectors(self, x, timesteps, y=None):
return result
+
class SuperResModel(UNetModel):
"""
- A UNetModel that performs super-resolution.
- Expects an extra kwarg `low_res` to condition on a low-resolution image.
+ Variante du UNet pour faire de la super-résolution.
+
+ Concept :
+ - le modèle reçoit une image (ou une estimation) en haute résolution
+ - et une image low_res (basse résolution) qui sert de "guide"
"""
def __init__(self, in_channels, *args, **kwargs):
+ # On multiplie par 2 car on va concaténer x et l'image low_res agrandie :
+ # donc on a 2 fois plus de canaux en entrée.
super().__init__(in_channels * 2, *args, **kwargs)
def forward(self, x, timesteps, low_res=None, **kwargs):
+ # On agrandit low_res pour qu'elle ait la même taille que x.
_, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
+ # On concatène x et low_res agrandie (en canaux) pour donner au UNet
+ # l'image "guide" en plus de l'image à traiter.
x = th.cat([x, upsampled], dim=1)
+ # Puis on appelle le forward du UNet standard.
return super().forward(x, timesteps, **kwargs)
def get_feature_vectors(self, x, timesteps, low_res=None, **kwargs):
+ # Même idée que forward, mais on retourne aussi les features intermédiaires.
_, new_height, new_width, _ = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
x = th.cat([x, upsampled], dim=1)
diff --git a/bridge/runners/__init__.py b/bridge/runners/__init__.py
deleted file mode 100644
index d8421a2..0000000
--- a/bridge/runners/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from itertools import repeat
-
-def repeater(data_loader):
- for loader in repeat(data_loader):
- for data in loader:
- yield data
\ No newline at end of file
diff --git a/bridge/runners/config_getters.py b/bridge/runners/config_getters.py
deleted file mode 100755
index 450d62b..0000000
--- a/bridge/runners/config_getters.py
+++ /dev/null
@@ -1,216 +0,0 @@
-import torch
-from ..models import *
-from ..data.two_dim import two_dim_ds
-from ..data.stackedmnist import Stacked_MNIST
-from ..data.emnist import EMNIST
-from ..data.celeba import CelebA
-from .plotters import TwoDPlotter, ImPlotter
-from torch.utils.data import TensorDataset
-import torchvision.transforms as transforms
-import os
-from .logger import CSVLogger, NeptuneLogger, Logger
-from torch.utils.data import DataLoader
-cmp = lambda x: transforms.Compose([*x])
-
-def get_plotter(runner, args):
- dataset_tag = getattr(args, DATASET)
- if dataset_tag == DATASET_2D:
- return TwoDPlotter(num_steps=runner.num_steps, gammas=runner.langevin.gammas)
- else:
- return ImPlotter(plot_level = args.plot_level)
-
-# Model
-#--------------------------------------------------------------------------------
-
-MODEL = 'Model'
-BASIC_MODEL = 'Basic'
-UNET_MODEL = 'UNET'
-
-
-def get_models(args):
- model_tag = getattr(args, MODEL)
-
- if model_tag == BASIC_MODEL:
- net_f, net_b = ScoreNetwork(), ScoreNetwork()
-
- if model_tag == UNET_MODEL:
- image_size=args.data.image_size
-
- if image_size == 256:
- channel_mult = (1, 1, 2, 2, 4, 4)
- elif image_size == 64:
- channel_mult = (1, 2, 3, 4)
- elif image_size == 32:
- channel_mult = (1, 2, 2, 2)
- elif image_size == 28:
- channel_mult = (1, 2, 2)
- else:
- raise ValueError(f"unsupported image size: {image_size}")
-
- attention_ds = []
- for res in args.model.attention_resolutions.split(","):
- attention_ds.append(image_size // int(res))
-
- kwargs = {
- "in_channels": args.data.channels,
- "model_channels": args.model.num_channels,
- "out_channels": args.data.channels,
- "num_res_blocks": args.model.num_res_blocks,
- "attention_resolutions": tuple(attention_ds),
- "dropout": args.model.dropout,
- "channel_mult": channel_mult,
- "num_classes": None,
- "use_checkpoint": args.model.use_checkpoint,
- "num_heads": args.model.num_heads,
- "num_heads_upsample": args.model.num_heads_upsample,
- "use_scale_shift_norm": args.model.use_scale_shift_norm
- }
-
- net_f, net_b = UNetModel(**kwargs), UNetModel(**kwargs)
-
- return net_f, net_b
-
-# Optimizer
-#--------------------------------------------------------------------------------
-def get_optimizers(net_f, net_b, lr):
- return torch.optim.Adam(net_f.parameters(), lr=lr), torch.optim.Adam(net_b.parameters(), lr=lr)
-
-# Dataset
-#--------------------------------------------------------------------------------
-
-DATASET = 'Dataset'
-DATASET_TRANSFER = 'Dataset_transfer'
-DATASET_2D = '2d'
-DATASET_CELEBA = 'celeba'
-DATASET_STACKEDMNIST = 'stackedmnist'
-DATASET_EMNIST = 'emnist'
-
-
-def get_datasets(args):
- dataset_tag = getattr(args, DATASET)
- if args.transfer:
- dataset_transfer_tag = getattr(args, DATASET_TRANSFER)
- else:
- dataset_transfer_tag = None
-
- # INITIAL (DATA) DATASET
-
- # 2D DATASET
-
- if dataset_tag == DATASET_2D:
- data_tag = args.data
- npar = max(args.npar, args.cache_npar)
- init_ds = two_dim_ds(npar, data_tag)
-
- if dataset_transfer_tag == DATASET_2D:
- data_tag = args.data_transfer
- npar = max(args.npar, args.cache_npar)
- final_ds = two_dim_ds(npar, data_tag)
- mean_final = torch.tensor(0.)
- var_final = torch.tensor(1.*10**3) #infty like
-
- # CELEBA DATASET
-
- if dataset_tag == DATASET_CELEBA:
-
- train_transform = [transforms.CenterCrop(140), transforms.Resize(args.data.image_size), transforms.ToTensor()]
- test_transform = [transforms.CenterCrop(140), transforms.Resize(args.data.image_size), transforms.ToTensor()]
- if args.data.random_flip:
- train_transform.insert(2, transforms.RandomHorizontalFlip())
-
-
- root = os.path.join(args.data_dir, 'celeba')
- init_ds = CelebA(root, split='train', transform=cmp(train_transform), download=False)
-
- # MNIST DATASET
-
- if dataset_tag == DATASET_STACKEDMNIST:
- root = os.path.join(args.data_dir, 'mnist')
- saved_file = os.path.join(root, "data.pt")
- load = os.path.exists(saved_file)
- load = args.load
- init_ds = Stacked_MNIST(root, load=load, source_root=root,
- train=True, num_channels = args.data.channels,
- imageSize=args.data.image_size,
- device=args.device)
-
- if dataset_transfer_tag == DATASET_STACKEDMNIST:
- root = os.path.join(args.data_dir, 'mnist')
- saved_file = os.path.join(root, "data.pt")
- load = os.path.exists(saved_file)
- load = args.load
- final_ds = Stacked_MNIST(root, load=load, source_root=root,
- train=True, num_channels = args.data.channels,
- imageSize=args.data.image_size,
- device=args.device)
- mean_final = torch.tensor(0.)
- var_final = torch.tensor(1.*10**3)
-
- # EMNIST DATASET
-
- if dataset_tag == DATASET_EMNIST:
- root = os.path.join(args.data_dir, 'EMNIST')
- saved_file = os.path.join(root, "data.pt")
- load = os.path.exists(saved_file)
- load = args.load
- init_ds = EMNIST(root, load=load, source_root=root,
- train=True, num_channels = args.data.channels,
- imageSize=args.data.image_size,
- device=args.device)
-
- if dataset_transfer_tag == DATASET_EMNIST:
- root = os.path.join(args.data_dir, 'EMNIST')
- saved_file = os.path.join(root, "data.pt")
- load = os.path.exists(saved_file)
- load = args.load
- final_ds = EMNIST(root, load=load, source_root=root,
- train=True, num_channels = args.data.channels,
- imageSize=args.data.image_size,
- device=args.device)
- mean_final = torch.tensor(0.)
- var_final = torch.tensor(1.*10**3)
-
-
- # FINAL (GAUSSIAN) DATASET (if no transfer)
-
- if not(args.transfer):
- if args.adaptive_mean:
- NAPPROX = 100
- vec = next(iter(DataLoader(init_ds, batch_size=NAPPROX)))[0]
- mean_final = vec.mean()
- mean_final = vec[0] * 0 + mean_final
- var_final = eval(args.var_final)
- final_ds = None
- elif args.final_adaptive:
- NAPPROX = 100
- vec = next(iter(DataLoader(init_ds, batch_size=NAPPROX)))[0]
- mean_final = vec.mean(axis=0)
- var_final = vec.var()
- final_ds = None
- else:
- mean_final = eval(args.mean_final)
- var_final = eval(args.var_final)
- final_ds = None
-
-
- return init_ds, final_ds, mean_final, var_final
-
-
-# Logger
-#--------------------------------------------------------------------------------
-
-LOGGER = 'LOGGER'
-LOGGER_PARAMS = 'LOGGER_PARAMS'
-
-CSV_TAG = 'CSV'
-NOLOG_TAG = 'NONE'
-
-def get_logger(args, name):
- logger_tag = getattr(args, LOGGER)
-
- if logger_tag == CSV_TAG:
- kwargs = {'directory': args.CSV_log_dir, 'name': name}
- return CSVLogger(**kwargs)
-
- if logger_tag == NOLOG_TAG:
- return Logger()
diff --git a/bridge/runners/ema.py b/bridge/runners/ema.py
deleted file mode 100644
index 392063e..0000000
--- a/bridge/runners/ema.py
+++ /dev/null
@@ -1,50 +0,0 @@
-import torch
-from torch import nn
-
-class EMAHelper(object):
- def __init__(self, mu=0.999, device="cpu"):
- self.mu = mu
- self.shadow = {}
- self.device = device
-
- def register(self, module):
- if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.distributed.DistributedDataParallel):
- module = module.module
- for name, param in module.named_parameters():
- if param.requires_grad:
- self.shadow[name] = param.data.clone()
-
- def update(self, module):
- if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.distributed.DistributedDataParallel):
- module = module.module
- for name, param in module.named_parameters():
- if param.requires_grad:
- self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data
-
- def ema(self, module):
- if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.distributed.DistributedDataParallel):
- module = module.module
- for name, param in module.named_parameters():
- if param.requires_grad:
- param.data.copy_(self.shadow[name].data)
-
- def ema_copy(self, module):
- if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.distributed.DistributedDataParallel):
- inner_module = module.module
- locs = inner_module.locals
- module_copy = type(inner_module)(*locs).to(self.device)
- module_copy.load_state_dict(inner_module.state_dict())
- if isinstance(module, nn.DataParallel):
- module_copy = nn.DataParallel(module_copy)
- else:
- locs = module.locals
- module_copy = type(module)(*locs).to(self.device)
- module_copy.load_state_dict(module.state_dict())
- self.ema(module_copy)
- return module_copy
-
- def state_dict(self):
- return self.shadow
-
- def load_state_dict(self, state_dict):
- self.shadow = state_dict
diff --git a/bridge/runners/ipf.py b/bridge/runners/ipf.py
deleted file mode 100755
index d8a6d24..0000000
--- a/bridge/runners/ipf.py
+++ /dev/null
@@ -1,431 +0,0 @@
-import torch
-import os
-import sys
-import torch.nn.functional as F
-import numpy as np
-from ..langevin import Langevin
-from torch.utils.data import DataLoader
-from .config_getters import get_models, get_optimizers, get_datasets, get_plotter, get_logger
-import datetime
-from tqdm import tqdm
-from .ema import EMAHelper
-from . import repeater
-import time
-import random
-import torch.autograd.profiler as profiler
-from ..data import CacheLoader
-from torch.utils.data import WeightedRandomSampler
-from accelerate import Accelerator, DistributedType
-import time
-
-
-class IPFBase(torch.nn.Module):
-
- def __init__(self, args):
- super().__init__()
- self.args = args
-
- #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- self.accelerator = Accelerator(fp16=False, cpu=args.device == 'cpu')
- self.device = self.accelerator.device # torch.device(args.device)
-
- # training params
- self.n_ipf = self.args.n_ipf
- self.num_steps = self.args.num_steps
- self.batch_size = self.args.batch_size
- self.num_iter = self.args.num_iter
- self.grad_clipping = self.args.grad_clipping
- self.fast_sampling = self.args.fast_sampling
- self.lr = self.args.lr
-
- n = self.num_steps//2
- if self.args.gamma_space == 'linspace':
- gamma_half = np.linspace(self.args.gamma_min, args.gamma_max, n)
- elif self.args.gamma_space == 'geomspace':
- gamma_half = np.geomspace(
- self.args.gamma_min, self.args.gamma_max, n)
- gammas = np.concatenate([gamma_half, np.flip(gamma_half)])
- gammas = torch.tensor(gammas).to(self.device)
- self.T = torch.sum(gammas)
-
- # get models
- self.build_models()
- self.build_ema()
-
- # get optims
- self.build_optimizers()
-
- # get loggers
- self.logger = self.get_logger()
- self.save_logger = self.get_logger('plot_logs')
-
- # get data
- self.build_dataloaders()
-
- # langevin
- if self.args.weight_distrib:
- alpha = self.args.weight_distrib_alpha
- prob_vec = (1 + alpha) * torch.sum(gammas) - \
- torch.cumsum(gammas, 0)
- else:
- prob_vec = gammas * 0 + 1
- time_sampler = torch.distributions.categorical.Categorical(prob_vec)
-
- batch = next(self.save_init_dl)[0]
- shape = batch[0].shape
- self.shape = shape
- self.langevin = Langevin(self.num_steps, shape, gammas,
- time_sampler, device=self.device,
- mean_final=self.mean_final, var_final=self.var_final,
- mean_match=self.args.mean_match)
-
- # checkpoint
- date = str(datetime.datetime.now())[0:10]
- self.name_all = date
-
- # run from checkpoint
- self.checkpoint_run = self.args.checkpoint_run
- if self.args.checkpoint_run:
- self.checkpoint_it = self.args.checkpoint_it
- self.checkpoint_pass = self.args.checkpoint_pass
- else:
- self.checkpoint_it = 1
- self.checkpoint_pass = 'b'
-
- self.plotter = self.get_plotter()
-
- if self.accelerator.process_index == 0:
- if not os.path.exists('./im'):
- os.mkdir('./im')
- if not os.path.exists('./gif'):
- os.mkdir('./gif')
- if not os.path.exists('./checkpoints'):
- os.mkdir('./checkpoints')
-
- self.stride = self.args.gif_stride
- self.stride_log = self.args.log_stride
-
- def get_logger(self, name='logs'):
- return get_logger(self.args, name)
-
- def get_plotter(self):
- return get_plotter(self, self.args)
-
- def build_models(self, forward_or_backward=None):
- # running network
- net_f, net_b = get_models(self.args)
-
- if self.args.checkpoint_run:
- if "checkpoint_f" in self.args:
- net_f.load_state_dict(torch.load(self.args.checkpoint_f))
- if "checkpoint_b" in self.args:
- net_b.load_state_dict(torch.load(self.args.checkpoint_b))
-
- if self.args.dataparallel:
- net_f = torch.nn.DataParallel(net_f)
- net_b = torch.nn.DataParallel(net_b)
-
- if forward_or_backward is None:
- net_f = net_f.to(self.device)
- net_b = net_b.to(self.device)
- self.net = torch.nn.ModuleDict({'f': net_f, 'b': net_b})
- if forward_or_backward == 'f':
- net_f = net_f.to(self.device)
- self.net.update({'f': net_f})
- if forward_or_backward == 'b':
- net_b = net_b.to(self.device)
- self.net.update({'b': net_b})
-
- def accelerate(self, forward_or_backward):
- (self.net[forward_or_backward], self.optimizer[forward_or_backward]) = self.accelerator.prepare(
- self.net[forward_or_backward], self.optimizer[forward_or_backward])
-
- def update_ema(self, forward_or_backward):
- if self.args.ema:
- self.ema_helpers[forward_or_backward] = EMAHelper(
- mu=self.args.ema_rate, device=self.device)
- self.ema_helpers[forward_or_backward].register(
- self.net[forward_or_backward])
-
- def build_ema(self):
- if self.args.ema:
- self.ema_helpers = {}
- self.update_ema('f')
- self.update_ema('b')
-
- if self.args.checkpoint_run:
- # sample network
- sample_net_f, sample_net_b = get_models(self.args)
-
- if "sample_checkpoint_f" in self.args:
- sample_net_f.load_state_dict(
- torch.load(self.args.sample_checkpoint_f))
- if self.args.dataparallel:
- sample_net_f = torch.nn.DataParallel(sample_net_f)
- sample_net_f = sample_net_f.to(self.device)
- self.ema_helpers['f'].register(sample_net_f)
- if "sample_checkpoint_b" in self.args:
- sample_net_b.load_state_dict(
- torch.load(self.args.sample_checkpoint_b))
- if self.args.dataparallel:
- sample_net_b = torch.nn.DataParallel(sample_net_b)
- sample_net_b = sample_net_b.to(self.device)
- self.ema_helpers['b'].register(sample_net_b)
-
- def build_optimizers(self):
- optimizer_f, optimizer_b = get_optimizers(
- self.net['f'], self.net['b'], self.lr)
- optimizer_b = optimizer_b
- optimizer_f = optimizer_f
- self.optimizer = {'f': optimizer_f, 'b': optimizer_b}
-
- def build_dataloaders(self):
- init_ds, final_ds, mean_final, var_final = get_datasets(self.args)
-
- self.mean_final = mean_final.to(self.device)
- self.var_final = var_final.to(self.device)
- self.std_final = torch.sqrt(var_final).to(self.device)
-
- def worker_init_fn(worker_id):
- np.random.seed(np.random.get_state()[
- 1][0] + worker_id + self.accelerator.process_index)
-
- self.kwargs = {"num_workers": self.args.num_workers,
- "pin_memory": self.args.pin_memory,
- "worker_init_fn": worker_init_fn,
- "drop_last": True}
-
- # get plotter, gifs etc.
- self.save_init_dl = DataLoader(
- init_ds, batch_size=self.args.plot_npar, shuffle=True, **self.kwargs)
- self.cache_init_dl = DataLoader(
- init_ds, batch_size=self.args.cache_npar, shuffle=True, **self.kwargs)
- (self.cache_init_dl, self.save_init_dl) = self.accelerator.prepare(
- self.cache_init_dl, self.save_init_dl)
- self.cache_init_dl = repeater(self.cache_init_dl)
- self.save_init_dl = repeater(self.save_init_dl)
-
- if self.args.transfer:
- self.save_final_dl = DataLoader(
- final_ds, batch_size=self.args.plot_npar, shuffle=True, **self.kwargs)
- self.cache_final_dl = DataLoader(
- final_ds, batch_size=self.args.cache_npar, shuffle=True, **self.kwargs)
- (self.cache_final_dl, self.save_final_dl) = self.accelerator.prepare(
- self.cache_final_dl, self.save_final_dl)
- self.cache_final_dl = repeater(self.cache_final_dl)
- self.save_final_dl = repeater(self.save_final_dl)
- else:
- self.cache_final_dl = None
- self.save_final = None
-
- def new_cacheloader(self, forward_or_backward, n, use_ema=True):
-
- sample_direction = 'f' if forward_or_backward == 'b' else 'b'
- if use_ema:
- sample_net = self.ema_helpers[sample_direction].ema_copy(
- self.net[sample_direction])
- else:
- sample_net = self.net[sample_direction]
-
- if forward_or_backward == 'b':
- sample_net = self.accelerator.prepare(sample_net)
- new_dl = CacheLoader('b',
- sample_net,
- self.cache_init_dl,
- self.args.num_cache_batches,
- self.langevin, n,
- mean=None,
- std=None,
- batch_size=self.args.cache_npar,
- device=self.device,
- dataloader_f=self.cache_final_dl,
- transfer=self.args.transfer)
-
- else: # forward
- sample_net = self.accelerator.prepare(sample_net)
- new_dl = CacheLoader('f',
- sample_net,
- None,
- self.args.num_cache_batches,
- self.langevin, n,
- mean=self.mean_final,
- std=self.std_final,
- batch_size=self.args.cache_npar,
- device=self.device,
- dataloader_f=self.cache_final_dl,
- transfer=self.args.transfer)
-
- new_dl = DataLoader(new_dl, batch_size=self.batch_size)
-
- new_dl = self.accelerator.prepare(new_dl)
- new_dl = repeater(new_dl)
- return new_dl
-
- def train(self):
- pass
-
- def save_step(self, i, n, fb):
- if self.accelerator.is_local_main_process:
- if ((i % self.stride == 0) or (i % self.stride == 1)) and (i > 0):
-
- if self.args.ema:
- sample_net = self.ema_helpers[fb].ema_copy(self.net[fb])
- else:
- sample_net = self.net[fb]
-
- name_net = 'net' + '_' + fb + '_' + \
- str(n) + "_" + str(i) + '.ckpt'
- name_net_ckpt = './checkpoints/' + name_net
-
- if self.args.dataparallel:
- torch.save(self.net[fb].module.state_dict(), name_net_ckpt)
- else:
- torch.save(self.net[fb].state_dict(), name_net_ckpt)
-
- if self.args.ema:
- name_net = 'sample_net' + '_' + fb + \
- '_' + str(n) + "_" + str(i) + '.ckpt'
- name_net_ckpt = './checkpoints/' + name_net
- if self.args.dataparallel:
- torch.save(sample_net.module.state_dict(),
- name_net_ckpt)
- else:
- torch.save(sample_net.state_dict(), name_net_ckpt)
-
- with torch.no_grad():
- self.set_seed(seed=0 + self.accelerator.process_index)
- if fb == 'f':
- batch = next(self.save_init_dl)[0]
- batch = batch.to(self.device)
- elif self.args.transfer:
- batch = next(self.save_final_dl)[0]
- batch = batch.to(self.device)
- else:
- batch = self.mean_final + self.std_final * \
- torch.randn(
- (self.args.plot_npar, *self.shape), device=self.device)
-
- x_tot, out, steps_expanded = self.langevin.record_langevin_seq(
- sample_net, batch, ipf_it=n, sample=True)
- shape_len = len(x_tot.shape)
- x_tot = x_tot.permute(1, 0, *list(range(2, shape_len)))
- x_tot_plot = x_tot.detach() # .cpu().numpy()
-
- init_x = batch.detach().cpu().numpy()
- final_x = x_tot_plot[-1].detach().cpu().numpy()
- std_final = np.std(final_x)
- std_init = np.std(init_x)
- mean_final = np.mean(final_x)
- mean_init = np.mean(init_x)
-
- print('Initial variance: ' + str(std_init ** 2))
- print('Final variance: ' + str(std_final ** 2))
-
- self.save_logger.log_metrics({'FB': fb,
- 'init_var': std_init**2, 'final_var': std_final**2,
- 'mean_init': mean_init, 'mean_final': mean_final,
- 'T': self.T})
-
- self.plotter(batch, x_tot_plot, i, n, fb)
-
- def set_seed(self, seed=0):
- torch.manual_seed(seed)
- random.seed(seed)
- np.random.seed(seed)
- torch.cuda.manual_seed_all(seed)
-
- def clear(self):
- torch.cuda.empty_cache()
-
-
-class IPFSequential(IPFBase):
-
- def __init__(self, args):
- super().__init__(args)
-
- def ipf_step(self, forward_or_backward, n):
- new_dl = None
- new_dl = self.new_cacheloader(forward_or_backward, n, self.args.ema)
-
- if not self.args.use_prev_net:
- self.build_models(forward_or_backward)
- self.update_ema(forward_or_backward)
-
- self.build_optimizers()
- self.accelerate(forward_or_backward)
-
- for i in tqdm(range(self.num_iter+1)):
- self.set_seed(seed=n*self.num_iter+i)
-
- x, out, steps_expanded = next(new_dl)
- x = x.to(self.device)
- out = out.to(self.device)
- steps_expanded = steps_expanded.to(self.device)
- # eval_steps = self.num_steps - 1 - steps_expanded
- eval_steps = self.T - steps_expanded
-
- if self.args.mean_match:
- pred = self.net[forward_or_backward](
- x, eval_steps) - x
- else:
- pred = self.net[forward_or_backward](x, eval_steps)
-
- loss = F.mse_loss(pred, out)
-
- # loss.backward()
- self.accelerator.backward(loss)
-
- if self.grad_clipping:
- clipping_param = self.args.grad_clip
- total_norm = torch.nn.utils.clip_grad_norm_(
- self.net[forward_or_backward].parameters(), clipping_param)
- else:
- total_norm = 0.
-
- if (i % self.stride_log == 0) and (i > 0):
- self.logger.log_metrics({'forward_or_backward': forward_or_backward,
- 'loss': loss,
- 'grad_norm': total_norm}, step=i+self.num_iter*n)
-
- self.optimizer[forward_or_backward].step()
- self.optimizer[forward_or_backward].zero_grad()
- if self.args.ema:
- self.ema_helpers[forward_or_backward].update(
- self.net[forward_or_backward])
-
- self.save_step(i, n, forward_or_backward)
-
- if (i % self.args.cache_refresh_stride == 0) and (i > 0):
- new_dl = None
- torch.cuda.empty_cache()
- new_dl = self.new_cacheloader(
- forward_or_backward, n, self.args.ema)
-
- new_dl = None
- self.clear()
-
- def train(self):
-
- # INITIAL FORWARD PASS
- if self.accelerator.is_local_main_process:
- init_sample = next(self.save_init_dl)[0]
- init_sample = init_sample.to(self.device)
- x_tot, _, _ = self.langevin.record_init_langevin(init_sample)
- shape_len = len(x_tot.shape)
- x_tot = x_tot.permute(1, 0, *list(range(2, shape_len)))
- x_tot_plot = x_tot.detach() # .cpu().numpy()
-
- self.plotter(init_sample, x_tot_plot, 0, 0, 'f')
- x_tot_plot = None
- x_tot = None
- torch.cuda.empty_cache()
-
- for n in range(self.checkpoint_it, self.n_ipf+1):
-
- print('IPF iteration: ' + str(n) + '/' + str(self.n_ipf))
- # BACKWARD OPTIMISATION
- if (self.checkpoint_pass == 'f') and (n == self.checkpoint_it):
- self.ipf_step('f', n)
- else:
- self.ipf_step('b', n)
- self.ipf_step('f', n)
diff --git a/bridge/runners/logger.py b/bridge/runners/logger.py
deleted file mode 100644
index 3a5ac20..0000000
--- a/bridge/runners/logger.py
+++ /dev/null
@@ -1,50 +0,0 @@
-
-from pytorch_lightning.loggers import NeptuneLogger as _NeptuneLogger
-
-from pytorch_lightning.loggers import CSVLogger as _CSVLogger
-
-
-class Logger:
- def log_metrics(self, metric_dict, step, save=False):
- pass
-
- def log_hparams(self, hparams_dict):
- pass
-
-
-class CSVLogger(Logger):
-
- def __init__(self, directory='./', name='logs', save_stride=1):
- self.logger = _CSVLogger(directory, name=name)
- self.count = 0
- self.stride = save_stride
-
- def log_metrics(self, metrics, step=None,save=False):
- self.count += 1
- self.logger.log_metrics(metrics, step=step)
- if self.count % self.stride == 0:
- self.logger.save()
- self.logger.metrics = []
-
- if self.count > self.stride * 10:
- self.count = 0
-
- if save:
- self.logger.save()
-
- def log_hparams(self, hparams_dict):
- self.logger.log_hyperparams(hparams_dict)
- self.logger.save()
-
-
-class NeptuneLogger(Logger):
- def __init__(self, project_name, api_key, save_folder='./'):
- self.directory = save_folder
- self.logger = _NeptuneLogger(api_key=api_key, project_name=project_name)
-
- def log_metrics(self, metrics, step=None):
- self.logger.log_metrics(metrics,step=step)
-
- def log_hparams(self, hparams_dict):
- self.logger.log_hyperparams(hparams_dict)
-
diff --git a/bridge/runners/plotters.py b/bridge/runners/plotters.py
deleted file mode 100755
index b284f1b..0000000
--- a/bridge/runners/plotters.py
+++ /dev/null
@@ -1,211 +0,0 @@
-import numpy as np
-import matplotlib.pyplot as plt
-import matplotlib
-import torch
-import torchvision.utils as vutils
-from PIL import Image
-from ..data.two_dim import data_distrib
-import os, sys
-matplotlib.use('Agg')
-
-
-
-DPI = 200
-
-def make_gif(plot_paths, output_directory='./gif', gif_name='gif'):
- frames = [Image.open(fn) for fn in plot_paths]
-
- frames[0].save(os.path.join(output_directory, f'{gif_name}.gif'),
- format='GIF',
- append_images=frames[1:],
- save_all=True,
- duration=100,
- loop=0)
-
-def save_sequence(num_steps, x, name='', im_dir='./im', gif_dir = './gif', xlim=None, ylim=None, ipf_it=None, freq=1):
- if not os.path.isdir(im_dir):
- os.mkdir(im_dir)
- if not os.path.isdir(gif_dir):
- os.mkdir(gif_dir)
-
- # PARTICLES (INIT AND FINAL DISTRIB)
-
- plot_paths = []
- for k in range(num_steps):
- if k % freq == 0:
- filename = name + 'particle_' + str(k) + '.png'
- filename = os.path.join(im_dir, filename)
- plt.clf()
- if (xlim is not None) and (ylim is not None):
- plt.xlim(*xlim)
- plt.ylim(*ylim)
- plt.plot(x[-1, :, 0], x[-1, :, 1], '*')
- plt.plot(x[0, :, 0], x[0, :, 1], '*')
- plt.plot(x[k, :, 0], x[k, :, 1], '*')
- if ipf_it is not None:
- str_title = 'IPFP iteration: ' + str(ipf_it)
- plt.title(str_title)
-
- #plt.axis('equal')
- plt.savefig(filename, bbox_inches = 'tight', transparent = True, dpi=DPI)
- plot_paths.append(filename)
-
- # TRAJECTORIES
-
- N_part = 10
- filename = name + 'trajectory.png'
- filename = os.path.join(im_dir, filename)
- plt.clf()
- plt.plot(x[-1, :, 0], x[-1, :, 1], '*')
- plt.plot(x[0, :, 0], x[0, :, 1], '*')
- for j in range(N_part):
- xj = x[:, j, :]
- plt.plot(xj[:, 0], xj[:, 1], 'g', linewidth=2)
- plt.plot(xj[0,0], xj[0,1], 'rx')
- plt.plot(xj[-1,0], xj[-1,1], 'rx')
- plt.savefig(filename, bbox_inches = 'tight', transparent = True, dpi=DPI)
-
- make_gif(plot_paths, output_directory=gif_dir, gif_name=name)
-
- # REGISTRATION
-
- colors = np.cos(0.1 * x[0, :, 0]) * np.cos(0.1 * x[0, :, 1])
-
- name_gif = name + 'registration'
- plot_paths_reg = []
- for k in range(num_steps):
- if k % freq == 0:
- filename = name + 'registration_' + str(k) + '.png'
- filename = os.path.join(im_dir, filename)
- plt.clf()
- if (xlim is not None) and (ylim is not None):
- plt.xlim(*xlim)
- plt.ylim(*ylim)
- plt.plot(x[-1, :, 0], x[-1, :, 1], '*', alpha=0)
- plt.plot(x[0, :, 0], x[0, :, 1], '*', alpha=0)
- plt.scatter(x[k, :, 0], x[k, :, 1], c=colors)
- if ipf_it is not None:
- str_title = 'IPFP iteration: ' + str(ipf_it)
- plt.title(str_title)
- plt.savefig(filename, bbox_inches = 'tight', transparent = True, dpi=DPI)
- plot_paths_reg.append(filename)
-
- make_gif(plot_paths_reg, output_directory=gif_dir, gif_name=name_gif)
-
- # DENSITY
-
- name_gif = name + 'density'
- plot_paths_reg = []
- npts = 100
- for k in range(num_steps):
- if k % freq == 0:
- filename = name + 'density_' + str(k) + '.png'
- filename = os.path.join(im_dir, filename)
- plt.clf()
- if (xlim is not None) and (ylim is not None):
- plt.xlim(*xlim)
- plt.ylim(*ylim)
- else:
- xlim = [-15, 15]
- ylim = [-15, 15]
- if ipf_it is not None:
- str_title = 'IPFP iteration: ' + str(ipf_it)
- plt.title(str_title)
- plt.hist2d(x[k, :, 0], x[k, :, 1], range=[[xlim[0], xlim[1]], [ylim[0], ylim[1]]], bins=npts)
- plt.savefig(filename, bbox_inches = 'tight', transparent = True, dpi=DPI)
- plot_paths_reg.append(filename)
-
- make_gif(plot_paths_reg, output_directory=gif_dir, gif_name=name_gif)
-
-
-
-
-class Plotter(object):
-
- def __init__(self):
- pass
-
- def plot(self, x_tot_plot, net, i, n, forward_or_backward):
- pass
-
- def __call__(self, initial_sample, x_tot_plot, net, i, n, forward_or_backward):
- self.plot(initial_sample, x_tot_plot, net, i, n, forward_or_backward)
-
-
-class ImPlotter(object):
-
- def __init__(self, im_dir = './im', gif_dir='./gif', plot_level=3):
- if not os.path.isdir(im_dir):
- os.mkdir(im_dir)
- if not os.path.isdir(gif_dir):
- os.mkdir(gif_dir)
- self.im_dir = im_dir
- self.gif_dir = gif_dir
- self.num_plots = 100
- self.num_digits = 20
- self.plot_level = plot_level
-
-
- def plot(self, initial_sample, x_tot_plot, i, n, forward_or_backward):
- if self.plot_level > 0:
- x_tot_plot = x_tot_plot[:,:self.num_plots]
- name = '{0}_{1}_{2}'.format(forward_or_backward, n, i)
- im_dir = os.path.join(self.im_dir, name)
-
- if not os.path.isdir(im_dir):
- os.mkdir(im_dir)
-
- if self.plot_level > 0:
- plt.clf()
- filename_grid_png = os.path.join(im_dir, 'im_grid_first.png')
- vutils.save_image(initial_sample, filename_grid_png, nrow=10)
- filename_grid_png = os.path.join(im_dir, 'im_grid_final.png')
- vutils.save_image(x_tot_plot[-1], filename_grid_png, nrow=10)
-
- if self.plot_level >= 2:
- plt.clf()
- plot_paths = []
- num_steps, num_particles, channels, H, W = x_tot_plot.shape
- plot_steps = np.linspace(0,num_steps-1,self.num_plots, dtype=int)
-
- for k in plot_steps:
- # save png
- filename_grid_png = os.path.join(im_dir, 'im_grid_{0}.png'.format(k))
- plot_paths.append(filename_grid_png)
- vutils.save_image(x_tot_plot[k], filename_grid_png, nrow=10)
-
-
- make_gif(plot_paths, output_directory=self.gif_dir, gif_name=name)
-
- def __call__(self, initial_sample, x_tot_plot, i, n, forward_or_backward):
- self.plot(initial_sample, x_tot_plot, i, n, forward_or_backward)
-
-
-class TwoDPlotter(Plotter):
-
- def __init__(self, num_steps, gammas, im_dir = './im', gif_dir='./gif'):
-
- if not os.path.isdir(im_dir):
- os.mkdir(im_dir)
- if not os.path.isdir(gif_dir):
- os.mkdir(gif_dir)
-
- self.im_dir = im_dir
- self.gif_dir = gif_dir
-
- self.num_steps = num_steps
- self.gammas = gammas
-
- def plot(self, initial_sample, x_tot_plot, i, n, forward_or_backward):
- fb = forward_or_backward
- ipf_it = n
- x_tot_plot = x_tot_plot.cpu().numpy()
- name = str(i) + '_' + fb +'_' + str(n) + '_'
-
- save_sequence(num_steps=self.num_steps, x=x_tot_plot, name=name, xlim=(-15,15),
- ylim=(-15,15), ipf_it=ipf_it, freq=self.num_steps//min(self.num_steps,50),
- im_dir=self.im_dir, gif_dir=self.gif_dir)
-
-
- def __call__(self, initial_sample, x_tot_plot, i, n, forward_or_backward):
- self.plot(initial_sample, x_tot_plot, i, n, forward_or_backward)
diff --git a/conda.yaml b/conda.yaml
deleted file mode 100644
index 6d414b8..0000000
--- a/conda.yaml
+++ /dev/null
@@ -1,80 +0,0 @@
-name: bridge_test
-channels:
- - conda-forge
- - defaults
-dependencies:
- - _libgcc_mutex=0.1=main
- - ca-certificates=2020.10.14=0
- - certifi=2020.6.20=pyhd3eb1b0_3
- - ld_impl_linux-64=2.33.1=h53a641e_7
- - libedit=3.1.20191231=h14c3975_1
- - libffi=3.3=he6710b0_2
- - libgcc-ng=9.1.0=hdf63c60_0
- - libstdcxx-ng=9.1.0=hdf63c60_0
- - ncurses=6.2=he6710b0_1
- - openssl=1.1.1h=h7b6447c_0
- - pip=20.2.4=py38h06a4308_0
- - python=3.8.5=h7579374_1
- - readline=8.0=h7b6447c_0
- - setuptools=50.3.0=py38h06a4308_1
- - sqlite=3.33.0=h62c20be_0
- - tk=8.6.10=hbc83047_0
- - wheel=0.35.1=py_0
- - xz=5.2.5=h7b6447c_0
- - zlib=1.2.11=h7b6447c_3
- - pip:
- - absl-py==0.12.0
- - accelerate==0.2.1
- - aiohttp==3.7.4.post0
- - antlr4-python3-runtime==4.8
- - async-timeout==3.0.1
- - attrs==20.3.0
- - cachetools==4.2.1
- - chardet==4.0.0
- - cycler==0.10.0
- - cython==0.29.23
- - fsspec==2021.4.0
- - future==0.18.2
- - google-auth==1.29.0
- - google-auth-oauthlib==0.4.4
- - grpcio==1.37.0
- - hydra-core==1.0.6
- - idna==2.10
- - importlib-resources==5.1.2
- - joblib==1.0.1
- - kiwisolver==1.3.1
- - markdown==3.3.4
- - matplotlib==3.4.1
- - multidict==5.1.0
- - numpy==1.20.2
- - oauthlib==3.1.0
- - omegaconf==2.0.6
- - packaging==20.9
- - pandas==1.2.4
- - pillow==8.2.0
- - pot==0.7.0
- - protobuf==3.15.8
- - pyaml==20.4.0
- - pyasn1==0.4.8
- - pyasn1-modules==0.2.8
- - pyparsing==2.4.7
- - pytorch-lightning==1.2.10
- - pytz==2021.1
- - pyyaml==5.3.1
- - requests==2.25.1
- - requests-oauthlib==1.3.0
- - rsa==4.7.2
- - scikit-learn==0.24.1
- - scipy==1.6.2
- - six==1.15.0
- - tensorboard==2.4.1
- - tensorboard-plugin-wit==1.8.0
- - threadpoolctl==2.1.0
- - torch==1.8.1
- - torchmetrics==0.2.0
- - torchvision==0.9.1
- - tqdm==4.60.0
- - typing-extensions==3.7.4.3
- - urllib3==1.26.4
- - werkzeug==1.0.1
- - yarl==1.6.3
diff --git a/conf/config.yaml b/conf/config.yaml
deleted file mode 100644
index e3cde32..0000000
--- a/conf/config.yaml
+++ /dev/null
@@ -1,33 +0,0 @@
-# @package _global_
-
-defaults:
- - launcher: local
- - job
- - dataset: 2d #celeba, 2d, stackedmnist
- - model: Basic #Basic, UNET
-
-
-# data
-data_dir: ./data/
-
-# logging
-LOGGER: CSV # NEPTUNE, CSV, NONE
-CSV_log_dir: ./
-
-cache_gpu: False
-num_cache_batches: 1
-cache_refresh_stride: 100
-plot_level: 1
-mean_match: True
-paths:
- experiments_dir_name: experiments
-
-# checkpoint
-checkpoint_run: False
-checkpoint_it: 13
-checkpoint_pass: backward
-sample_checkpoint_f: ""
-sample_checkpoint_b: ""
-checkpoint_f: ""
-checkpoint_b: ""
-
diff --git a/conf/dataset/2d.yaml b/conf/dataset/2d.yaml
deleted file mode 100644
index 998a417..0000000
--- a/conf/dataset/2d.yaml
+++ /dev/null
@@ -1,52 +0,0 @@
-# @package _global_
-
-Dataset: 2d
-data: scurve
-
-# transfer
-transfer: False
-Dataset_transfer: 2d
-data_transfer: circle
-
-adaptive_mean: False
-final_adaptive: True
-mean_final: torch.tensor([0.,0.])
-var_final: 1.*torch.tensor([1., 1.])
-
-
-# device
-device: cpu
-dataparallel: False
-num_workers: 8
-pin_memory: False
-distributed: False
-
-# training
-use_prev_net: False
-mean_match: False
-ema: False
-ema_rate: 0.999
-grad_clipping: False
-grad_clip: 1.0
-npar: 10000
-batch_size: 512
-num_iter : 10000
-cache_npar: 10000
-n_ipf: 20
-lr: 0.0001
-
-# schedule
-num_steps : 20
-gamma_max: 0.01
-gamma_min: 0.01
-gamma_space: linspace
-weight_distrib: False
-weight_distrib_alpha: 100
-fast_sampling: True
-
-
-# logging
-plot_npar: 1000
-log_stride: 50
-gif_stride: ${num_iter}
-
diff --git a/conf/dataset/celeba.yaml b/conf/dataset/celeba.yaml
deleted file mode 100755
index 45fb3e2..0000000
--- a/conf/dataset/celeba.yaml
+++ /dev/null
@@ -1,55 +0,0 @@
-# @package _global_
-
-# data
-Dataset: celeba
-
-data:
- dataset: "CELEBA"
- image_size: 64
- channels: 3
- random_flip: true
-
-# transfer
-transfer: False
-Dataset_transfer: mnist
-
-
-final_adaptive: False
-adaptive_mean: True
-mean_final: torch.zeros([${data.channels}, ${data.image_size}, ${data.image_size}])
-var_final: .5 * torch.ones([${data.channels}, ${data.image_size}, ${data.image_size}])
-load: False
-
-# device
-device: cuda
-dataparallel: True
-num_workers: 8
-pin_memory: False
-
-# logging
-log_stride : 10
-gif_stride: 5000
-plot_npar: 100
-
-# training
-cache_npar: 300
-use_prev_net: True
-ema: True
-ema_rate: 0.999
-grad_clipping: True
-grad_clip: 1.0
-n_ipf_init: 1
-batch_size: 128
-num_iter : 50000
-n_ipf: 20
-lr: 0.0001
-
-# diffusion schedule
-num_steps : 50
-gamma_max: 0.1
-gamma_min: 0.00001
-gamma_space: linspace
-weight_distrib: False
-weight_distrib_alpha: 100
-fast_sampling: True
-
diff --git a/conf/dataset/stackedmnist.yaml b/conf/dataset/stackedmnist.yaml
deleted file mode 100644
index 9e9a6d6..0000000
--- a/conf/dataset/stackedmnist.yaml
+++ /dev/null
@@ -1,56 +0,0 @@
-# @package _global_
-
-# data
-Dataset: stackedmnist
-data:
- dataset: "Stacked_MNIST"
- category: ""
- image_size: 28
- channels: 1
-
-# transfer
-transfer: False
-Dataset_transfer: mnist
-
-
-final_adaptive: False
-adaptive_mean: True
-mean_final: torch.zeros([${data.channels}, ${data.image_size}, ${data.image_size}])
-var_final: 1 * torch.ones([${data.channels}, ${data.image_size}, ${data.image_size}])
-load: False
-
-# device
-device: cuda
-dataparallel: True
-num_workers: 8
-pin_memory: True
-
-# logging
-log_stride : 10
-gif_stride: 5000
-plot_npar: 100
-
-# training
-cache_npar: 1000
-num_cache_batches: 1
-use_prev_net: False
-ema: True
-ema_rate: 0.999
-grad_clipping: True
-grad_clip: 1.0
-n_ipf_init: 1
-batch_size: 128
-num_iter : 500000
-n_ipf: 20
-lr: 0.0001
-
-# schedule
-num_steps : 30
-gamma_max: 0.1
-gamma_min: 0.00001
-gamma_space: linspace
-weight_distrib: True
-weight_distrib_alpha: 100
-fast_sampling: True
-
-
diff --git a/conf/job.yaml b/conf/job.yaml
deleted file mode 100644
index e7db7c7..0000000
--- a/conf/job.yaml
+++ /dev/null
@@ -1,26 +0,0 @@
-# @package hydra
-
-job:
- config:
- # configuration for the ${hydra.job.override_dirname} runtime variable
- override_dirname:
- exclude_keys: [name, launcher, run, training, device, data_dir, dataset, load]
-
-run:
- # Output directory for normal runs
- dir: ./${paths.experiments_dir_name}/${now:%Y-%m-%d}/${name}/${hydra.job.override_dirname}/${now:%H-%M-%S}
-
-sweep:
-# # Output directory for sweep runs
- dir: ./${paths.experiments_dir_name}/${name}/${hydra.job.override_dirname}
- subdir: ./
-
-job_logging:
- formatters:
- simple:
- format: '[%(levelname)s] - %(message)s'
- handlers:
- file:
- filename: run.log
- root:
- handlers: [console, file]
\ No newline at end of file
diff --git a/conf/launcher/local.yaml b/conf/launcher/local.yaml
deleted file mode 100644
index e16d4e5..0000000
--- a/conf/launcher/local.yaml
+++ /dev/null
@@ -1,3 +0,0 @@
-# @package _global_
-
-name: ${hydra.job.name}
\ No newline at end of file
diff --git a/conf/mnist.yaml b/conf/mnist.yaml
deleted file mode 100644
index cc45780..0000000
--- a/conf/mnist.yaml
+++ /dev/null
@@ -1,31 +0,0 @@
-# @package _global_
-
-defaults:
- - launcher: local
- - job
- - dataset: stackedmnist #celeba, 2d, stackedmnist
- - model: UNET #Basic, UNET
-
-
-# data
-data_dir: ./data/
-
-# logging
-LOGGER: CSV # NEPTUNE, CSV, NONE
-CSV_log_dir: ./
-
-cache_gpu: False
-num_cache_batches: 10
-cache_refresh_stride: 1000
-plot_level: 1
-mean_match: True
-paths:
- experiments_dir_name: experiments
-
-# checkpoint
-checkpoint_run: False
-checkpoint_it: 1
-checkpoint_pass: backward
-sample_checkpoint_f: None
-sample_checkpoint_b: None
-checkpoint_f: None
diff --git a/conf/model/Basic.yaml b/conf/model/Basic.yaml
deleted file mode 100644
index 10181cd..0000000
--- a/conf/model/Basic.yaml
+++ /dev/null
@@ -1,4 +0,0 @@
-# @package _global_
-
-# model
-Model : Basic
\ No newline at end of file
diff --git a/conf/model/UNET.yaml b/conf/model/UNET.yaml
deleted file mode 100644
index 86615b6..0000000
--- a/conf/model/UNET.yaml
+++ /dev/null
@@ -1,12 +0,0 @@
-# @package _global_
-
-Model: UNET
-model:
- num_channels: 64
- num_res_blocks: 2
- num_heads: 4
- num_heads_upsample: -1
- attention_resolutions: "168"
- dropout: 0.0
- use_checkpoint: False
- use_scale_shift_norm: True
\ No newline at end of file
diff --git a/config.py b/config.py
new file mode 100644
index 0000000..eadbfc6
--- /dev/null
+++ b/config.py
@@ -0,0 +1,82 @@
+"""
+Configuration pour l'entraînement DSB : HES -> CD30 virtual staining.
+Tous les hyperparamètres sont centralisés ici.
+Modifier directement les valeurs ci-dessous.
+"""
+
+# ----- Dataset -----
+DATASET = "hes_cd30"
+DATASET_TRANSFER = "hes_cd30"
+IMAGE_SIZE = 256
+CHANNELS = 3
+RANDOM_FLIP = True
+DATA_DIR = "./" # dataset_v2/ doit être dans ce dossier
+TRANSFER = True # mode transfer : HES -> CD30
+LOAD = False
+
+# ----- Modèle (UNET) -----
+MODEL = "UNET"
+NUM_CHANNELS = 64
+NUM_RES_BLOCKS = 2
+NUM_HEADS = 4
+NUM_HEADS_UPSAMPLE = -1
+ATTENTION_RESOLUTIONS = "16"
+DROPOUT = 0.0
+USE_CHECKPOINT = False
+USE_SCALE_SHIFT_NORM = True
+
+# ----- Device -----
+DEVICE = "cuda" # "cpu" si pas de GPU
+DATAPARALLEL = True
+NUM_WORKERS = 4
+PIN_MEMORY = True
+
+# ----- Entraînement -----
+BATCH_SIZE = 8
+LR = 1e-4
+NUM_ITER = 50000
+N_IPF = 20
+N_IPF_INIT = 1
+CACHE_NPAR = 32
+NUM_CACHE_BATCHES = 1
+CACHE_REFRESH_STRIDE = 100
+USE_PREV_NET = True
+MEAN_MATCH = True
+
+# ----- EMA -----
+EMA = True
+EMA_RATE = 0.999
+
+# ----- Gradient clipping -----
+GRAD_CLIPPING = True
+GRAD_CLIP = 1.0
+
+# ----- Schedule de diffusion -----
+NUM_STEPS = 50
+GAMMA_MAX = 0.1
+GAMMA_MIN = 1e-5
+GAMMA_SPACE = "linspace" # "linspace" ou "geomspace"
+WEIGHT_DISTRIB = True
+WEIGHT_DISTRIB_ALPHA = 100
+FAST_SAMPLING = True
+
+# ----- Gaussian final (non utilisé en mode transfer, mais requis) -----
+FINAL_ADAPTIVE = False
+ADAPTIVE_MEAN = False
+MEAN_FINAL = "torch.zeros([3, 256, 256])"
+VAR_FINAL = "torch.ones([3, 256, 256])"
+
+# ----- Logging -----
+LOGGER = "CSV"
+CSV_LOG_DIR = "./"
+LOG_STRIDE = 10
+GIF_STRIDE = 5000
+PLOT_NPAR = 16
+PLOT_LEVEL = 1
+
+# ----- Checkpoint -----
+CHECKPOINT_RUN = False
+CHECKPOINT_IT = 1
+CHECKPOINT_PASS = "b"
+SAMPLE_CHECKPOINT_F = ""
+SAMPLE_CHECKPOINT_B = ""
diff --git a/data.py b/data.py
deleted file mode 100644
index 6fe208e..0000000
--- a/data.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import os,sys
-
-import argparse
-
-parser = argparse.ArgumentParser(description='Download data.')
-parser.add_argument('--data', type=str, help='mnist, celeba')
-parser.add_argument('--data_dir', type=str, help='download location')
-
-
-sys.path.append('..')
-
-
-from bridge.data.stackedmnist import Stacked_MNIST
-from bridge.data.emnist import EMNIST
-from bridge.data.celeba import CelebA
-
-
-# SETTING PARAMETERS
-
-def main():
-
- args = parser.parse_args()
-
- if args.data == 'mnist':
- root = os.path.join(args.data_dir, 'mnist')
- Stacked_MNIST(root,
- load=False,
- source_root=root,
- train=True,
- num_channels = 1,
- imageSize=28,
- device='cpu')
-
- if args.data == 'celeba':
- root = os.path.join(args.data_dir, 'celeba')
- CelebA(root, split='train', download=True)
-
-
-if __name__ == '__main__':
- main()
diff --git a/main.py b/main.py
deleted file mode 100644
index a8c20e7..0000000
--- a/main.py
+++ /dev/null
@@ -1,22 +0,0 @@
-import torch
-import hydra
-import os,sys
-
-sys.path.append('..')
-
-
-from bridge.runners.ipf import IPFSequential
-
-
-# SETTING PARAMETERS
-
-@hydra.main(config_path="./conf", config_name="config")
-def main(args):
-
- print('Directory: ' + os.getcwd())
- ipf = IPFSequential(args)
- ipf.train()
-
-
-if __name__ == '__main__':
- main()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..6ff4ef0
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,8 @@
+torch
+torchvision
+matplotlib
+tqdm
+pillow
+numpy
+accelerate
+pytorch-lightning
\ No newline at end of file
diff --git a/schrodinger_bridge.png b/schrodinger_bridge.png
deleted file mode 100644
index bde4faa..0000000
Binary files a/schrodinger_bridge.png and /dev/null differ
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..c38fc79
--- /dev/null
+++ b/train.py
@@ -0,0 +1,1254 @@
+"""
+Script d'entraînement complet DSB (Diffusion Schrödinger Bridge) : HES -> CD30.
+
+Idée générale (version débutant) :
+- On veut apprendre à transformer des images HES en images type CD30.
+- On utilise un modèle de type "diffusion / bridge" entraîné en plusieurs passes.
+- Le script contient tout : chargement des données, entraînement, sauvegardes, plots.
+
+Le modèle (UNet), le dataset (HES_CD30) et la config (cfg) sont importés depuis d'autres fichiers.
+"""
+
+import os
+import sys
+import copy
+import time
+import random
+import datetime
+from itertools import repeat
+
+import numpy as np
+import matplotlib
+matplotlib.use('Agg') # Permet de faire des figures même sans écran (serveur)
+import matplotlib.pyplot as plt
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader, Dataset
+import torchvision.utils as vutils
+import torchvision.transforms as transforms
+from PIL import Image
+from tqdm import tqdm
+from accelerate import Accelerator # Gère GPU/multi-GPU/mixed precision plus simplement
+
+from pytorch_lightning.loggers import CSVLogger as _CSVLogger
+
+import config as cfg
+from bridge.models.unet import UNetModel
+from bridge.data.hes_cd30 import HES_CD30
+
+# Petit raccourci pour construire facilement une pipeline de transformations d'images
+cmp = lambda x: transforms.Compose([*x])
+
+
+# ============================================================================
+# REPEATER (boucle infinie sur un DataLoader)
+# ============================================================================
+
+def repeater(data_loader):
+ """
+ Concept :
+ - Un DataLoader PyTorch s'arrête quand il a tout parcouru.
+ - Ici on veut pouvoir appeler next(...) indéfiniment sans gérer les fins d'epoch.
+ → On crée une boucle infinie qui répète le DataLoader.
+ """
+ for loader in repeat(data_loader):
+ for data in loader:
+ yield data
+
+
+# ============================================================================
+# LOGGER
+# ============================================================================
+
+class Logger:
+ """
+ Interface minimale pour enregistrer des métriques (loss, etc.).
+ """
+ def log_metrics(self, metric_dict, step=None, save=False):
+ pass
+
+ def log_hparams(self, hparams_dict):
+ pass
+
+
+class CSVLogger(Logger):
+ """
+ Logger qui écrit les métriques dans des fichiers CSV.
+
+ Concept :
+ - garder une trace de l'entraînement (loss, normes de gradient, etc.)
+ - pouvoir relire ensuite dans Excel / Python / etc.
+ """
+ def __init__(self, directory='./', name='logs', save_stride=1):
+ self.logger = _CSVLogger(directory, name=name)
+ self.count = 0
+ self.stride = save_stride
+
+ def log_metrics(self, metrics, step=None, save=False):
+ # Ajoute des métriques (ex: loss) au logger
+ self.count += 1
+ self.logger.log_metrics(metrics, step=step)
+ # Sauvegarde périodique (évite de perdre les données en cas de crash)
+ if self.count % self.stride == 0:
+ self.logger.save()
+ self.logger.metrics = []
+ # "reset" occasionnel pour éviter de garder trop en mémoire
+ if self.count > self.stride * 10:
+ self.count = 0
+ if save:
+ self.logger.save()
+
+ def log_hparams(self, hparams_dict):
+ # Enregistre la config (hyperparamètres) dans les logs
+ self.logger.log_hyperparams(hparams_dict)
+ self.logger.save()
+
+
+# ============================================================================
+# PLOTTER
+# ============================================================================
+
+def make_gif(plot_paths, output_directory='./gif', gif_name='gif'):
+ """
+ Concept :
+ - On a une suite d'images (png) à différentes étapes d'un processus.
+ - On les assemble en GIF pour visualiser l'évolution.
+ """
+ frames = [Image.open(fn) for fn in plot_paths]
+ frames[0].save(
+ os.path.join(output_directory, f'{gif_name}.gif'),
+ format='GIF',
+ append_images=frames[1:],
+ save_all=True,
+ duration=100,
+ loop=0,
+ )
+
+
+class ImPlotter:
+ """
+ Outil pour sauvegarder des grilles d'images (début / fin / étapes intermédiaires)
+ et éventuellement créer un GIF.
+
+ Concept :
+ - visualiser la "trajectoire" de génération : comment une image évolue
+ au fil des étapes de Langevin / diffusion.
+ """
+ def __init__(self, im_dir='./im', gif_dir='./gif', plot_level=3):
+ if not os.path.isdir(im_dir):
+ os.mkdir(im_dir)
+ if not os.path.isdir(gif_dir):
+ os.mkdir(gif_dir)
+ self.im_dir = im_dir
+ self.gif_dir = gif_dir
+ self.num_plots = 100
+ self.num_digits = 20
+ self.plot_level = plot_level
+
+ def plot(self, initial_sample, x_tot_plot, i, n, forward_or_backward):
+ """
+ initial_sample : images de départ
+ x_tot_plot : images générées au fil du temps (trajectoire)
+ i : index d'itération d'entraînement
+ n : itération IPF
+ forward_or_backward : 'f' ou 'b' (sens entraîné)
+ """
+ if self.plot_level > 0:
+ # On limite le nombre d'images affichées pour éviter des fichiers énormes
+ x_tot_plot = x_tot_plot[:, :self.num_plots]
+ name = '{0}_{1}_{2}'.format(forward_or_backward, n, i)
+ im_dir = os.path.join(self.im_dir, name)
+
+ if not os.path.isdir(im_dir):
+ os.mkdir(im_dir)
+
+ # Niveau 1 : sauvegarde grille début + grille fin
+ if self.plot_level > 0:
+ plt.clf()
+ filename_grid_png = os.path.join(im_dir, 'im_grid_first.png')
+ vutils.save_image(initial_sample, filename_grid_png, nrow=10)
+ filename_grid_png = os.path.join(im_dir, 'im_grid_final.png')
+ vutils.save_image(x_tot_plot[-1], filename_grid_png, nrow=10)
+
+ # Niveau 2 : sauvegarde étapes intermédiaires + création GIF
+ if self.plot_level >= 2:
+ plt.clf()
+ plot_paths = []
+ num_steps, num_particles, channels, H, W = x_tot_plot.shape
+ plot_steps = np.linspace(0, num_steps - 1, self.num_plots, dtype=int)
+
+ for k in plot_steps:
+ filename_grid_png = os.path.join(im_dir, 'im_grid_{0}.png'.format(k))
+ plot_paths.append(filename_grid_png)
+ vutils.save_image(x_tot_plot[k], filename_grid_png, nrow=10)
+
+ make_gif(plot_paths, output_directory=self.gif_dir, gif_name=name)
+
+ def __call__(self, initial_sample, x_tot_plot, i, n, forward_or_backward):
+ self.plot(initial_sample, x_tot_plot, i, n, forward_or_backward)
+
+
+# ============================================================================
+# EMA HELPER
+# ============================================================================
+
+class EMAHelper:
+ """
+ EMA = Exponential Moving Average (moyenne glissante) des poids.
+
+ Concept (débutant) :
+ - pendant l'entraînement, les poids bougent beaucoup et peuvent être "bruités"
+ - EMA garde une version "plus stable" du modèle
+ - souvent cette version EMA génère de meilleures images
+ """
+ def __init__(self, mu=0.999, device="cpu"):
+ self.mu = mu
+ self.shadow = {} # copie des poids "lissés"
+ self.device = device
+
+ def register(self, module):
+ # On mémorise une copie des poids au départ
+ if isinstance(module, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
+ module = module.module
+ for name, param in module.named_parameters():
+ if param.requires_grad:
+ self.shadow[name] = param.data.clone()
+
+ def update(self, module):
+ # Mise à jour EMA : nouveau = (1-mu)*poids_actuels + mu*poids_ema
+ if isinstance(module, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
+ module = module.module
+ for name, param in module.named_parameters():
+ if param.requires_grad:
+ self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data
+
+ def ema(self, module):
+ # Applique les poids EMA à un module (remplace les poids du module)
+ if isinstance(module, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
+ module = module.module
+ for name, param in module.named_parameters():
+ if param.requires_grad:
+ param.data.copy_(self.shadow[name].data)
+
+ def ema_copy(self, module):
+ # Crée une copie du modèle et y applique l'EMA (pratique pour échantillonner)
+ if isinstance(module, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
+ inner_module = module.module
+ locs = inner_module.locals
+ module_copy = type(inner_module)(*locs).to(self.device)
+ module_copy.load_state_dict(inner_module.state_dict())
+ if isinstance(module, nn.DataParallel):
+ module_copy = nn.DataParallel(module_copy)
+ else:
+ locs = module.locals
+ module_copy = type(module)(*locs).to(self.device)
+ module_copy.load_state_dict(module.state_dict())
+ self.ema(module_copy)
+ return module_copy
+
+ def state_dict(self):
+ return self.shadow
+
+ def load_state_dict(self, state_dict):
+ self.shadow = state_dict
+
+
+# ============================================================================
+# LANGEVIN DYNAMICS
+# ============================================================================
+
+def grad_gauss(x, m, var):
+ """
+ Concept :
+ - gradient d'une distribution Gaussienne
+ - sert ici à pousser les échantillons vers une distribution cible simple
+ (utile pour initialiser / guider certaines étapes)
+ """
+ return -(x - m) / var
+
+
+def ornstein_ulhenbeck(x, gradx, gamma):
+ """
+ Concept :
+ - une mise à jour type "dynamique stochastique"
+ - on avance dans une direction (gradx) + on ajoute un bruit aléatoire
+ """
+ return x + gamma * gradx + torch.sqrt(2 * gamma) * torch.randn(x.shape, device=x.device)
+
+
+class Langevin(torch.nn.Module):
+ """
+ Implémente une trajectoire de sampling par dynamique de Langevin.
+
+ Concept débutant :
+ - on part d'un état initial (images ou bruit)
+ - on applique plusieurs petites mises à jour
+ - à chaque étape, on ajoute un bruit aléatoire
+ - le réseau (net) sert à guider ces mises à jour
+ """
+ def __init__(self, num_steps, shape, gammas, time_sampler, device=None,
+ mean_final=torch.tensor([0., 0.]), var_final=torch.tensor([.5, .5]),
+ mean_match=True):
+ super().__init__()
+ self.mean_match = mean_match
+ self.mean_final = mean_final
+ self.var_final = var_final
+ self.num_steps = num_steps
+ self.d = shape
+ self.gammas = gammas.float()
+
+ # Prépare une version "étendue" des gammas au bon format
+ gammas_vec = torch.ones(self.num_steps, *self.d, device=device)
+ for k in range(num_steps):
+ gammas_vec[k] = gammas[k].float()
+ self.gammas_vec = gammas_vec
+
+ self.device = device if device is not None else gammas.device
+
+ # self.time contient les temps cumulés (une échelle temporelle)
+ self.steps = torch.arange(self.num_steps).to(self.device)
+ self.time = torch.cumsum(self.gammas, 0).to(self.device).float()
+
+ # time_sampler : sert à tirer des pas de temps selon une distribution (pondération)
+ self.time_sampler = time_sampler
+
+ def record_init_langevin(self, init_samples):
+ """
+ Trajectoire spéciale utilisée au tout début (n=1 et fb='b' dans ton code).
+
+ Concept :
+ - générer une trajectoire sans réseau (ou avec une règle simple)
+ - produire aussi 'out' (la "cible" à apprendre) pour l'entraînement
+ """
+ mean_final = self.mean_final
+ var_final = self.var_final
+ x = init_samples
+ N = x.shape[0]
+
+ # steps_expanded = temps associé à chaque étape pour chaque échantillon du batch
+ time = self.time.reshape((1, self.num_steps, 1)).repeat((N, 1, 1))
+ steps_expanded = time
+
+ # x_tot : sauvegarde les états successifs (pour plots / cache)
+ # out : sauvegarde un signal "à prédire" (utilisé comme target pendant l'entraînement)
+ x_tot = torch.Tensor(N, self.num_steps, *self.d).to(x.device)
+ out = torch.Tensor(N, self.num_steps, *self.d).to(x.device)
+ num_iter = self.num_steps
+
+ for k in range(num_iter):
+ gamma = self.gammas[k]
+ # Ici : on pousse x vers une gaussienne cible (mean_final/var_final)
+ gradx = grad_gauss(x, mean_final, var_final)
+ # t_old / t_new : deux évaluations utilisées pour construire la cible 'out'
+ t_old = x + gamma * gradx
+ z = torch.randn(x.shape, device=x.device)
+ x = t_old + torch.sqrt(2 * gamma) * z
+ gradx = grad_gauss(x, mean_final, var_final)
+ t_new = x + gamma * gradx
+ x_tot[:, k, :] = x
+ out[:, k, :] = (t_old - t_new)
+
+ return x_tot, out, steps_expanded
+
+ def record_langevin_seq(self, net, init_samples, t_batch=None, ipf_it=0, sample=False):
+ """
+ Trajectoire standard : le réseau 'net' guide les mises à jour.
+
+ Concept :
+ - on simule une évolution en plusieurs étapes
+ - on enregistre la trajectoire x_tot (pour plots)
+ - et on construit 'out' qui sert de cible pour entraîner l'autre réseau
+ """
+ mean_final = self.mean_final
+ var_final = self.var_final
+ x = init_samples
+ N = x.shape[0]
+ time = self.time.reshape((1, self.num_steps, 1)).repeat((N, 1, 1))
+ steps = time
+ steps_expanded = steps
+
+ x_tot = torch.Tensor(N, self.num_steps, *self.d).to(x.device)
+ out = torch.Tensor(N, self.num_steps, *self.d).to(x.device)
+ num_iter = self.num_steps
+
+ # mean_match : change la façon dont on interprète la sortie du réseau
+ # (soit le réseau donne directement un point "moyen", soit une correction)
+ if self.mean_match:
+ for k in range(num_iter):
+ gamma = self.gammas[k]
+
+ # t_old : prédiction du réseau à l'étape k
+ t_old = net(x, steps[:, k, :])
+
+ # sampling : à la dernière étape, on peut choisir de ne plus ajouter de bruit
+ if sample and (k == num_iter - 1):
+ x = t_old
+ else:
+ z = torch.randn(x.shape, device=x.device)
+ x = t_old + torch.sqrt(2 * gamma) * z
+ t_new = net(x, steps[:, k, :])
+
+ x_tot[:, k, :] = x
+ out[:, k, :] = (t_old - t_new)
+ else:
+ for k in range(num_iter):
+ gamma = self.gammas[k]
+ t_old = x + net(x, steps[:, k, :])
+ if sample and (k == num_iter - 1):
+ x = t_old
+ else:
+ z = torch.randn(x.shape, device=x.device)
+ x = t_old + torch.sqrt(2 * gamma) * z
+ t_new = x + net(x, steps[:, k, :])
+ x_tot[:, k, :] = x
+ out[:, k, :] = (t_old - t_new)
+
+ return x_tot, out, steps_expanded
+
+ def forward(self, net, init_samples, t_batch, ipf_it):
+ return self.record_langevin_seq(net, init_samples, t_batch, ipf_it)
+
+
+# ============================================================================
+# CACHE LOADER
+# ============================================================================
+
+class CacheLoader(Dataset):
+ """
+ Dataset "fabriqué" à la volée.
+
+ Concept :
+ - au lieu d'entraîner directement sur les images, on génère des trajectoires
+ (via Langevin + un réseau) et on les "met en cache"
+ - ça transforme un problème complexe en un dataset supervisé :
+ (x, out, steps) où out est la cible à prédire.
+ """
+ def __init__(self, fb, sample_net, dataloader_b, num_batches, langevin, n,
+ mean, std, batch_size, device='cpu',
+ dataloader_f=None, transfer=False):
+ super().__init__()
+ start = time.time()
+ shape = langevin.d
+ num_steps = langevin.num_steps
+
+ # self.data stocke pour chaque étape :
+ # - x : l'état
+ # - out : la cible d'entraînement associée
+ self.data = torch.zeros(
+ (num_batches, batch_size * num_steps, 2, *shape)).to(device)
+
+ # self.steps_data stocke le temps associé à chaque échantillon/étape
+ self.steps_data = torch.zeros(
+ (num_batches, batch_size * num_steps, 1)).to(device)
+
+ with torch.no_grad():
+ for b in range(num_batches):
+
+ # Choix de la source des échantillons initiaux :
+ # - soit on prend des images du dataset (si fb == 'b')
+ # - soit on prend des images de l'autre domaine si transfer
+ # - soit on part d'un bruit gaussien (sinon)
+ if fb == 'b':
+ batch = next(dataloader_b)[0].to(device)
+ elif fb == 'f' and transfer:
+ batch = next(dataloader_f)[0].to(device)
+ else:
+ batch = mean + std * torch.randn((batch_size, *shape), device=device)
+
+ # Première itération : trajectoire spéciale d'init
+ if (n == 1) and (fb == 'b'):
+ x, out, steps_expanded = langevin.record_init_langevin(batch)
+ else:
+ # Trajectoire guidée par sample_net
+ x, out, steps_expanded = langevin.record_langevin_seq(sample_net, batch, ipf_it=n)
+
+ # On regroupe (x, out) ensemble puis on "aplatit" en une liste d'exemples
+ x = x.unsqueeze(2)
+ out = out.unsqueeze(2)
+ batch_data = torch.cat((x, out), dim=2)
+ flat_data = batch_data.flatten(start_dim=0, end_dim=1)
+ self.data[b] = flat_data
+
+ flat_steps = steps_expanded.flatten(start_dim=0, end_dim=1)
+ self.steps_data[b] = flat_steps
+
+ # Aplatit tout en un grand dataset
+ self.data = self.data.flatten(start_dim=0, end_dim=1)
+ self.steps_data = self.steps_data.flatten(start_dim=0, end_dim=1)
+
+ stop = time.time()
+ print('Cache size: {0}'.format(self.data.shape))
+ print("Load time: {0}".format(stop - start))
+
+ def __getitem__(self, index):
+ # Retourne : état x, cible out, et temps steps
+ item = self.data[index]
+ x = item[0]
+ out = item[1]
+ steps = self.steps_data[index]
+ return x, out, steps
+
+ def __len__(self):
+ return self.data.shape[0]
+
+
+
+# ============================================================================
+# CONFIG GETTERS (modèle, optimiseur, données, plotter, logger)
+# ============================================================================
+
+def get_models():
+ """
+ Construit deux réseaux UNet :
+ - net_f : réseau "forward"
+ - net_b : réseau "backward"
+
+ Concept :
+ - le Schrödinger Bridge entraîne deux directions (aller/retour)
+ - les deux réseaux apprennent à se "répondre" via IPF.
+ """
+ image_size = cfg.IMAGE_SIZE
+ if image_size == 256:
+ channel_mult = (1, 1, 2, 2, 4, 4)
+ elif image_size == 64:
+ channel_mult = (1, 2, 3, 4)
+ elif image_size == 32:
+ channel_mult = (1, 2, 2, 2)
+ else:
+ raise ValueError(f"unsupported image size: {image_size}")
+
+ # On convertit les résolutions en "facteurs de downsampling"
+ attention_ds = []
+ for res in cfg.ATTENTION_RESOLUTIONS.split(","):
+ attention_ds.append(image_size // int(res))
+
+ # Paramètres du UNet pris depuis la config
+ kwargs = {
+ "in_channels": cfg.CHANNELS,
+ "model_channels": cfg.NUM_CHANNELS,
+ "out_channels": cfg.CHANNELS,
+ "num_res_blocks": cfg.NUM_RES_BLOCKS,
+ "attention_resolutions": tuple(attention_ds),
+ "dropout": cfg.DROPOUT,
+ "channel_mult": channel_mult,
+ "num_classes": None,
+ "use_checkpoint": cfg.USE_CHECKPOINT,
+ "num_heads": cfg.NUM_HEADS,
+ "num_heads_upsample": cfg.NUM_HEADS_UPSAMPLE,
+ "use_scale_shift_norm": cfg.USE_SCALE_SHIFT_NORM,
+ }
+ net_f, net_b = UNetModel(**kwargs), UNetModel(**kwargs)
+ return net_f, net_b
+
+
+def get_optimizers(net_f, net_b, lr):
+ """
+ Créé les optimiseurs (ici Adam) pour chaque réseau.
+ """
+ return (torch.optim.Adam(net_f.parameters(), lr=lr),
+ torch.optim.Adam(net_b.parameters(), lr=lr))
+
+
+def get_datasets():
+ """
+ Charge deux datasets séparés :
+ - init_ds : images de départ (HES)
+ - final_ds : images cibles (CD30)
+
+ Concept :
+ - on ne donne pas la paire (HES, CD30) directement
+ - on a deux domaines séparés, et le bridge apprend à passer de l'un à l'autre.
+ """
+ train_transform = [
+ transforms.Resize(cfg.IMAGE_SIZE),
+ transforms.CenterCrop(cfg.IMAGE_SIZE),
+ transforms.ToTensor(),
+ ]
+ if cfg.RANDOM_FLIP:
+ train_transform.insert(2, transforms.RandomHorizontalFlip())
+
+ root = os.path.join(cfg.DATA_DIR, 'dataset_v2')
+
+ init_ds = HES_CD30(root, image_size=cfg.IMAGE_SIZE,
+ domain='HES', transform=cmp(train_transform))
+ final_ds = HES_CD30(root, image_size=cfg.IMAGE_SIZE,
+ domain='CD30', transform=cmp(train_transform))
+
+ # Paramètres d'une distribution simple, utilisée quand on ne part pas d'images réelles
+ mean_final = torch.tensor(0.)
+ var_final = torch.tensor(1. * 10 ** 3)
+
+ return init_ds, final_ds, mean_final, var_final
+
+
+def get_plotter():
+ """Crée l'outil de visualisation."""
+ return ImPlotter(plot_level=cfg.PLOT_LEVEL)
+
+
+def get_logger(name='logs'):
+ """Choisit le logger selon la config."""
+ if cfg.LOGGER == 'CSV':
+ return CSVLogger(directory=cfg.CSV_LOG_DIR, name=name)
+ return Logger()
+
+
+# ============================================================================
+# IPF BASE
+# ============================================================================
+
+# ============================================================================
+# IPF BASE
+# ============================================================================
+
+class IPFBase(torch.nn.Module):
+ """
+ Classe "socle" : prépare tout ce qu'il faut pour entraîner un Schrödinger Bridge.
+
+ Concept (débutant) :
+ - On entraîne 2 réseaux : forward (f) et backward (b)
+ - On alterne leur entraînement via IPF (Iterative Proportional Fitting)
+ - Pour entraîner un réseau, on génère d'abord des trajectoires avec l'autre réseau
+ (CacheLoader), puis on fait une optimisation classique (MSE).
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ # Accelerator simplifie la gestion CPU/GPU/multi-GPU (et parfois mixed precision).
+ self.accelerator = Accelerator(mixed_precision="no", cpu=(cfg.DEVICE == 'cpu'))
+ self.device = self.accelerator.device
+
+ # -------------------------
+ # Hyperparamètres d'entraînement
+ # -------------------------
+ self.n_ipf = cfg.N_IPF # nombre d'itérations IPF (cycles b puis f)
+ self.num_steps = cfg.NUM_STEPS # nombre d'étapes dans la trajectoire Langevin
+ self.batch_size = cfg.BATCH_SIZE # batch size utilisé pour l'optimisation
+ self.num_iter = cfg.NUM_ITER # nombre d'itérations d'optimisation par étape IPF
+ self.grad_clipping = cfg.GRAD_CLIPPING # active/désactive le clipping des gradients
+ self.fast_sampling = cfg.FAST_SAMPLING # option (selon config) pour accélérer le sampling
+ self.lr = cfg.LR # learning rate
+
+ # -------------------------
+ # Construction des pas de temps (gammas)
+ # -------------------------
+ # Concept :
+ # - on définit des petits pas gamma (taille des mises à jour)
+ # - on crée une séquence symétrique (monte puis redescend) pour le bridge
+ n = self.num_steps // 2
+ if cfg.GAMMA_SPACE == 'linspace':
+ gamma_half = np.linspace(cfg.GAMMA_MIN, cfg.GAMMA_MAX, n)
+ elif cfg.GAMMA_SPACE == 'geomspace':
+ gamma_half = np.geomspace(cfg.GAMMA_MIN, cfg.GAMMA_MAX, n)
+ gammas = np.concatenate([gamma_half, np.flip(gamma_half)])
+ gammas = torch.tensor(gammas).to(self.device)
+
+ # T = "durée totale" (somme des pas), utilisée pour transformer/recentrer le temps
+ self.T = torch.sum(gammas)
+
+ # -------------------------
+ # Modèles + EMA
+ # -------------------------
+ # Concept :
+ # - build_models : crée net_f et net_b
+ # - build_ema : initialise les EMA (moyennes glissantes des poids)
+ self.build_models()
+ self.build_ema()
+
+ # -------------------------
+ # Optimiseurs
+ # -------------------------
+ # 1 optimiseur par réseau (forward et backward)
+ self.build_optimizers()
+
+ # -------------------------
+ # Loggers
+ # -------------------------
+ # logger : métriques d'entraînement (loss, grad_norm, ...)
+ # save_logger : métriques liées aux samples/plots
+ self.logger = get_logger()
+ self.save_logger = get_logger('plot_logs')
+
+ # -------------------------
+ # Données (DataLoaders)
+ # -------------------------
+ # Construit les DataLoaders pour :
+ # - échantillonnage / plots
+ # - création de cache (trajectoires)
+ self.build_dataloaders()
+
+ # -------------------------
+ # Time sampler (pondération des temps)
+ # -------------------------
+ # Concept :
+ # - selon cfg.WEIGHT_DISTRIB, on peut donner plus de poids à certaines étapes
+ # lors du tirage (utile pour l'entraînement, selon la stratégie choisie)
+ if cfg.WEIGHT_DISTRIB:
+ alpha = cfg.WEIGHT_DISTRIB_ALPHA
+ prob_vec = (1 + alpha) * torch.sum(gammas) - torch.cumsum(gammas, 0)
+ else:
+ prob_vec = gammas * 0 + 1
+ time_sampler = torch.distributions.categorical.Categorical(prob_vec)
+
+ # -------------------------
+ # Définir la "forme" des données et créer l'objet Langevin
+ # -------------------------
+ # On récupère un batch pour connaître la taille (C,H,W)
+ batch = next(self.save_init_dl)[0]
+ shape = batch[0].shape
+ self.shape = shape
+
+ # Langevin = moteur qui génère des trajectoires + des cibles supervisées (out)
+ self.langevin = Langevin(
+ self.num_steps, shape, gammas,
+ time_sampler, device=self.device,
+ mean_final=self.mean_final, var_final=self.var_final,
+ mean_match=cfg.MEAN_MATCH
+ )
+
+ # -------------------------
+ # Gestion des checkpoints / reprise
+ # -------------------------
+ # Concept :
+ # - possibilité de reprendre à une itération IPF donnée (checkpoint_it)
+ # - et de reprendre sur la passe forward ou backward (checkpoint_pass)
+ date = str(datetime.datetime.now())[0:10]
+ self.name_all = date
+
+ self.checkpoint_run = cfg.CHECKPOINT_RUN
+ if cfg.CHECKPOINT_RUN:
+ self.checkpoint_it = cfg.CHECKPOINT_IT
+ self.checkpoint_pass = cfg.CHECKPOINT_PASS
+ else:
+ self.checkpoint_it = 1
+ self.checkpoint_pass = 'b'
+
+ # Outil de visualisation (grilles + gifs)
+ self.plotter = get_plotter()
+
+ # Création des dossiers de sortie (une seule fois sur le process principal)
+ if self.accelerator.process_index == 0:
+ os.makedirs('./im', exist_ok=True)
+ os.makedirs('./gif', exist_ok=True)
+ os.makedirs('./checkpoints', exist_ok=True)
+
+ # Strides : fréquence de sauvegarde/plots et fréquence de log
+ self.stride = cfg.GIF_STRIDE
+ self.stride_log = cfg.LOG_STRIDE
+
+ def build_models(self, forward_or_backward=None):
+ """
+ Construit les deux réseaux (forward et backward).
+
+ forward_or_backward :
+ - None : construit les deux
+ - 'f' : reconstruit uniquement le forward
+ - 'b' : reconstruit uniquement le backward
+ """
+ net_f, net_b = get_models()
+
+ # Si reprise : charge des checkpoints si fournis dans la config
+ if cfg.CHECKPOINT_RUN:
+ if cfg.SAMPLE_CHECKPOINT_F:
+ net_f.load_state_dict(torch.load(cfg.SAMPLE_CHECKPOINT_F))
+ if cfg.SAMPLE_CHECKPOINT_B:
+ net_b.load_state_dict(torch.load(cfg.SAMPLE_CHECKPOINT_B))
+
+ # Option : paralléliser sur plusieurs GPUs via DataParallel
+ if cfg.DATAPARALLEL:
+ net_f = torch.nn.DataParallel(net_f)
+ net_b = torch.nn.DataParallel(net_b)
+
+ # Création initiale des 2 réseaux
+ if forward_or_backward is None:
+ net_f = net_f.to(self.device)
+ net_b = net_b.to(self.device)
+ self.net = torch.nn.ModuleDict({'f': net_f, 'b': net_b})
+
+ # Remplacement uniquement du forward
+ if forward_or_backward == 'f':
+ net_f = net_f.to(self.device)
+ self.net.update({'f': net_f})
+
+ # Remplacement uniquement du backward
+ if forward_or_backward == 'b':
+ net_b = net_b.to(self.device)
+ self.net.update({'b': net_b})
+
+ def accelerate(self, forward_or_backward):
+ """
+ Prépare le modèle et l'optimiseur avec Accelerator.
+
+ Concept :
+ - selon le contexte, Accelerator wrap le modèle/l'optimiseur
+ pour faire tourner correctement sur le(s) device(s).
+ """
+ (self.net[forward_or_backward], self.optimizer[forward_or_backward]) = self.accelerator.prepare(
+ self.net[forward_or_backward], self.optimizer[forward_or_backward])
+
+ def update_ema(self, forward_or_backward):
+ """
+ Initialise/relance l'EMA pour une direction (f ou b).
+
+ Concept :
+ - EMA garde une version "lissée" des poids
+ - très utile pour faire des samples plus stables
+ """
+ if cfg.EMA:
+ self.ema_helpers[forward_or_backward] = EMAHelper(
+ mu=cfg.EMA_RATE, device=self.device)
+ self.ema_helpers[forward_or_backward].register(
+ self.net[forward_or_backward])
+
+ def build_ema(self):
+ """
+ Crée les EMA pour forward et backward.
+
+ En mode reprise :
+ - on peut initialiser l'EMA à partir de checkpoints de "sample nets".
+ """
+ if cfg.EMA:
+ self.ema_helpers = {}
+ self.update_ema('f')
+ self.update_ema('b')
+
+ if cfg.CHECKPOINT_RUN:
+ sample_net_f, sample_net_b = get_models()
+
+ if cfg.SAMPLE_CHECKPOINT_F:
+ sample_net_f.load_state_dict(torch.load(cfg.SAMPLE_CHECKPOINT_F))
+ if cfg.DATAPARALLEL:
+ sample_net_f = torch.nn.DataParallel(sample_net_f)
+ sample_net_f = sample_net_f.to(self.device)
+ self.ema_helpers['f'].register(sample_net_f)
+
+ if cfg.SAMPLE_CHECKPOINT_B:
+ sample_net_b.load_state_dict(torch.load(cfg.SAMPLE_CHECKPOINT_B))
+ if cfg.DATAPARALLEL:
+ sample_net_b = torch.nn.DataParallel(sample_net_b)
+ sample_net_b = sample_net_b.to(self.device)
+ self.ema_helpers['b'].register(sample_net_b)
+
+ def build_optimizers(self):
+ """
+ Crée les optimiseurs (ici Adam) pour les deux réseaux.
+ """
+ optimizer_f, optimizer_b = get_optimizers(self.net['f'], self.net['b'], self.lr)
+ self.optimizer = {'f': optimizer_f, 'b': optimizer_b}
+
+ def build_dataloaders(self):
+ """
+ Prépare les DataLoaders.
+
+ Concept :
+ - save_* : petits batches pour visualiser/plotter
+ - cache_* : batches utilisés pour construire le dataset de cache (trajectoires)
+ - repeater(...) : permet d'appeler next(...) sans fin (pas de gestion d'epoch)
+ """
+ init_ds, final_ds, mean_final, var_final = get_datasets()
+
+ # Paramètres utilisés si on génère depuis une distribution simple (bruit)
+ self.mean_final = mean_final.to(self.device)
+ self.var_final = var_final.to(self.device)
+ self.std_final = torch.sqrt(var_final).to(self.device)
+
+ # Rend les workers reproductibles (mais différents entre eux)
+ def worker_init_fn(worker_id):
+ np.random.seed(np.random.get_state()[1][0] + worker_id + self.accelerator.process_index)
+
+ self.kwargs = {
+ "num_workers": cfg.NUM_WORKERS,
+ "pin_memory": cfg.PIN_MEMORY,
+ "worker_init_fn": worker_init_fn,
+ "drop_last": True
+ }
+
+ # Domaine initial (HES)
+ self.save_init_dl = DataLoader(init_ds, batch_size=cfg.PLOT_NPAR, shuffle=True, **self.kwargs)
+ self.cache_init_dl = DataLoader(init_ds, batch_size=cfg.CACHE_NPAR, shuffle=True, **self.kwargs)
+
+ # Accelerator prépare les dataloaders (utile en multi-GPU)
+ (self.cache_init_dl, self.save_init_dl) = self.accelerator.prepare(self.cache_init_dl, self.save_init_dl)
+
+ # Itérateurs infinis
+ self.cache_init_dl = repeater(self.cache_init_dl)
+ self.save_init_dl = repeater(self.save_init_dl)
+
+ # Si TRANSFER : on utilise aussi le domaine final (CD30) comme source réelle
+ if cfg.TRANSFER:
+ self.save_final_dl = DataLoader(final_ds, batch_size=cfg.PLOT_NPAR, shuffle=True, **self.kwargs)
+ self.cache_final_dl = DataLoader(final_ds, batch_size=cfg.CACHE_NPAR, shuffle=True, **self.kwargs)
+
+ (self.cache_final_dl, self.save_final_dl) = self.accelerator.prepare(self.cache_final_dl, self.save_final_dl)
+
+ self.cache_final_dl = repeater(self.cache_final_dl)
+ self.save_final_dl = repeater(self.save_final_dl)
+ else:
+ self.cache_final_dl = None
+ self.save_final_dl = None
+
+ def new_cacheloader(self, forward_or_backward, n, use_ema=True):
+ """
+ Crée un DataLoader "cache" pour entraîner une direction.
+
+ Concept :
+ - Pour entraîner 'b', on génère des trajectoires avec 'f'
+ - Pour entraîner 'f', on génère des trajectoires avec 'b'
+ - On stocke ces trajectoires sous forme (x, out, steps) dans CacheLoader
+ """
+ # Direction utilisée pour générer les trajectoires (l'autre réseau)
+ sample_direction = 'f' if forward_or_backward == 'b' else 'b'
+
+ # On préfère souvent le réseau EMA pour sampler (plus stable)
+ if use_ema:
+ sample_net = self.ema_helpers[sample_direction].ema_copy(self.net[sample_direction])
+ else:
+ sample_net = self.net[sample_direction]
+
+ # Construction du dataset de cache selon la direction entraînée
+ if forward_or_backward == 'b':
+ sample_net = self.accelerator.prepare(sample_net)
+ new_dl = CacheLoader(
+ 'b', sample_net, self.cache_init_dl,
+ cfg.NUM_CACHE_BATCHES, self.langevin, n,
+ mean=None, std=None,
+ batch_size=cfg.CACHE_NPAR,
+ device=self.device,
+ dataloader_f=self.cache_final_dl,
+ transfer=cfg.TRANSFER
+ )
+ else:
+ sample_net = self.accelerator.prepare(sample_net)
+ new_dl = CacheLoader(
+ 'f', sample_net, None,
+ cfg.NUM_CACHE_BATCHES, self.langevin, n,
+ mean=self.mean_final, std=self.std_final,
+ batch_size=cfg.CACHE_NPAR,
+ device=self.device,
+ dataloader_f=self.cache_final_dl,
+ transfer=cfg.TRANSFER
+ )
+
+ # DataLoader + préparation Accelerator + itérateur infini
+ new_dl = DataLoader(new_dl, batch_size=self.batch_size)
+ new_dl = self.accelerator.prepare(new_dl)
+ new_dl = repeater(new_dl)
+ return new_dl
+
+ def train(self):
+ # À implémenter dans la classe enfant (IPFSequential)
+ pass
+
+ def save_step(self, i, n, fb):
+ """
+ Sauvegarde périodique + génération de samples pour suivi visuel.
+
+ Concept :
+ - à certains pas (stride), on :
+ 1) sauvegarde les poids
+ 2) génère une trajectoire (sampling)
+ 3) sauvegarde des images (grid) + gifs + logs simples
+ """
+ if self.accelerator.is_local_main_process:
+ if ((i % self.stride == 0) or (i % self.stride == 1)) and (i > 0):
+
+ # Choix du modèle de sampling (EMA ou non)
+ if cfg.EMA:
+ sample_net = self.ema_helpers[fb].ema_copy(self.net[fb])
+ else:
+ sample_net = self.net[fb]
+
+ # -------------------------
+ # 1) Sauvegarde du réseau courant
+ # -------------------------
+ name_net = 'net_' + fb + '_' + str(n) + "_" + str(i) + '.ckpt'
+ name_net_ckpt = './checkpoints/' + name_net
+
+ if cfg.DATAPARALLEL:
+ torch.save(self.net[fb].module.state_dict(), name_net_ckpt)
+ else:
+ torch.save(self.net[fb].state_dict(), name_net_ckpt)
+
+ # -------------------------
+ # 2) Sauvegarde du réseau EMA (si activé)
+ # -------------------------
+ if cfg.EMA:
+ name_net = 'sample_net_' + fb + '_' + str(n) + "_" + str(i) + '.ckpt'
+ name_net_ckpt = './checkpoints/' + name_net
+ if cfg.DATAPARALLEL:
+ torch.save(sample_net.module.state_dict(), name_net_ckpt)
+ else:
+ torch.save(sample_net.state_dict(), name_net_ckpt)
+
+ # -------------------------
+ # 3) Sampling + plots (sans gradient)
+ # -------------------------
+ with torch.no_grad():
+ # Seed fixe pour avoir des images comparables d'une sauvegarde à l'autre
+ self.set_seed(seed=0 + self.accelerator.process_index)
+
+ # Choix des images de départ pour visualiser
+ if fb == 'f':
+ # Pour forward : on part d'images HES
+ batch = next(self.save_init_dl)[0].to(self.device)
+ elif cfg.TRANSFER:
+ # Sinon si TRANSFER : on part d'images CD30 réelles
+ batch = next(self.save_final_dl)[0].to(self.device)
+ else:
+ # Sinon : on part d'un bruit gaussien
+ batch = self.mean_final + self.std_final * torch.randn(
+ (cfg.PLOT_NPAR, *self.shape), device=self.device
+ )
+
+ # Génère une trajectoire complète (sample=True = dernière étape sans bruit)
+ x_tot, out, steps_expanded = self.langevin.record_langevin_seq(
+ sample_net, batch, ipf_it=n, sample=True
+ )
+
+ # Réorganisation des dimensions pour faire des grilles / gifs
+ shape_len = len(x_tot.shape)
+ x_tot = x_tot.permute(1, 0, *list(range(2, shape_len)))
+ x_tot_plot = x_tot.detach()
+
+ # -------------------------
+ # 4) Stats simples pour surveiller la "santé" des sorties
+ # -------------------------
+ init_x = batch.detach().cpu().numpy()
+ final_x = x_tot_plot[-1].detach().cpu().numpy()
+ std_final = np.std(final_x)
+ std_init = np.std(init_x)
+ mean_final = np.mean(final_x)
+ mean_init = np.mean(init_x)
+
+ print('Initial variance: ' + str(std_init ** 2))
+ print('Final variance: ' + str(std_final ** 2))
+
+ # Log des stats de sampling
+ self.save_logger.log_metrics({
+ 'FB': fb,
+ 'init_var': std_init ** 2, 'final_var': std_final ** 2,
+ 'mean_init': mean_init, 'mean_final': mean_final,
+ 'T': self.T,
+ })
+
+ # Sauvegarde images/gif
+ self.plotter(batch, x_tot_plot, i, n, fb)
+
+ def set_seed(self, seed=0):
+ """
+ Fixe les seeds pour reproductibilité.
+ """
+ torch.manual_seed(seed)
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ def clear(self):
+ """
+ Libère la mémoire GPU (utile après cache/sampling).
+ """
+ torch.cuda.empty_cache()
+
+
+# ============================================================================
+# IPF SEQUENTIAL (entraînement)
+# ============================================================================
+
+class IPFSequential(IPFBase):
+ """
+ Implémentation "séquentielle" de l'entraînement IPF.
+
+ Concept (débutant) :
+ - À chaque itération IPF n :
+ 1) on entraîne le réseau backward (b)
+ 2) puis on entraîne le réseau forward (f)
+ - Chaque entraînement utilise un "cache" généré par l'autre réseau.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def ipf_step(self, forward_or_backward, n):
+ """
+ Entraîne UNE direction ('f' ou 'b') pendant num_iter itérations, à l'itération IPF n.
+
+ Étapes conceptuelles :
+ 1) Construire un dataset supervisé (CacheLoader) en générant des trajectoires
+ avec le réseau opposé (direction inverse).
+ 2) Entraîner le réseau courant sur (x -> out) avec une loss MSE.
+ 3) Sauvegarder/plotter régulièrement + rafraîchir le cache de temps en temps.
+ """
+ # 1) Génère le cache : un DataLoader de tuples (x, out, steps)
+ new_dl = self.new_cacheloader(forward_or_backward, n, cfg.EMA)
+
+ # Option : ne pas réutiliser le réseau précédent
+ # Concept :
+ # - si USE_PREV_NET=False, on reconstruit le réseau avant de l'entraîner
+ # - utile si on veut repartir "proprement" à chaque itération IPF
+ if not cfg.USE_PREV_NET:
+ self.build_models(forward_or_backward)
+ self.update_ema(forward_or_backward)
+
+ # (Re)crée l'optimiseur et prépare modèle+optimiseur avec Accelerator
+ self.build_optimizers()
+ self.accelerate(forward_or_backward)
+
+ # 2) Boucle d'optimisation classique
+ for i in tqdm(range(self.num_iter + 1)):
+ # Seed dépendant de (n, i) : reproductibilité + diversité entre itérations
+ self.set_seed(seed=n * self.num_iter + i)
+
+ # Récupère un batch du cache :
+ # - x : état / image intermédiaire
+ # - out : cible supervisée à prédire
+ # - steps_expanded : temps associé
+ x, out, steps_expanded = next(new_dl)
+ x = x.to(self.device)
+ out = out.to(self.device)
+ steps_expanded = steps_expanded.to(self.device)
+
+ # eval_steps : conversion du temps (ici on utilise T - t)
+ # Concept :
+ # - selon la formulation du bridge, la direction peut utiliser un temps "renversé"
+ eval_steps = self.T - steps_expanded
+
+ # Prédiction du réseau :
+ # - mode MEAN_MATCH : la sortie du réseau correspond à un "point moyen",
+ # donc on retire x pour obtenir une correction comparable à out.
+ # - sinon : la sortie du réseau est directement la correction.
+ if cfg.MEAN_MATCH:
+ pred = self.net[forward_or_backward](x, eval_steps) - x
+ else:
+ pred = self.net[forward_or_backward](x, eval_steps)
+
+ # Loss supervisée : on veut que pred ≈ out
+ loss = F.mse_loss(pred, out)
+
+ # Backprop (Accelerator gère correctement le backward selon le contexte)
+ self.accelerator.backward(loss)
+
+ # Option : clipping des gradients pour éviter des gradients trop grands
+ if self.grad_clipping:
+ clipping_param = cfg.GRAD_CLIP
+ total_norm = torch.nn.utils.clip_grad_norm_(
+ self.net[forward_or_backward].parameters(), clipping_param
+ )
+ else:
+ total_norm = 0.
+
+ # Logs périodiques (loss + norme de gradient)
+ if (i % self.stride_log == 0) and (i > 0):
+ self.logger.log_metrics({
+ 'forward_or_backward': forward_or_backward,
+ 'loss': loss,
+ 'grad_norm': total_norm,
+ }, step=i + self.num_iter * n)
+
+ # Step d'optimisation
+ self.optimizer[forward_or_backward].step()
+ self.optimizer[forward_or_backward].zero_grad()
+
+ # Mise à jour EMA (si activé) : version lissée des poids
+ if cfg.EMA:
+ self.ema_helpers[forward_or_backward].update(self.net[forward_or_backward])
+
+ # Sauvegarde + sampling + plots (selon stride)
+ self.save_step(i, n, forward_or_backward)
+
+ # 3) Rafraîchissement du cache
+ # Concept :
+ # - comme le réseau opposé évolue, les trajectoires "idéales" changent aussi
+ # - on reconstruit donc périodiquement un nouveau cache pour rester cohérent
+ if (i % cfg.CACHE_REFRESH_STRIDE == 0) and (i > 0):
+ new_dl = None
+ torch.cuda.empty_cache()
+ new_dl = self.new_cacheloader(forward_or_backward, n, cfg.EMA)
+
+ # Nettoyage
+ new_dl = None
+ self.clear()
+
+ def train(self):
+ """
+ Boucle principale d'entraînement IPF.
+
+ Concept :
+ - On fait d'abord une trajectoire d'initialisation (pour visualiser le point de départ).
+ - Ensuite, pour chaque itération IPF n :
+ - on entraîne 'b' puis 'f' (sauf cas spécial de reprise de checkpoint).
+ """
+
+ # -------------------------
+ # INITIAL FORWARD PASS (visualisation)
+ # -------------------------
+ # On génère une trajectoire d'init (sans réseau) uniquement pour voir à quoi ressemble
+ # le processus au tout début (utile pour vérifier que tout marche).
+ if self.accelerator.is_local_main_process:
+ init_sample = next(self.save_init_dl)[0].to(self.device)
+
+ x_tot, _, _ = self.langevin.record_init_langevin(init_sample)
+
+ # Mise en forme pour plot : (steps, batch, C, H, W)
+ shape_len = len(x_tot.shape)
+ x_tot = x_tot.permute(1, 0, *list(range(2, shape_len)))
+ x_tot_plot = x_tot.detach()
+
+ # Sauvegarde grilles + éventuellement gif
+ self.plotter(init_sample, x_tot_plot, 0, 0, 'f')
+
+ # Libération mémoire
+ x_tot_plot = None
+ x_tot = None
+ torch.cuda.empty_cache()
+
+ # -------------------------
+ # Itérations IPF
+ # -------------------------
+ for n in range(self.checkpoint_it, self.n_ipf + 1):
+ print('IPF iteration: ' + str(n) + '/' + str(self.n_ipf))
+
+ # Reprise : si on doit démarrer sur forward à l'itération de checkpoint
+ if (self.checkpoint_pass == 'f') and (n == self.checkpoint_it):
+ self.ipf_step('f', n)
+ else:
+ # En routine : backward puis forward
+ self.ipf_step('b', n)
+ self.ipf_step('f', n)
+
+
+
+# ============================================================================
+# MAIN
+# ============================================================================
+
+if __name__ == '__main__':
+ """
+ Point d'entrée du script.
+
+ Concept :
+ - Ce bloc ne s'exécute que si on lance le fichier directement :
+ python train_dsb.py
+ - Il sert à afficher la configuration du run (pour vérifier qu'on n'a pas
+ oublié un paramètre), puis à lancer l'entraînement.
+ """
+
+ # Affiche un résumé des paramètres importants du run
+ print('=== DSB Training: HES -> CD30 ===')
+ print(f'Image size : {cfg.IMAGE_SIZE}') # taille des images (ex: 256x256)
+ print(f'Batch size : {cfg.BATCH_SIZE}') # batch size d'entraînement
+ print(f'Num iter : {cfg.NUM_ITER}') # itérations d'optimisation par étape IPF
+ print(f'Num IPF : {cfg.N_IPF}') # nombre d'itérations IPF (cycles b puis f)
+ print(f'Num steps : {cfg.NUM_STEPS}') # nombre d'étapes Langevin / diffusion
+ print(f'Device : {cfg.DEVICE}') # cpu / cuda
+ print(f'Transfer : {cfg.TRANSFER}') # utilise (ou non) des images du domaine final
+ print(f'Data dir : {cfg.DATA_DIR}') # chemin vers les données
+ print('Directory : ' + os.getcwd()) # dossier courant (où seront écrits logs/ckpt)
+
+ # Crée l'objet d'entraînement IPF (prépare modèles, données, langevin, etc.)
+ ipf = IPFSequential()
+
+ # Lance la boucle principale d'entraînement (IPF)
+ ipf.train()
diff --git a/train.sh b/train.sh
new file mode 100644
index 0000000..966e6bd
--- /dev/null
+++ b/train.sh
@@ -0,0 +1,13 @@
+#!/bin/ksh
+#$ -q gpu
+#$ -o result.out
+#$ -j y
+#$ -N diffusion_schrodinger_bridge
+cd $WORKDIR
+cd /beegfs/data/work/imvia/in156281/diffusion_schrodinger_bridge
+source /beegfs/data/work/imvia/in156281/diffusion_schrodinger_bridge/venv/bin/activate
+module load python
+export PYTHONPATH=/work/imvia/in156281/diffusion_schrodinger_bridge/venv/lib/python3.9/site-packages:$PYTHONPATH
+export MPLCONFIGDIR=/work/imvia/in156281/.cache/matplotlib
+
+python train.py
\ No newline at end of file