RankSEG is a Python module designed for segmentation tasks, aiming to maximize Dice or IoU metrics based on estimated probabilities.
- GitHub repo: https://github.com/statmlben/rankseg
- Slides: https://slides.com/statmlben/rankseg
- Paper: JMLR-v24-22-0712
- Poster: ICML2024
Most segmentation methods traditionally rely on IoU and Dice as evaluation metrics. During inference and prediction, these methods typically use a threshold of 0.5 or apply argmax to the estimated probabilities to generate segmentation predictions. However, this approach does not directly optimize the IoU or Dice metrics.
Our method, RankDice, directly optimizes IoU and Dice metrics.
- Nearly ensures improved Dice and IoU performance!
- Seamlessly integrates with any pretrained segmentation neural network (no need to retrain the models).
- A well-developed Python function rank_dice is available for use.
git clone https://github.com/statmlben/rankseg.git
pip install -r requirements.txt## `out_prob` (batch_size, num_class, width, height) is the output probability for each pixel based on a trained neural network
from rankseg import rank_dice
predict_rd, tau_rd, cutpoint_rd = rank_dice(out_prob, app=2, device='cuda')## `out_prob` (batch_size, num_class, width, height) is the output probability for each pixel based on a trained neural network
## Threshold
predict_T = torch.where(out_prob > .5, True, False)
## Argmax
idx = torch.argmax(out_prob.data, dim=1, keepdims=True)
predict_max = torch.zeros_like(out_prob.data, dtype=bool).scatter_(1, idx, True)## rankdice
$ python test.py -r saved/cityscapes/PSPNet/CrossEntropyLoss2d/T/05-04_13-08/checkpoint-epoch300.pth -p "rankdice"
TEST, Pred (rankdice) | Loss: 0.159, PixelAcc: 0.99, Mean IoU: 0.51, Mean Dice 0.59 |: 100%|██████| 84/84 [01:03<00:00, 1.33it/s]
## TESTING Restuls for Model: PSPNet + Loss: CrossEntropyLoss2d + predict: rankdice ##
test_loss : 0.15925
Pixel_Accuracy : 0.9879999756813049
Mean_IoU : 0.5099999904632568
Mean_Dice : 0.5929999947547913
Class_IoU : {0: 0.771, 1: 0.508, 2: 0.767, 3: 0.164, 4: 0.117, 5: 0.317, 6: 0.283, 7: 0.401, 8: 0.841, 9: 0.231, 10: 0.778, 11: 0.4, 12: 0.292, 13: 0.766, 14: 0.233, 15: 0.465, 16: 0.315, 17: 0.177, 18: 0.326}
Class_Dice : {0: 0.856, 1: 0.608, 2: 0.851, 3: 0.21, 4: 0.158, 5: 0.46, 6: 0.374, 7: 0.514, 8: 0.903, 9: 0.294, 10: 0.845, 11: 0.495, 12: 0.372, 13: 0.84, 14: 0.265, 15: 0.513, 16: 0.358, 17: 0.222, 18: 0.419}
## max
$ python test.py -r saved/cityscapes/PSPNet/CrossEntropyLoss2d/T/05-04_13-08/checkpoint-epoch300.pth -p "max"
TEST, Pred (max) | Loss: 0.159, PixelAcc: 0.99, Mean IoU: 0.49, Mean Dice 0.56 |: 100%|███████████| 84/84 [00:12<00:00, 6.52it/s]
## TESTING Restuls for Model: PSPNet + Loss: CrossEntropyLoss2d + predict: max ##
test_loss : 0.15925
Pixel_Accuracy : 0.9879999756813049
Mean_IoU : 0.48500001430511475
Mean_Dice : 0.5649999976158142
Class_IoU : {0: 0.768, 1: 0.489, 2: 0.759, 3: 0.133, 4: 0.099, 5: 0.295, 6: 0.257, 7: 0.387, 8: 0.836, 9: 0.208, 10: 0.769, 11: 0.372, 12: 0.272, 13: 0.751, 14: 0.204, 15: 0.395, 16: 0.268, 17: 0.152, 18: 0.303}
Class_Dice : {0: 0.854, 1: 0.585, 2: 0.844, 3: 0.172, 4: 0.136, 5: 0.428, 6: 0.341, 7: 0.498, 8: 0.9, 9: 0.268, 10: 0.835, 11: 0.464, 12: 0.351, 13: 0.826, 14: 0.233, 15: 0.437, 16: 0.308, 17: 0.193, 18: 0.392}
## threshold at 0.5
$ python test.py -r saved/cityscapes/PSPNet/CrossEntropyLoss2d/T/05-04_13-08/checkpoint-epoch300.pth -p "T"
TEST, Pred (T) | Loss: 0.159, PixelAcc: 0.99, Mean IoU: 0.50, Mean Dice 0.57 |: 100%|█████████████| 84/84 [00:13<00:00, 6.45it/s]
## TESTING Restuls for Model: PSPNet + Loss: CrossEntropyLoss2d + predict: T ##
test_loss : 0.15925
Pixel_Accuracy : 0.9890000224113464
Mean_IoU : 0.4959999918937683
Mean_Dice : 0.574999988079071
Class_IoU : {0: 0.772, 1: 0.478, 2: 0.762, 3: 0.136, 4: 0.109, 5: 0.29, 6: 0.265, 7: 0.39, 8: 0.841, 9: 0.201, 10: 0.77, 11: 0.363, 12: 0.273, 13: 0.769, 14: 0.219, 15: 0.422, 16: 0.307, 17: 0.158, 18: 0.325}
Class_Dice : {0: 0.857, 1: 0.573, 2: 0.846, 3: 0.174, 4: 0.147, 5: 0.419, 6: 0.349, 7: 0.499, 8: 0.902, 9: 0.257, 10: 0.836, 11: 0.451, 12: 0.351, 13: 0.841, 14: 0.247, 15: 0.468, 16: 0.349, 17: 0.197, 18: 0.414}Threshold,ArgmaxandrankDiceare performed based on the same network (inModelcolumn) trained by the same loss (inLosscolumn).- Averaged mDice and mIoU metrics based on state-of-the-art models/losses on Fine-annotated CityScapes val set. '/' indicates not applicable since the proposed
RankDice/mRankDicerequires a strictly proper loss. The best performance in each model/loss is bold-faced. - All trained neural networks and their
config.jsonwith differentnetworkandlossare saved in this link (12G folder: network/loss/.../*.pth+config.json)
| Model | Loss | Threshold (at 0.5) | Argmax | mRankDice (our) |
|---|---|---|---|---|
| (mDice, mIoU) ( |
(mDice, mIoU) ( |
(mDice, mIoU) ( |
||
| DeepLab-V3+ | CE | (56.00, 48.40) | (54.20, 46.60) | (57.80, 49.80) |
| (resnet101) | Focal | (54.10, 46.60) | (53.30, 45.60) | (56.50, 48.70) |
| BCE | (49.80, 24.90) | (44.20, 22.10) | (54.00, 27.00) | |
| Soft-Dice | (39.50, 35.90) | (39.50, 35.90) | / | |
| B-Soft-Dice | (41.00, 20.50) | (27.60, 13.80) | / | |
| LovaszSoftmax | (55.20, 47.60) | (52.30, 45.10) | / | |
| PSPNet | CE | (57.50, 49.60) | (56.50, 48.50) | (59.30, 51.00) |
| (resnet50) | Focal | (56.00, 48.20) | (55.80, 47.70) | (58.20, 50.00) |
| BCE | (51.40, 25.70) | (47.60, 23.80) | (55.10, 27.60) | |
| Soft-Dice | (49.10, 43.50) | (48.70, 43.20) | / | |
| B-Soft-Dice | (46.30, 23.10) | (32.70, 16.40) | / | |
| LovaszSoftmax | (56.80, 48.90) | (55.40, 47.70) | / | |
| FCN8 | CE | (51.40, 43.70) | (50.50, 42.60) | (53.50, 45.30) |
| (resnet101) | Focal | (48.50, 41.20) | (49.60, 41.60) | (51.50, 43.70) |
| BCE | (39.40, 19.70) | (39.40, 19.70) | (41.30, 20.60) | |
| Soft-Dice | (28.30, 24.30) | (28.30, 24.30) | / | |
| B-Soft-Dice | (29.10, 14.60) | (29.10, 14.60) | / | |
| LovaszSoftmax | (48.10, 40.40) | (42.90, 35.80) | / |
Threshold,ArgmaxandrankDiceare performed based on the same network (inModelcolumn) trained by the same loss (inLosscolumn).- Averaged mDice and mIoU based on state-of-the-art models/losses on PASCAL VOC 2012 val set. '---' indicates that either the performance is significantly worse or the training is unstable, and '/' indicates not applicable since the proposed
RankDice/mRankDicerequires a strictly proper loss. The best performance in each model-loss pair is bold-faced. - All trained neural networks with different
networkandlossare saved in this link (22G folder: network/loss/.../*.pth)
| Model | Loss | Threshold (at 0.5) | Argmax | mRankDice (our) |
|---|---|---|---|---|
| (mDice, mIoU) ( |
(mDice, mIoU) ( |
(mDice, mIoU) ( |
||
| DeepLab-V3+ | CE | (63.60, 56.70) | (61.90, 55.30) | (64.01, 57.01) |
| (resnet101) | Focal | (62.70, 55.01) | (60.50, 53.20) | (62.90, 55.10) |
| BCE | (63.30, 31.70) | (59.90, 29.90) | (64.60, 32.30) | |
| Soft-Dice | --- | --- | / | |
| B-Soft-Dice | --- | --- | / | |
| LovaszSoftmax | (57.70, 51.60) | (56.20, 50.30) | / | |
| PSPNet | CE | (64.60, 57.10) | (63.20, 55.90) | (65.40, 57.80) |
| (resnet50) | Focal | (64.00, 56.10) | (63.90, 56.10) | (66.60, 58.50) |
| BCE | (64.20, 32.10) | (65.20, 32.60) | (67.10, 33.50) | |
| Soft-Dice | (59.60, 54.00) | (58.80, 53.20) | / | |
| B-Soft-Dice | (63.30, 31.60) | (54.00. 27.00) | / | |
| LovaszSoftmax | (62.00, 55.20) | (60.80, 54.10) | / | |
| FCN8 | CE | (49.50, 41.90) | (45.30, 38.40) | (50.40, 42.70) |
| (resnet101) | Focal | (50.40, 41.80) | (47.20, 39.30) | (51.50, 42.50) |
| BCE | (46.20, 23.10) | (44.20, 22.10) | (47.70, 23.80) | |
| Soft-Dice | --- | --- | / | |
| B-Soft-Dice | --- | --- | / | |
| LovaszSoftmax | (39.80, 34.30) | (37.30, 32.20) | / |
Threshold,ArgmaxandrankDiceare performed based on the same network (inModelcolumn) trained by the same loss (inLosscolumn).ThresholdandArgmaxare exactly the same in binary segmentation.- Averaged mDice and mIoU based on state-of-the-art models/losses on Kvasir-SEG dataset set. '---' indicates that either the performance is significantly worse or the training is unstable, and '/' indicates not applicable since the proposed
RankDice/mRankDicerequires a strictly proper loss. The best performance in each model-loss pair is bold-faced.
| Model | Loss | Threshold/Argmax | mRankDice (our) |
|---|---|---|---|
| (Dice, IoU) ( |
(Dice, IoU) ( |
||
| DeepLab-V3+ | CE | (87.9, 80.7) | (88.3, 80.9) |
| (resnet101) | Focal | (86.5, 87.3) | / |
| Soft-Dice | (85.7, 77.8) | / | |
| LovaszSoftmax | (84.3, 77.3) | / | |
| PSPNet | CE | (86.3, 79.2) | (87.1, 79.8) |
| (resnet50) | Focal | (83.8, 75.4) | / |
| Soft-Dice | (83.5, 75.9) | / | |
| LovaszSoftmax | (86.0, 79.2) | / | |
| FCN8 | CE | (81.9, 73.5) | (82.1, 73.6) |
| (resnet101) | Focal | (78.5, 69.0) | / |
| Soft-Dice | --- | --- | |
| LovaszSoftmax | (82.0, 73.4) | / |
- All empirical results on different losses and models can be found here
If you want to replicate the experiments in our papers, please check the folder ./pytorch-segmentation-rankseg and its README file Pytorch-segmentation-rankseg
If you like RankSEG please star 🌟 the repository and cite the following paper:
@article{dai2023rankseg,
title={RankSEG: A Consistent Ranking-based Framework for Segmentation},
author={Dai, Ben and Li, Chunlin},
journal={Journal of Machine Learning Research},
volume={24},
number={224},
pages={1--50},
year={2023}
}If you find this repository helpful, please star our repo 🌟.
Thank you so much for the support from our stargazers.