Skip to content

Conversation

@edoardogiacomello
Copy link
Contributor

@edoardogiacomello edoardogiacomello commented Dec 17, 2025

This PR wants to merge both the pseudo-epoch branch and the faster_prediction branches into main.

It adds the following main functionalities:

  • Pseudo-Epochs: Now the logging and the metric tracking is based on pseudo-epochs. This means that each "epoch" reported in WanDB corresponds to a fixed amount of batches that is configurable in the training configuration file. This has been made fully compatible with multi-GPU training. To make the training metrics comparable across different GPUs, (e.g., compare single GPU training vs 4GPU training) you can set accumulate_grad_batches to accumulate more gradient steps (e.g., set it to 4).
  • New Datamodule Stucture: I refactored DataModule to be dataset-agnostic and to include also test and prediction dataloaders. Dataset specifics belong now to different DatasetConfig classes.
  • New Test and Predict Script: I rewrote the testing and predictions script following a cleaner Lightning pattern. This also allows to run predictions on multiple GPUs. Test MUST still be run on a single GPU. (Since default Samplers/Dataloaders are used and they automatically resample batches to have the same amount of samples on each GPU. This needs to be avoided during testing.)
  • Zarr and Tiff Prediction Output: To ease generation of figures and comparison of models, the prediction script outputs both Zarr (also used as cached when multiprocessing) and Tiff files for the whole input image.
  • LiverFibSem Dataset Support: A new config class has been added to load LiverFibSem. This is possible due to the new DataModule and the Prediction and Test script that now support multiple input/output stacks (that is, if 2 test_keys are used as input, they will output a test/prediction for both)
  • Test / Prediction Masking can be configured easiliy: Prediction and Testing can be now masked on a per-stack basis. The user will select the center slice to predict, the "half" volume if prediction on multiple slice is desired, and an optional stepping (only for testing) to test on "one every N" pixels in all three dimensions to speedup test phase.

Detailed Changelog:

Please refer to commit messages

More Info:

Refer to Notion for more information.

[TODO] Fix device placement for lvae tensors
[Fix] Set logs_every_n_steps to allow alignment across different world_size
[Fix] Radius patience is increased only in semiSL mode
[Add] Integrated Test and Prediction in DataLoader
[Add] Config now has test-related parameters and volume masks
[Add] New test script that saves results and test predictions
[WIP] Prediction script on multiple gpu with integrated TIFF writing
[FIXME] Reintegrate liverfiber datamodule
[Ref] Refactored DataSetConfigs
[Ref] Refactored DataModule to get rid of Dataset-speficic Datamodules in favour of dataset-specific Configs
[Add] Now predict_step and test_step also returns the corresponding test_key for the predicted patch
[Fix] Adapted PredictionDataset to work with multiple images
[Fix] Adapted training script to use a dataset agnostic datamodule instead of betaseg
[Add] Updated Readme with information on how to add new datasets
[Fix] Prediction working with multiple keys and ranks, fixed perf. degradation
[Fix] Test OOM issues.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants