Skip to content

Dropout fix [PyTorch]: Move dropout_rate from model init to model fwd pass #873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
b306076
dropout fix criteo, fastmri, vit, conf
Niccolo-Ajroldi Jun 10, 2025
3e7a396
dropout fix deepspeech, ogbg
Niccolo-Ajroldi Jun 11, 2025
e80add4
remove attention_dropout_rate from wmt
Niccolo-Ajroldi Jun 11, 2025
84b1bd1
dropout fix on wmt
Niccolo-Ajroldi Jun 11, 2025
af08bb9
fix dropout, ALL tested
Niccolo-Ajroldi Jun 11, 2025
7a6651a
add dropout equivalence tests
Niccolo-Ajroldi Jun 11, 2025
a7ff3d1
moved custom dropout to pytorch_utils
Niccolo-Ajroldi Jun 11, 2025
f26ab02
remove aux_dropout from pytorch workloads
Niccolo-Ajroldi Jun 11, 2025
e0a0e62
criteo rm dropout from init
Niccolo-Ajroldi Jun 12, 2025
1e2f379
criteo rm dropout from init
Niccolo-Ajroldi Jun 12, 2025
f10e3dc
criteo rm dropout from init
Niccolo-Ajroldi Jun 12, 2025
027b053
criteo rm dropout from init
Niccolo-Ajroldi Jun 12, 2025
74c43aa
fastmri rm dropout from init
Niccolo-Ajroldi Jun 12, 2025
64276ef
vit rm dropout at init
Niccolo-Ajroldi Jun 12, 2025
44029d2
vit rm dropout at init
Niccolo-Ajroldi Jun 12, 2025
44ffec1
add default dropout test
Niccolo-Ajroldi Jun 12, 2025
9d12fa6
add default dropout test
Niccolo-Ajroldi Jun 12, 2025
ac45a9f
conformer: rm dropout_rate from init
Niccolo-Ajroldi Jun 12, 2025
31d64f6
rm dropout_rate at init from all workloads
Niccolo-Ajroldi Jun 12, 2025
0128c9f
pipe dropout to model_fn, set default in workload
Niccolo-Ajroldi Jun 13, 2025
a7cba1a
remove aux_dropout from pytorch workloads
Niccolo-Ajroldi Jun 13, 2025
d8e39b0
fix to model_fn default dropout_rate
Niccolo-Ajroldi Jun 15, 2025
7a00158
rm models_dropout torch files
Niccolo-Ajroldi Jun 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions algoperf/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import jax
import tensorflow as tf
import torch
from torch import Tensor
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F

from algoperf import spec
from algoperf.profiler import Profiler
Expand Down Expand Up @@ -77,3 +80,38 @@ def update_batch_norm_fn(module: spec.ParameterContainer,
module.momentum = 0.0
elif hasattr(module, 'momentum_backup'):
module.momentum = module.momentum_backup


class CustomDropout(nn.Module):
"""A module around torch.nn.functional.dropout."""
def __init__(self):
super().__init__()
self._supports_custom_dropout = True

def forward(self, input: Tensor, p: float) -> Tensor:
return F.dropout(input, p, training=self.training)


class CustomDropout2d(nn.Module):
"""A module around torch.nn.functional.dropout2d."""
def __init__(self):
super().__init__()
self._supports_custom_dropout = True

def forward(self, input: Tensor, p: float) -> Tensor:
return F.dropout2d(input, p, training=self.training)


class SequentialWithDropout(nn.Sequential):
"""Sequential of modules with dropout."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._supports_custom_dropout = True

def forward(self, x: Tensor, p: float) -> Tensor:
for module in self:
if getattr(module, '_supports_custom_dropout', False):
x = module(x, p)
else:
x = module(x)
return x
6 changes: 4 additions & 2 deletions algoperf/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def init_model_fn(self,
# ModelAuxiliaryState,
# ForwardPassMode,
# RandomState,
# bool],
# bool,
# float],
# Tensor]
@abc.abstractmethod
def model_fn(self,
Expand All @@ -256,7 +257,8 @@ def model_fn(self,
model_state: ModelAuxiliaryState,
mode: ForwardPassMode,
rng: RandomState,
update_batch_norm: bool) -> Tuple[Tensor, ModelAuxiliaryState]:
update_batch_norm: bool,
dropout_rate: float) -> Tuple[Tensor, ModelAuxiliaryState]:
"""Return logits_batch"""
# Possible side effect of updating BN.

Expand Down
54 changes: 32 additions & 22 deletions algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,32 @@
import torch
from torch import nn

from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout

DROPOUT_RATE = 0.0


class DenseBlock(nn.Module):
"""Dense block with optional residual connection.""" ""

def __init__(self, module, resnet=False):
super().__init__()
self.module = module
self.resnet = resnet

def forward(self, x):
if self.resnet:
return self.module(x) + x
else:
return self.module(x)
return self.module(x) + x if self.resnet else self.module(x)


class DenseBlockWithDropout(nn.Module):
"""Dense block with optional residual connection and support for dropout."""
def __init__(self, module, resnet=False):
super().__init__()
self.module = module
self.resnet = resnet
self._supports_custom_dropout = True

def forward(self, x, p):
return self.module(x, p) + x if self.resnet else self.module(x, p)


class DotInteract(nn.Module):
Expand Down Expand Up @@ -58,7 +70,6 @@ def __init__(self,
mlp_bottom_dims=(256, 256, 256),
mlp_top_dims=(256, 256, 256, 256, 1),
embed_dim=128,
dropout_rate=0.0,
use_layer_norm=False,
embedding_init_multiplier=None):
super().__init__()
Expand Down Expand Up @@ -116,17 +127,16 @@ def __init__(self,
block.append(nn.Linear(fan_in, fan_out))
if layer_idx < (num_layers_top - 1):
block.append(nn.ReLU(inplace=True))
if (dropout_rate is not None and dropout_rate > 0.0 and
layer_idx == num_layers_top - 2):
block.append(nn.Dropout(p=dropout_rate))
block = nn.Sequential(*block)
if layer_idx == num_layers_top - 2:
block.append(CustomDropout())
block = SequentialWithDropout(*block)
if (layer_idx != 0) and (layer_idx != num_layers_top - 1):
block = DenseBlock(block, resnet=True)
block = DenseBlockWithDropout(block, resnet=True)
else:
block = DenseBlock(block)
block = DenseBlockWithDropout(block)
mlp_top_blocks.append(block)
fan_in = fan_out
self.top_mlp = nn.Sequential(*mlp_top_blocks)
self.top_mlp = SequentialWithDropout(*mlp_top_blocks)

for module in self.top_mlp.modules():
if isinstance(module, nn.Linear):
Expand All @@ -138,7 +148,8 @@ def __init__(self,
0.,
math.sqrt(1. / module.out_features))

def forward(self, x):
def forward(self, x, dropout_rate=DROPOUT_RATE):

batch_size = x.shape[0]

dense_features, sparse_features = torch.split(
Expand All @@ -157,7 +168,7 @@ def forward(self, x):
top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1)

# Final MLP.
logits = self.top_mlp(top_mlp_input)
logits = self.top_mlp(top_mlp_input, dropout_rate)
return logits


Expand All @@ -179,7 +190,6 @@ def __init__(self,
mlp_bottom_dims=(512, 256, 128),
mlp_top_dims=(1024, 1024, 512, 256, 1),
embed_dim=128,
dropout_rate=0.0,
use_layer_norm=False,
embedding_init_multiplier=None):
super().__init__()
Expand Down Expand Up @@ -242,10 +252,9 @@ def __init__(self,
top_mlp_layers.append(nn.ReLU(inplace=True))
if use_layer_norm:
top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6))
if (dropout_rate is not None and dropout_rate > 0.0 and
layer_idx == num_layers_top - 2):
top_mlp_layers.append(nn.Dropout(p=dropout_rate))
self.top_mlp = nn.Sequential(*top_mlp_layers)
if layer_idx == num_layers_top - 2:
top_mlp_layers.append(CustomDropout())
self.top_mlp = SequentialWithDropout(*top_mlp_layers)
if use_layer_norm:
self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6)
else:
Expand All @@ -260,7 +269,8 @@ def __init__(self,
0.,
math.sqrt(1. / module.out_features))

def forward(self, x):
def forward(self, x, dropout_rate=DROPOUT_RATE):

batch_size = x.shape[0]

dense_features, sparse_features = torch.split(
Expand All @@ -283,5 +293,5 @@ def forward(self, x):
dense_features=embedded_dense, sparse_features=embedded_sparse)

# Final MLP.
logits = self.top_mlp(concatenated_dense)
logits = self.top_mlp(concatenated_dense, dropout_rate)
return logits
12 changes: 4 additions & 8 deletions algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,7 @@ def loss_fn(

def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
"""Only dropout is used."""
del aux_dropout_rate
rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
# Disable cudnn benchmark to avoid OOM errors.
torch.backends.cudnn.benchmark = False
Expand All @@ -85,7 +81,6 @@ def init_model_fn(
mlp_bottom_dims=self.mlp_bottom_dims,
mlp_top_dims=self.mlp_top_dims,
embed_dim=self.embed_dim,
dropout_rate=dropout_rate,
use_layer_norm=self.use_layer_norm,
embedding_init_multiplier=self.embedding_init_multiplier)
self._param_shapes = param_utils.pytorch_param_shapes(model)
Expand All @@ -108,7 +103,8 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
Expand All @@ -128,7 +124,7 @@ def model_fn(
}

with contexts[mode]():
logits_batch = model(inputs)
logits_batch = model(inputs, dropout_rate=dropout_rate)

return logits_batch, None

Expand Down
44 changes: 24 additions & 20 deletions algoperf/workloads/fastmri/fastmri_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from torch.nn import functional as F

from algoperf import init_utils
from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout

DROPOUT_RATE = 0.0


class UNet(nn.Module):
Expand All @@ -27,7 +30,6 @@ def __init__(self,
out_chans: int = 1,
num_channels: int = 32,
num_pool_layers: int = 4,
dropout_rate: Optional[float] = 0.0,
use_tanh: bool = False,
use_layer_norm: bool = False) -> None:
super().__init__()
Expand All @@ -36,21 +38,19 @@ def __init__(self,
self.out_chans = out_chans
self.num_channels = num_channels
self.num_pool_layers = num_pool_layers
if dropout_rate is None:
dropout_rate = 0.0

self.down_sample_layers = nn.ModuleList([
ConvBlock(in_chans,
num_channels,
dropout_rate,
use_tanh,
use_layer_norm)
])
ch = num_channels
for _ in range(num_pool_layers - 1):
self.down_sample_layers.append(
ConvBlock(ch, ch * 2, dropout_rate, use_tanh, use_layer_norm))
ConvBlock(ch, ch * 2, use_tanh, use_layer_norm))
ch *= 2
self.conv = ConvBlock(ch, ch * 2, dropout_rate, use_tanh, use_layer_norm)
self.conv = ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)

self.up_conv = nn.ModuleList()
self.up_transpose_conv = nn.ModuleList()
Expand All @@ -59,39 +59,43 @@ def __init__(self,
self.up_transpose_conv.append(
TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
self.up_conv.append(
ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm))
ConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
ch //= 2

self.up_transpose_conv.append(
TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
self.up_conv.append(
nn.Sequential(
ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm),
SequentialWithDropout(
ConvBlock(ch * 2, ch, use_tanh, use_layer_norm),
nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
))

for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
init_utils.pytorch_default_init(m)

def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
dropout_rate: float = DROPOUT_RATE) -> Tensor:

stack = []
output = x

# apply down-sampling layers
for layer in self.down_sample_layers:
output = layer(output)
output = layer(output, dropout_rate)
stack.append(output)
output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)

output = self.conv(output)
output = self.conv(output, dropout_rate)

# apply up-sampling layers
for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):
downsample_layer = stack.pop()
output = transpose_conv(output)

# reflect pad on the right/botton if needed to handle
# reflect pad on the right/bottom if needed to handle
# odd input dimensions
padding = [0, 0, 0, 0]
if output.shape[-1] != downsample_layer.shape[-1]:
Expand All @@ -102,7 +106,7 @@ def forward(self, x: Tensor) -> Tensor:
output = F.pad(output, padding, "reflect")

output = torch.cat([output, downsample_layer], dim=1)
output = conv(output)
output = conv(output, dropout_rate)

return output

Expand All @@ -114,10 +118,10 @@ class ConvBlock(nn.Module):
def __init__(self,
in_chans: int,
out_chans: int,
dropout_rate: float,
use_tanh: bool,
use_layer_norm: bool) -> None:
super().__init__()
self._supports_custom_dropout = True

if use_layer_norm:
norm_layer = partial(nn.GroupNorm, 1, eps=1e-6)
Expand All @@ -127,19 +131,19 @@ def __init__(self,
activation_fn = nn.Tanh()
else:
activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.conv_layers = nn.Sequential(
self.conv_layers = SequentialWithDropout(
nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
norm_layer(out_chans),
activation_fn,
nn.Dropout2d(dropout_rate),
CustomDropout2d(),
nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
norm_layer(out_chans),
activation_fn,
nn.Dropout2d(dropout_rate),
CustomDropout2d(),
)

def forward(self, x: Tensor) -> Tensor:
return self.conv_layers(x)
def forward(self, x: Tensor, dropout_rate: float) -> Tensor:
return self.conv_layers(x, dropout_rate)


class TransposeConvBlock(nn.Module):
Expand Down
Loading
Loading