diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index 4a674985d..4af77088e 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -5,7 +5,10 @@ import jax import tensorflow as tf import torch +from torch import Tensor +import torch.nn as nn import torch.distributed as dist +import torch.nn.functional as F from algoperf import spec from algoperf.profiler import Profiler @@ -77,3 +80,38 @@ def update_batch_norm_fn(module: spec.ParameterContainer, module.momentum = 0.0 elif hasattr(module, 'momentum_backup'): module.momentum = module.momentum_backup + + +class CustomDropout(nn.Module): + """A module around torch.nn.functional.dropout.""" + def __init__(self): + super().__init__() + self._supports_custom_dropout = True + + def forward(self, input: Tensor, p: float) -> Tensor: + return F.dropout(input, p, training=self.training) + + +class CustomDropout2d(nn.Module): + """A module around torch.nn.functional.dropout2d.""" + def __init__(self): + super().__init__() + self._supports_custom_dropout = True + + def forward(self, input: Tensor, p: float) -> Tensor: + return F.dropout2d(input, p, training=self.training) + + +class SequentialWithDropout(nn.Sequential): + """Sequential of modules with dropout.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._supports_custom_dropout = True + + def forward(self, x: Tensor, p: float) -> Tensor: + for module in self: + if getattr(module, '_supports_custom_dropout', False): + x = module(x, p) + else: + x = module(x) + return x diff --git a/algoperf/spec.py b/algoperf/spec.py index cf4f1a14e..9670dcb76 100644 --- a/algoperf/spec.py +++ b/algoperf/spec.py @@ -247,7 +247,8 @@ def init_model_fn(self, # ModelAuxiliaryState, # ForwardPassMode, # RandomState, - # bool], + # bool, + # float], # Tensor] @abc.abstractmethod def model_fn(self, @@ -256,7 +257,8 @@ def model_fn(self, model_state: ModelAuxiliaryState, mode: ForwardPassMode, rng: RandomState, - update_batch_norm: bool) -> Tuple[Tensor, ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float) -> Tuple[Tensor, ModelAuxiliaryState]: """Return logits_batch""" # Possible side effect of updating BN. diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py index 7a40f0e81..f0653a665 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -5,20 +5,32 @@ import torch from torch import nn +from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout + +DROPOUT_RATE = 0.0 + class DenseBlock(nn.Module): """Dense block with optional residual connection.""" "" - def __init__(self, module, resnet=False): super().__init__() self.module = module self.resnet = resnet def forward(self, x): - if self.resnet: - return self.module(x) + x - else: - return self.module(x) + return self.module(x) + x if self.resnet else self.module(x) + + +class DenseBlockWithDropout(nn.Module): + """Dense block with optional residual connection and support for dropout.""" + def __init__(self, module, resnet=False): + super().__init__() + self.module = module + self.resnet = resnet + self._supports_custom_dropout = True + + def forward(self, x, p): + return self.module(x, p) + x if self.resnet else self.module(x, p) class DotInteract(nn.Module): @@ -58,7 +70,6 @@ def __init__(self, mlp_bottom_dims=(256, 256, 256), mlp_top_dims=(256, 256, 256, 256, 1), embed_dim=128, - dropout_rate=0.0, use_layer_norm=False, embedding_init_multiplier=None): super().__init__() @@ -116,17 +127,16 @@ def __init__(self, block.append(nn.Linear(fan_in, fan_out)) if layer_idx < (num_layers_top - 1): block.append(nn.ReLU(inplace=True)) - if (dropout_rate is not None and dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - block.append(nn.Dropout(p=dropout_rate)) - block = nn.Sequential(*block) + if layer_idx == num_layers_top - 2: + block.append(CustomDropout()) + block = SequentialWithDropout(*block) if (layer_idx != 0) and (layer_idx != num_layers_top - 1): - block = DenseBlock(block, resnet=True) + block = DenseBlockWithDropout(block, resnet=True) else: - block = DenseBlock(block) + block = DenseBlockWithDropout(block) mlp_top_blocks.append(block) fan_in = fan_out - self.top_mlp = nn.Sequential(*mlp_top_blocks) + self.top_mlp = SequentialWithDropout(*mlp_top_blocks) for module in self.top_mlp.modules(): if isinstance(module, nn.Linear): @@ -138,7 +148,8 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x): + def forward(self, x, dropout_rate=DROPOUT_RATE): + batch_size = x.shape[0] dense_features, sparse_features = torch.split( @@ -157,7 +168,7 @@ def forward(self, x): top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) # Final MLP. - logits = self.top_mlp(top_mlp_input) + logits = self.top_mlp(top_mlp_input, dropout_rate) return logits @@ -179,7 +190,6 @@ def __init__(self, mlp_bottom_dims=(512, 256, 128), mlp_top_dims=(1024, 1024, 512, 256, 1), embed_dim=128, - dropout_rate=0.0, use_layer_norm=False, embedding_init_multiplier=None): super().__init__() @@ -242,10 +252,9 @@ def __init__(self, top_mlp_layers.append(nn.ReLU(inplace=True)) if use_layer_norm: top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) - if (dropout_rate is not None and dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - top_mlp_layers.append(nn.Dropout(p=dropout_rate)) - self.top_mlp = nn.Sequential(*top_mlp_layers) + if layer_idx == num_layers_top - 2: + top_mlp_layers.append(CustomDropout()) + self.top_mlp = SequentialWithDropout(*top_mlp_layers) if use_layer_norm: self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) else: @@ -260,7 +269,8 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x): + def forward(self, x, dropout_rate=DROPOUT_RATE): + batch_size = x.shape[0] dense_features, sparse_features = torch.split( @@ -283,5 +293,5 @@ def forward(self, x): dense_features=embedded_dense, sparse_features=embedded_sparse) # Final MLP. - logits = self.top_mlp(concatenated_dense) + logits = self.top_mlp(concatenated_dense, dropout_rate) return logits diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 726aa8705..48c6592f2 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -67,11 +67,7 @@ def loss_fn( def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Only dropout is used.""" - del aux_dropout_rate + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) # Disable cudnn benchmark to avoid OOM errors. torch.backends.cudnn.benchmark = False @@ -85,7 +81,6 @@ def init_model_fn( mlp_bottom_dims=self.mlp_bottom_dims, mlp_top_dims=self.mlp_top_dims, embed_dim=self.embed_dim, - dropout_rate=dropout_rate, use_layer_norm=self.use_layer_norm, embedding_init_multiplier=self.embedding_init_multiplier) self._param_shapes = param_utils.pytorch_param_shapes(model) @@ -108,7 +103,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -128,7 +124,7 @@ def model_fn( } with contexts[mode](): - logits_batch = model(inputs) + logits_batch = model(inputs, dropout_rate=dropout_rate) return logits_batch, None diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models.py b/algoperf/workloads/fastmri/fastmri_pytorch/models.py index 28f20bf20..0b8ac5499 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models.py @@ -13,6 +13,9 @@ from torch.nn import functional as F from algoperf import init_utils +from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout + +DROPOUT_RATE = 0.0 class UNet(nn.Module): @@ -27,7 +30,6 @@ def __init__(self, out_chans: int = 1, num_channels: int = 32, num_pool_layers: int = 4, - dropout_rate: Optional[float] = 0.0, use_tanh: bool = False, use_layer_norm: bool = False) -> None: super().__init__() @@ -36,21 +38,19 @@ def __init__(self, self.out_chans = out_chans self.num_channels = num_channels self.num_pool_layers = num_pool_layers - if dropout_rate is None: - dropout_rate = 0.0 + self.down_sample_layers = nn.ModuleList([ ConvBlock(in_chans, num_channels, - dropout_rate, use_tanh, use_layer_norm) ]) ch = num_channels for _ in range(num_pool_layers - 1): self.down_sample_layers.append( - ConvBlock(ch, ch * 2, dropout_rate, use_tanh, use_layer_norm)) + ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)) ch *= 2 - self.conv = ConvBlock(ch, ch * 2, dropout_rate, use_tanh, use_layer_norm) + self.conv = ConvBlock(ch, ch * 2, use_tanh, use_layer_norm) self.up_conv = nn.ModuleList() self.up_transpose_conv = nn.ModuleList() @@ -59,14 +59,14 @@ def __init__(self, self.up_transpose_conv.append( TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) self.up_conv.append( - ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm)) + ConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) ch //= 2 self.up_transpose_conv.append( TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) self.up_conv.append( - nn.Sequential( - ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm), + SequentialWithDropout( + ConvBlock(ch * 2, ch, use_tanh, use_layer_norm), nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), )) @@ -74,24 +74,28 @@ def __init__(self, if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): init_utils.pytorch_default_init(m) - def forward(self, x: Tensor) -> Tensor: + def forward( + self, + x: Tensor, + dropout_rate: float = DROPOUT_RATE) -> Tensor: + stack = [] output = x # apply down-sampling layers for layer in self.down_sample_layers: - output = layer(output) + output = layer(output, dropout_rate) stack.append(output) output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) - output = self.conv(output) + output = self.conv(output, dropout_rate) # apply up-sampling layers for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): downsample_layer = stack.pop() output = transpose_conv(output) - # reflect pad on the right/botton if needed to handle + # reflect pad on the right/bottom if needed to handle # odd input dimensions padding = [0, 0, 0, 0] if output.shape[-1] != downsample_layer.shape[-1]: @@ -102,7 +106,7 @@ def forward(self, x: Tensor) -> Tensor: output = F.pad(output, padding, "reflect") output = torch.cat([output, downsample_layer], dim=1) - output = conv(output) + output = conv(output, dropout_rate) return output @@ -114,10 +118,10 @@ class ConvBlock(nn.Module): def __init__(self, in_chans: int, out_chans: int, - dropout_rate: float, use_tanh: bool, use_layer_norm: bool) -> None: super().__init__() + self._supports_custom_dropout = True if use_layer_norm: norm_layer = partial(nn.GroupNorm, 1, eps=1e-6) @@ -127,19 +131,19 @@ def __init__(self, activation_fn = nn.Tanh() else: activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.conv_layers = nn.Sequential( + self.conv_layers = SequentialWithDropout( nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), norm_layer(out_chans), activation_fn, - nn.Dropout2d(dropout_rate), + CustomDropout2d(), nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), norm_layer(out_chans), activation_fn, - nn.Dropout2d(dropout_rate), + CustomDropout2d(), ) - def forward(self, x: Tensor) -> Tensor: - return self.conv_layers(x) + def forward(self, x: Tensor, dropout_rate: float) -> Tensor: + return self.conv_layers(x, dropout_rate) class TransposeConvBlock(nn.Module): diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 58943de2f..9b96230fc 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -13,6 +13,7 @@ from algoperf import pytorch_utils from algoperf import spec import algoperf.random_utils as prng +from algoperf.workloads.fastmri.fastmri_pytorch import models from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload @@ -107,17 +108,13 @@ def _build_input_queue(self, def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - del aux_dropout_rate + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = UNet( num_pool_layers=self.num_pool_layers, num_channels=self.num_channels, use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm, - dropout_rate=dropout_rate) + use_layer_norm=self.use_layer_norm) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -138,7 +135,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -158,8 +156,8 @@ def model_fn( with contexts[mode](): logit_batch = model( - augmented_and_preprocessed_input_batch['inputs'].unsqueeze( - 1)).squeeze(1) + augmented_and_preprocessed_input_batch['inputs'].unsqueeze(1), + dropout_rate=dropout_rate).squeeze(1) return logit_batch, None diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index ed29271f3..f28eb1762 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -156,12 +156,7 @@ def _build_dataset( def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Dropout is unused.""" - del dropout_rate - del aux_dropout_rate + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) if self.use_silu and self.use_gelu: @@ -194,9 +189,11 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng + del dropout_rate model = params diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index fcf0992d3..60e09edb5 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -14,7 +14,9 @@ from algoperf import init_utils from algoperf import spec -from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention +from algoperf.workloads.wmt.wmt_pytorch.models_dropout import MultiheadAttention + +DROPOUT_RATE = 0.0 def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: @@ -41,18 +43,15 @@ def __init__( self, width: int, mlp_dim: Optional[int] = None, # Defaults to 4x input dim. - use_glu: bool = False, - dropout_rate: float = 0.0) -> None: + use_glu: bool = False) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim or 4 * width self.use_glu = use_glu - self.dropout_rate = dropout_rate self.linear1 = nn.Linear(self.width, self.mlp_dim) self.act_fnc = nn.GELU(approximate='tanh') - self.dropout = nn.Dropout(self.dropout_rate) if self.use_glu: self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) @@ -70,7 +69,8 @@ def reset_parameters(self) -> None: if module.bias is not None: module.bias.data.normal_(std=1e-6) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: + x = self.linear1(x) x = self.act_fnc(x) @@ -78,7 +78,7 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: y = self.glu_linear(x) x = x * y - x = self.dropout(x) + x = F.dropout(x, dropout_rate, training=self.training) x = self.linear2(x) return x @@ -88,8 +88,7 @@ class SelfAttention(nn.Module): def __init__(self, width: int, - num_heads: int = 8, - dropout_rate: float = 0.0) -> None: + num_heads: int = 8) -> None: super().__init__() self.width = width @@ -104,7 +103,6 @@ def __init__(self, self.query = nn.Linear(self.width, self.all_head_dim) self.key = nn.Linear(self.width, self.all_head_dim) self.value = nn.Linear(self.width, self.all_head_dim) - self.dropout = nn.Dropout(dropout_rate) self.out = nn.Linear(self.width, self.width) self.reset_parameters() @@ -120,7 +118,8 @@ def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor: x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: + mixed_query_layer = self.query(x) key_layer = self.transpose_for_scores(self.key(x)) @@ -131,7 +130,7 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: attention_scores = attention_scores / math.sqrt(self.head_dim) attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.dropout(attention_probs) + attention_probs = F.dropout(attention_probs, dropout_rate, self.training) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() @@ -149,8 +148,7 @@ def __init__(self, mlp_dim: Optional[int] = None, num_heads: int = 12, use_glu: bool = False, - use_post_layer_norm: bool = False, - dropout_rate: float = 0.0) -> None: + use_post_layer_norm: bool = False) -> None: super().__init__() self.width = width @@ -161,35 +159,34 @@ def __init__(self, self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) self.self_attention1 = SelfAttention(self.width, self.num_heads) - self.dropout = nn.Dropout(dropout_rate) self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) self.mlp3 = MlpBlock( width=self.width, mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - dropout_rate=dropout_rate) + use_glu=self.use_glu) + + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - def forward(self, x: spec.Tensor) -> spec.Tensor: if not self.use_post_layer_norm: y = self.layer_norm0(x) - y = self.self_attention1(y) - y = self.dropout(y) + y = self.self_attention1(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y y = self.layer_norm2(x) - y = self.mlp3(y) - y = self.dropout(y) + y = self.mlp3(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y else: y = x - y = self.self_attention1(y) - y = self.dropout(y) + y = self.self_attention1(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y x = self.layer_norm0(x) y = x - y = self.mlp3(y) - y = self.dropout(y) + y = self.mlp3(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y x = self.layer_norm2(x) return x @@ -204,8 +201,7 @@ def __init__(self, mlp_dim: Optional[int] = None, num_heads: int = 12, use_glu: bool = False, - use_post_layer_norm: bool = False, - dropout_rate: float = 0.0) -> None: + use_post_layer_norm: bool = False) -> None: super().__init__() self.depth = depth @@ -220,8 +216,7 @@ def __init__(self, self.mlp_dim, self.num_heads, self.use_glu, - self.use_post_layer_norm, - dropout_rate) for _ in range(depth) + self.use_post_layer_norm) for _ in range(depth) ]) if not self.use_post_layer_norm: @@ -229,10 +224,10 @@ def __init__(self, else: self.encoder_norm = None - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: # Input Encoder. for block in self.net: - x = block(x) + x = block(x, dropout_rate) if not self.use_post_layer_norm: return self.encoder_norm(x) else: @@ -259,13 +254,13 @@ def __init__(self, self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: n, _, _ = x.shape probe = torch.tile(self.probe, [n, 1, 1]) - x = self.mha(probe, x)[0] + x = self.mha(probe, x, dropout_rate=dropout_rate)[0] y = self.layer_norm(x) - x = x + self.mlp(y) + x = x + self.mlp(y, dropout_rate) return x[:, 0] @@ -285,15 +280,12 @@ def __init__( mlp_dim: Optional[int] = None, # Defaults to 4x input dim. num_heads: int = 12, rep_size: Union[int, bool] = True, - dropout_rate: Optional[float] = 0.0, head_zeroinit: bool = True, use_glu: bool = False, use_post_layer_norm: bool = False, use_map: bool = False, dtype: Any = torch.float32) -> None: super().__init__() - if dropout_rate is None: - dropout_rate = 0.0 self.num_classes = num_classes self.patch_size = patch_size @@ -318,7 +310,6 @@ def __init__( self.patch_size, stride=self.patch_size, padding='valid') - self.dropout = nn.Dropout(p=dropout_rate) self.encoder = Encoder( depth=self.depth, @@ -326,8 +317,7 @@ def __init__( mlp_dim=self.mlp_dim, num_heads=self.num_heads, use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=dropout_rate) + use_post_layer_norm=self.use_post_layer_norm) if self.num_classes: self.head = nn.Linear(self.width, self.num_classes) @@ -355,7 +345,11 @@ def reset_parameters(self) -> None: def get_posemb(self, x: spec.Tensor) -> spec.Tensor: return posemb_sincos_2d(x).type(self.dtype) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward( + self, + x: spec.Tensor, + dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: + # Patch extraction. x = self.conv_patch_extract(x) @@ -367,11 +361,11 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: x = torch.transpose(torch.reshape(x, (n, c, h * w)), 1, 2) x = x + pes - x = self.dropout(x) - x = self.encoder(x) + x = F.dropout(x, dropout_rate, training=self.training) + x = self.encoder(x, dropout_rate) if self.use_map: - x = self.map(x) + x = self.map(x, dropout_rate) else: x = torch.mean(x, dim=1) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 97bb38515..f86a1b1c2 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -23,13 +23,9 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - del aux_dropout_rate + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = models.ViT( - dropout_rate=dropout_rate, num_classes=self._num_classes, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, @@ -55,7 +51,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -74,7 +71,8 @@ def model_fn( } with contexts[mode](): - logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) + logits_batch = model(augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate) return logits_batch, None diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py index db1e24521..a6a60bf95 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from functools import partial import math -from typing import Tuple +from typing import Optional, Tuple import torch from torch import nn @@ -17,6 +17,8 @@ from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ SpecAug +DROPOUT_RATE = 0.1 + @dataclass class ConformerConfig: @@ -26,10 +28,7 @@ class ConformerConfig: num_attention_heads: int = 8 num_encoder_layers: int = 4 attention_dropout_rate: float = 0.0 - attention_residual_dropout_rate: float = 0.1 - conv_residual_dropout_rate: float = 0.0 feed_forward_dropout_rate: float = 0.0 - feed_forward_residual_dropout_rate: float = 0.1 convolution_kernel_size: int = 5 feed_forward_expansion_factor: int = 4 freq_mask_count: int = 2 @@ -39,7 +38,6 @@ class ConformerConfig: time_mask_max_ratio: float = 0.05 time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True - input_dropout_rate: float = 0.1 batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 use_specaug: bool = True @@ -77,11 +75,9 @@ class Subsample(nn.Module): def __init__(self, encoder_dim: int = 0, - input_dropout_rate: float = 0.0, num_bins: int = 80): super().__init__() self.encoder_dim = encoder_dim - self.input_dropout_rate = input_dropout_rate self.conv1 = Conv2dSubsampling( input_channels=1, output_channels=encoder_dim) @@ -93,9 +89,9 @@ def __init__(self, out_features=self.encoder_dim, bias=True) self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) - self.dropout = nn.Dropout(p=self.input_dropout_rate, inplace=True) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): + output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -109,7 +105,7 @@ def forward(self, inputs, input_paddings): outputs = self.linear(outputs) outputs = outputs + self.pos_encode(seq_length=outputs.shape[1]) - outputs = self.dropout(outputs) + outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) return outputs, output_paddings @@ -201,15 +197,8 @@ def __init__(self, config: ConformerConfig): out_features=config.encoder_dim, bias=True) - if config.feed_forward_residual_dropout_rate is None: - feed_forward_residual_dropout_rate = 0.1 - else: - feed_forward_residual_dropout_rate = ( - config.feed_forward_residual_dropout_rate) - self.dropout2 = nn.Dropout( - p=feed_forward_residual_dropout_rate, inplace=True) + def forward(self, inputs, padding_mask, dropout_rate): - def forward(self, inputs, padding_mask): inputs = self.ln(inputs) inputs = self.linear1(inputs) if self.config.activation_function_name == 'swish': @@ -226,7 +215,7 @@ def forward(self, inputs, padding_mask): inputs = inputs * padding_mask inputs = self.linear2(inputs) inputs = inputs * padding_mask - inputs = self.dropout2(inputs) + inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) return inputs @@ -280,7 +269,7 @@ def __init__(self, config: ConformerConfig): super().__init__() self.embed_dim = config.encoder_dim self.num_heads = config.num_attention_heads - self.dropout = config.attention_dropout_rate + self.attention_dropout_rate = config.attention_dropout_rate self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim) self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim) self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads) @@ -297,7 +286,7 @@ def forward(self, inputs, key_padding_mask=None): key=k, value=v, attn_mask=~key_padding_mask[:, None, None], - dropout_p=self.dropout, + dropout_p=self.attention_dropout_rate, ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim) out = out * self.attention_temperature out = self.out_proj(out) @@ -313,19 +302,14 @@ def __init__(self, config: ConformerConfig): self.ln = LayerNorm(dim=config.encoder_dim) self.self_attention = MHSAwithQS(config) - if config.attention_residual_dropout_rate is None: - attention_residual_dropout_rate = 0.1 - else: - attention_residual_dropout_rate = config.attention_residual_dropout_rate - self.dropout = nn.Dropout(p=attention_residual_dropout_rate, inplace=True) - def forward(self, outputs, paddings): + def forward(self, outputs, paddings, dropout_rate): outputs = self.ln(outputs) outputs = self.self_attention( outputs, key_padding_mask=paddings == 1, ) - outputs = self.dropout(outputs) + outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) return outputs @@ -405,13 +389,8 @@ def __init__(self, config): groups=config.encoder_dim) self.bn = BatchNorm(config) self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim) - if config.conv_residual_dropout_rate is None: - conv_residual_dropout_rate = 0.0 - else: - conv_residual_dropout_rate = config.conv_residual_dropout_rate - self.dropout = nn.Dropout(p=conv_residual_dropout_rate, inplace=True) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): inputs = self.ln(inputs) inputs = F.glu(torch.cat([self.lin1(inputs), self.lin2(inputs)], dim=2)) @@ -433,7 +412,7 @@ def forward(self, inputs, input_paddings): inputs = activation_fn(inputs) inputs = self.lin3(inputs) - inputs = self.dropout(inputs) + inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) return inputs @@ -450,12 +429,12 @@ def __init__(self, config: ConformerConfig): if config.use_post_layer_norm: self.ln = LayerNorm(dim=config.encoder_dim) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): padding_mask = 1 - input_paddings[:, :, None] - inputs = inputs + 0.5 * self.ff1(inputs, padding_mask) - inputs = inputs + self.mhsa(inputs, input_paddings) - inputs = inputs + self.conv(inputs, input_paddings) - inputs = inputs + 0.5 * self.ff2(inputs, padding_mask) + inputs = inputs + 0.5 * self.ff1(inputs, padding_mask, dropout_rate) + inputs = inputs + self.mhsa(inputs, input_paddings, dropout_rate) + inputs = inputs + self.conv(inputs, input_paddings, dropout_rate) + inputs = inputs + 0.5 * self.ff2(inputs, padding_mask, dropout_rate) if self.ln: inputs = self.ln(inputs) return inputs @@ -480,13 +459,8 @@ def __init__(self, config: ConformerConfig): time_masks_per_frame=config.time_masks_per_frame, use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames ) - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate self.subsample = Subsample( encoder_dim=config.encoder_dim, - input_dropout_rate=input_dropout_rate, num_bins=preprocessing_config.num_bins) self.conformers = nn.ModuleList( [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) @@ -494,15 +468,15 @@ def __init__(self, config: ConformerConfig): self.ln = LayerNorm(config.encoder_dim) self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs = inputs output_paddings = input_paddings outputs, output_paddings = self.preprocessor(outputs, output_paddings) if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings) + outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) for conformer in self.conformers: - outputs = conformer(outputs, output_paddings) + outputs = conformer(outputs, output_paddings, dropout_rate) outputs = self.ln(outputs) outputs = self.lin(outputs) return outputs, output_paddings diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 5ed37957e..0477a7389 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -63,14 +63,8 @@ def attention_temperature(self) -> float: def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Conformer model init function. - - Here we use dropout_rate as residual_dropout_rate, and aux_dropout_rate as - input_dropout_rate. - """ + rng: spec.RandomState) -> spec.ModelInitState: + """Conformer model init function.""" torch.random.manual_seed(rng[0]) # Configure torch backends to avoid OOM errors. torch.backends.cudnn.benchmark = False @@ -83,10 +77,6 @@ def init_model_fn( activation_function_name = 'swish' model = models.ConformerEncoderDecoder( models.ConformerConfig( - attention_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - conv_residual_dropout_rate=dropout_rate, - input_dropout_rate=aux_dropout_rate, use_specaug=self.use_specaug, attention_temperature=self.attention_temperature, use_post_layer_norm=self.use_post_layer_norm, @@ -115,7 +105,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng @@ -136,7 +127,8 @@ def model_fn( with contexts[mode](): inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] logits, logits_paddings = model(inputs.to(DEVICE), - input_paddings.to(DEVICE)) + input_paddings.to(DEVICE), + dropout_rate=dropout_rate) return (logits, logits_paddings), None def _build_input_queue( diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index 84d317326..a8480a343 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -17,6 +17,7 @@ SpecAug USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ +DROPOUT_RATE = 0.1 @dataclass @@ -38,10 +39,6 @@ class DeepspeechConfig: use_dynamic_time_mask_max_frames: bool = True batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.1. - feed_forward_dropout_rate: Optional[float] = 0.1 enable_residual_connections: bool = True enable_decoder_layer_norm: bool = True bidirectional: bool = True @@ -87,13 +84,8 @@ def __init__(self, config: DeepspeechConfig): self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True) - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate - self.dropout = nn.Dropout(p=input_dropout_rate) + def forward(self, inputs, input_paddings, dropout_rate): - def forward(self, inputs, input_paddings): output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -106,7 +98,7 @@ def forward(self, inputs, input_paddings): subsampled_dims * channels) outputs = self.lin(outputs) - outputs = self.dropout(outputs) + outputs = F.dropout(outputs, dropout_rate, training=self.training) return outputs, output_paddings @@ -205,13 +197,9 @@ def __init__(self, config: DeepspeechConfig): batch_norm_momentum=config.batch_norm_momentum, batch_norm_epsilon=config.batch_norm_epsilon) self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True) - if config.feed_forward_dropout_rate is None: - feed_forward_dropout_rate = 0.1 - else: - feed_forward_dropout_rate = config.feed_forward_dropout_rate - self.dropout = nn.Dropout(p=feed_forward_dropout_rate) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): + padding_mask = (1 - input_paddings)[:, :, None] if self.config.layernorm_everywhere: inputs = self.normalization_layer(inputs) @@ -226,7 +214,7 @@ def forward(self, inputs, input_paddings): inputs = F.relu(inputs) inputs = inputs * padding_mask - inputs = self.dropout(inputs) + inputs = F.dropout(inputs, dropout_rate, training=self.training) return inputs @@ -363,14 +351,14 @@ def __init__(self, config: DeepspeechConfig): self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs = inputs output_paddings = input_paddings outputs, output_paddings = self.preprocessor(outputs, output_paddings) if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings) + outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) for idx in range(self.config.num_lstm_layers): if self.config.enable_residual_connections: outputs = outputs + self.lstms[idx](outputs, output_paddings) @@ -379,9 +367,9 @@ def forward(self, inputs, input_paddings): for idx in range(self.config.num_ffn_layers): if self.config.enable_residual_connections: - outputs = outputs + self.ffns[idx](outputs, output_paddings) + outputs = outputs + self.ffns[idx](outputs, output_paddings, dropout_rate) else: - outputs = self.ffns[idx](outputs, output_paddings) + outputs = self.ffns[idx](outputs, output_paddings, dropout_rate) if self.config.enable_decoder_layer_norm: outputs = self.ln(outputs) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index e5387f5cb..bf345cfc9 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Dict, Tuple import torch from torch.nn.parallel import DistributedDataParallel as DDP @@ -6,6 +6,7 @@ from algoperf import param_utils from algoperf import spec from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ initialize from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ @@ -24,20 +25,12 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Deepspeech model init function. - - Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate - as input_dropout_rate. - """ + rng: spec.RandomState) -> spec.ModelInitState: + """Deepspeech model init function.""" torch.random.manual_seed(rng[0]) model = DeepspeechEncoderDecoder( DeepspeechConfig( - feed_forward_dropout_rate=dropout_rate, use_specaug=self.use_specaug, - input_dropout_rate=aux_dropout_rate, use_tanh=self.use_tanh, enable_residual_connections=self.enable_residual_connections, enable_decoder_layer_norm=self.enable_decoder_layer_norm, @@ -62,6 +55,20 @@ def init_model_fn( else: model = torch.nn.DataParallel(model) return model, None + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + # override super method, changing only the default dropout_rate + return super().model_fn( + params, augmented_and_preprocessed_input_batch, model_state, + mode, rng, update_batch_norm, dropout_rate) def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['lin.weight', 'lin.bias'] diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models.py b/algoperf/workloads/ogbg/ogbg_pytorch/models.py index fe9b29bc1..be5882333 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/models.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models.py @@ -9,17 +9,20 @@ from torch import nn from algoperf import init_utils +from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout +DROPOUT_RATE = 0.1 -def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): + +def _make_mlp(in_dim, hidden_dims, activation_fn): """Creates a MLP with specified dimensions.""" - layers = nn.Sequential() + layers = SequentialWithDropout() for i, dim in enumerate(hidden_dims): layers.add_module(f'dense_{i}', nn.Linear(in_features=in_dim, out_features=dim)) layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) layers.add_module(f'activation_fn_{i}', activation_fn()) - layers.add_module(f'dropout_{i}', nn.Dropout(dropout_rate)) + layers.add_module(f'dropout_{i}', CustomDropout()) in_dim = dim return layers @@ -33,7 +36,6 @@ class GNN(nn.Module): def __init__(self, num_outputs: int = 128, - dropout_rate: Optional[float] = 0.1, activation_fn_name: str = 'relu', latent_dim: int = 256, hidden_dims: Tuple[int] = (256,), @@ -43,8 +45,6 @@ def __init__(self, self.hidden_dims = hidden_dims self.num_message_passing_steps = num_message_passing_steps self.num_outputs = num_outputs - if dropout_rate is None: - dropout_rate = 0.1 # in_features are specifically chosen for the ogbg workload. self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim) self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim) @@ -77,17 +77,14 @@ def __init__(self, GraphNetwork( update_edge_fn=_make_mlp(in_dim_edge_fn, self.hidden_dims, - dropout_rate, activation_fn), update_node_fn=_make_mlp(in_dim_node_fn, self.hidden_dims, - dropout_rate, activation_fn), update_global_fn=_make_mlp(last_in_dim, self.hidden_dims, - dropout_rate, activation_fn))) - self.graph_network = nn.Sequential(*graph_network_layers) + self.graph_network = SequentialWithDropout(*graph_network_layers) self.decoder = nn.Linear( in_features=self.hidden_dims[-1], out_features=self.num_outputs) @@ -96,14 +93,18 @@ def __init__(self, if isinstance(m, nn.Linear): init_utils.pytorch_default_init(m) - def forward(self, graph: GraphsTuple) -> torch.Tensor: + def forward( + self, + graph: GraphsTuple, + dropout_rate: float = DROPOUT_RATE) -> torch.Tensor: + graph = graph._replace( globals=torch.zeros([graph.n_node.shape[0], self.num_outputs], device=graph.n_node.device)) graph = graph._replace(nodes=self.node_embedder(graph.nodes)) graph = graph._replace(edges=self.edge_embedder(graph.edges)) - graph = self.graph_network(graph) + graph = self.graph_network(graph, dropout_rate) # Map globals to represent the final result graph = graph._replace(globals=self.decoder(graph.globals)) @@ -145,8 +146,9 @@ def __init__(self, self.update_edge_fn = update_edge_fn self.update_node_fn = update_node_fn self.update_global_fn = update_global_fn + self._supports_custom_dropout = True # supports SequentialWithDropout - def forward(self, graph: GraphsTuple) -> GraphsTuple: + def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple: """Applies a configured GraphNetwork to a graph. This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 There is one difference. For the nodes update the class aggregates over the @@ -159,6 +161,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: GraphNets, for more information please see the paper. Args: graph: a `GraphsTuple` containing the graph. + dropout_rate: dropout probability value. Returns: Updated `GraphsTuple`. """ @@ -179,7 +182,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: edge_fn_inputs = torch.cat( [edges, sent_attributes, received_attributes, global_edge_attributes], dim=-1) - edges = self.update_edge_fn(edge_fn_inputs) + edges = self.update_edge_fn(edge_fn_inputs, dropout_rate) if self.update_node_fn: sent_attributes = tree.tree_map( @@ -194,7 +197,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: node_fn_inputs = torch.cat( [nodes, sent_attributes, received_attributes, global_attributes], dim=-1) - nodes = self.update_node_fn(node_fn_inputs) + nodes = self.update_node_fn(node_fn_inputs, dropout_rate) if self.update_global_fn: n_graph = n_node.shape[0] @@ -213,7 +216,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: # These pooled nodes are the inputs to the global update fn. global_fn_inputs = torch.cat([node_attributes, edge_attributes, globals_], dim=-1) - globals_ = self.update_global_fn(global_fn_inputs) + globals_ = self.update_global_fn(global_fn_inputs, dropout_rate) return GraphsTuple( nodes=nodes, diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py index 45295ac7f..758b36b60 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py @@ -12,6 +12,7 @@ from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.ogbg import metrics +from algoperf.workloads.ogbg.ogbg_pytorch import models from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN from algoperf.workloads.ogbg.workload import BaseOgbgWorkload @@ -138,15 +139,10 @@ def _build_input_queue(self, def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is unused.""" - del aux_dropout_rate + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = GNN( num_outputs=self._num_outputs, - dropout_rate=dropout_rate, hidden_dims=self.hidden_dims, latent_dim=self.latent_dim, num_message_passing_steps=self.num_message_passing_steps, @@ -171,7 +167,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del rng del update_batch_norm # No BN in the GNN model. @@ -191,7 +188,8 @@ def model_fn( } with contexts[mode](): - logits = model(augmented_and_preprocessed_input_batch['inputs']) + logits = model(augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate) return logits, None diff --git a/algoperf/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py index a1c7ce15e..a43df30d4 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models.py @@ -9,6 +9,8 @@ from torch.nn.init import normal_ from torch.nn.init import xavier_uniform_ +DROPOUT_RATE = 0.1 + def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: """Make a causal mask for self-attention. @@ -104,26 +106,18 @@ def __init__(self, nhead: int = 16, d_hid: int = 1024, nlayers: int = 6, - dropout_rate: Optional[float] = 0.1, - attention_dropout_rate: Optional[float] = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, attention_temp: float = 1.0, pre_ln: bool = True): super().__init__() - if dropout_rate is None: - dropout_rate = 0.1 - if attention_dropout_rate is None: - attention_dropout_rate = 0.1 - self.pos_encoder = PositionalEncoding(d_model, dropout_rate) + self.pos_encoder = PositionalEncoding(d_model) self.shared_embedding = nn.Embedding(ntoken, d_model) self.encoder = Encoder(d_model, nhead, d_hid, nlayers, - dropout_rate, - attention_dropout_rate, activation, glu, layer_norm_eps, @@ -133,8 +127,6 @@ def __init__(self, nhead, d_hid, nlayers, - dropout_rate, - attention_dropout_rate, activation, glu, layer_norm_eps, @@ -163,7 +155,8 @@ def forward(self, targets_positions: Optional[Tensor] = None, inputs_segmentation: Optional[Tensor] = None, targets_segmentation: Optional[Tensor] = None, - decode: bool = False) -> Tensor: + decode: bool = False, + dropout_rate: float = DROPOUT_RATE) -> Tensor: """ Args: src: Tensor, shape [batch_size, seq_len] @@ -173,16 +166,19 @@ def forward(self, inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len] targets_segmentation: Optional[Tensor], shape [batch_size, seq_len] decode: bool + dropout_rate: float Returns: output Tensor of shape [batch_size, seq_len, ntoken] """ if src.size(0) != tgt.size(0): raise RuntimeError('The batch size of src and tgt must be equal.') + memory = self.encoder( src, inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation) + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate) output = self.decoder( tgt, memory, @@ -190,7 +186,8 @@ def forward(self, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, - decode=decode) + decode=decode, + dropout_rate=dropout_rate) return output @@ -229,12 +226,15 @@ def __init__(self, self.enable_nested_tensor = enable_nested_tensor self.mask_check = mask_check - def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: + def forward(self, src: Tensor, + mask: Optional[Tensor] = None, + dropout_rate: Optional[float] = 0.0) -> Tensor: """Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). mask: the mask for the src sequence (optional). + dropout_rate: the dropout probability (optional). Shape: see the docs in Transformer class. @@ -243,7 +243,7 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: convert_to_nested = False for mod in self.layers: - output = mod(output, src_mask=mask) + output = mod(output, src_mask=mask, dropout_rate=dropout_rate) if convert_to_nested: output = output.to_padded_tensor(0.) @@ -261,8 +261,6 @@ def __init__(self, nhead: int = 16, d_hid: int = 1024, nlayers: int = 6, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -276,8 +274,6 @@ def __init__(self, d_model, nhead, d_hid, - dropout_rate, - attention_dropout_rate=attention_dropout_rate, activation=activation, glu=glu, layer_norm_eps=layer_norm_eps, @@ -290,12 +286,13 @@ def __init__(self, def forward(self, src: Tensor, inputs_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None) -> Tensor: + inputs_segmentation: Optional[Tensor] = None, + dropout_rate: Optional[float] = 0.0) -> Tensor: src = src.to(torch.int) src_mask = make_src_mask(src, inputs_segmentation, self.nhead) src = self.shared_embedding(src) - src = self.pos_encoder(src, inputs_positions) - memory = self.encoder(src, mask=src_mask) + src = self.pos_encoder(src, inputs_positions, dropout_rate=dropout_rate) + memory = self.encoder(src, mask=src_mask, dropout_rate=dropout_rate) return memory @@ -306,8 +303,6 @@ def __init__(self, nhead: int = 16, d_hid: int = 1024, nlayers: int = 6, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -320,8 +315,6 @@ def __init__(self, self.decoder = TransformerDecoder(d_model, nhead, d_hid, - dropout_rate, - attention_dropout_rate, activation, glu, layer_norm_eps, @@ -339,7 +332,8 @@ def forward( targets_segmentation: Optional[Tensor] = None, decode: bool = False, max_len: Optional[int] = None, - cache: Optional[dict] = None) -> Any: + cache: Optional[dict] = None, + dropout_rate: Optional[float] = 0.0) -> Any: tgt = tgt.to(torch.int) tgt_mask, memory_mask = make_tgt_and_memory_mask( tgt, src, inputs_segmentation, targets_segmentation, @@ -347,7 +341,7 @@ def forward( if not decode: tgt = shift_right(tgt) tgt = self.shared_embedding(tgt) - tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache) + tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache, dropout_rate=dropout_rate) if decode: tgt, cache = tgt output = self.decoder( @@ -357,7 +351,8 @@ def forward( memory_mask=memory_mask, decode=decode, max_len=max_len, - cache=cache) + cache=cache, + dropout_rate=dropout_rate) if decode: output, cache = output normalize = math.sqrt(output.shape[-1]) @@ -371,10 +366,8 @@ class PositionalEncoding(nn.Module): def __init__(self, d_model: int, - dropout_rate: float = 0.1, max_len: int = 256): super().__init__() - self.dropout = nn.Dropout(p=dropout_rate) position = torch.arange(max_len).unsqueeze(1) scale_factor = -math.log(10000.0) / (d_model // 2 - 1) @@ -389,7 +382,8 @@ def forward( x: Tensor, inputs_positions: Optional[Tensor] = None, decode: bool = False, - cache: Optional[Dict[str, Dict[str, Tensor]]] = None + cache: Optional[Dict[str, Dict[str, Tensor]]] = None, + dropout_rate: Optional[float] = 0.0 ) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]: """ Args: @@ -397,6 +391,7 @@ def forward( inputs_positions: Tensor (shape [batch_size, seq_len]) or None decode: bool cache: Dict[str, Dict[str, Tensor]] or None + dropout_rate: Optional[float] Returns: Tensor or Tuple[Tensor, Dict[str, Dict[str, Tensor]]] """ @@ -412,14 +407,14 @@ def forward( } pe = self.pe[0, cache[name]['cache_index'], :] cache[name]['cache_index'] += 1 - return self.dropout(x + pe), cache + return F.dropout(x + pe, dropout_rate, self.training), cache if inputs_positions is None: # normal unpacked case: pe = self.pe[:, :x.size(1), :] else: # for packed data we need to use known position indices: pe = self.pe[0, inputs_positions, :] - return self.dropout(x + pe) + return F.dropout(x + pe, dropout_rate, self.training) # TransformerEncoderLayer and TransformerDecoderLayer are taken from: @@ -438,7 +433,6 @@ class TransformerEncoderLayer(nn.Module): nhead: the number of heads in the multiheadattention models (default=16). dim_feedforward: the dimension of the feedforward network model (default=1024). - dropout_rate: the dropout_rate value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components @@ -457,8 +451,6 @@ def __init__(self, d_model: int = 1024, nhead: int = 16, dim_feedforward: int = 1024, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -472,7 +464,6 @@ def __init__(self, d_model, nhead, self_attn=True, - dropout_rate=attention_dropout_rate, attention_temp=attention_temp, bias=False, **factory_kwargs) @@ -482,50 +473,55 @@ def __init__(self, self.glu = glu if self.glu: self.linear_glu = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.dropout = nn.Dropout(dropout_rate) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.pre_ln = pre_ln self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.dropout1 = nn.Dropout(dropout_rate) - self.dropout2 = nn.Dropout(dropout_rate) self.activation = activation - def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: + def forward(self, + src: Tensor, + src_mask: Optional[Tensor] = None, + dropout_rate: Optional[float] = 0.0) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). - + dropout_rate: the dropout probability value (optional). Shape: see the docs in Transformer class. """ x = src if self.pre_ln: - x = x + self._sa_block(self.norm1(x), src_mask) - x = x + self._ff_block(self.norm2(x)) + x = x + self._sa_block(self.norm1(x), src_mask, dropout_rate) + x = x + self._ff_block(self.norm2(x), dropout_rate) else: - x = self.norm1(x + self._sa_block(x, src_mask)) - x = self.norm2(x + self._ff_block(x)) + x = self.norm1(x + self._sa_block(x, src_mask, dropout_rate)) + x = self.norm2(x + self._ff_block(x, dropout_rate)) return x # Self-attention block: - def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tensor: - x, _ = self.self_attn(x, attn_mask=attn_mask) - return self.dropout1(x) + def _sa_block(self, + x: Tensor, + attn_mask: Optional[Tensor], + dropout_rate: Optional[float] = 0.0) -> Tensor: + x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, training=self.training) # Feed forward block: - def _ff_block(self, inputs: Tensor) -> Tensor: + def _ff_block(self, + inputs: Tensor, + dropout_rate: Optional[float] = 0.0) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) x = x * y - x = self.linear2(self.dropout(x)) - return self.dropout2(x) + x = self.linear2(F.dropout(x, dropout_rate, training=self.training)) + return F.dropout(x, dropout_rate, training=self.training) # Modified to use cache for autoregressive decoding and custom @@ -537,7 +533,6 @@ class TransformerDecoder(nn.Module): nhead: the number of heads in the multiheadattention models (default=16) d_hid: the dimension of the feedforward network model (default=1024) - dropout_rate: the dropout_rate value (default=0.1) layer_norm_eps: the eps value in layer normalization components (default=1e-6). decoder_layer: an instance of the TransformerDecoderLayer() class @@ -555,8 +550,6 @@ def __init__(self, d_model, nhead, d_hid, - dropout_rate, - attention_dropout_rate, activation, glu, layer_norm_eps, @@ -569,8 +562,6 @@ def __init__(self, d_model, nhead, d_hid, - dropout_rate, - attention_dropout_rate, activation, glu, layer_norm_eps=layer_norm_eps, @@ -587,7 +578,8 @@ def forward(self, memory_mask: Optional[Tensor] = None, decode: bool = False, max_len: Optional[int] = None, - cache: Optional[dict] = None) -> Any: + cache: Optional[dict] = None, + dropout_rate: Optional[float] = 0.0) -> Any: r"""Pass the inputs (and mask) through the decoder layer in turn. Args: tgt: the sequence to the decoder (required). @@ -596,6 +588,7 @@ def forward(self, memory_mask: the mask for the memory sequence (optional). decode: whether to use cache for autoregressive decoding or not. max_len: maximum sequence length, necessary for decoding cache. + dropout_rate: the dropout probability value (optional) Shape: see the docs in Transformer class. """ @@ -610,7 +603,8 @@ def forward(self, decode=decode, max_len=max_len, cache=cache, - index=idx) + index=idx, + dropout_rate=dropout_rate) if self.norm is not None: output = self.norm(output) @@ -636,7 +630,6 @@ class TransformerDecoderLayer(nn.Module): nhead: the number of heads in the multiheadattention models (default=16). dim_feedforward: the dimension of the feedforward network model (default=1024). - dropout_rate: the dropout_rate value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components @@ -656,8 +649,6 @@ def __init__(self, d_model: int = 1024, nhead: int = 16, dim_feedforward: int = 1024, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -671,7 +662,6 @@ def __init__(self, d_model, nhead, self_attn=True, - dropout_rate=attention_dropout_rate, attention_temp=attention_temp, bias=False, **factory_kwargs) @@ -679,7 +669,6 @@ def __init__(self, d_model, nhead, self_attn=False, - dropout_rate=attention_dropout_rate, attention_temp=attention_temp, bias=False, **factory_kwargs) @@ -691,16 +680,12 @@ def __init__(self, self.linear_glu = nn.Linear(dim_feedforward, dim_feedforward, **factory_kwargs) - self.dropout = nn.Dropout(dropout_rate) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.pre_ln = pre_ln self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.dropout1 = nn.Dropout(dropout_rate) - self.dropout2 = nn.Dropout(dropout_rate) - self.dropout3 = nn.Dropout(dropout_rate) self.activation = activation @@ -713,7 +698,8 @@ def forward( # pylint: disable=arguments-renamed decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0) -> Any: r"""Pass the inputs (and mask) through the decoder layer. Args: tgt: the sequence to the decoder layer (required). @@ -722,6 +708,7 @@ def forward( # pylint: disable=arguments-renamed memory_mask: the mask for the memory sequence (optional). decode: wether to use cache for autoregressive decoding or not. max_len: maximum sequence length, necessary for decoding cache. + dropout_rate: the dropout probability value (optional) Shape: see the docs in Transformer class. """ @@ -735,10 +722,11 @@ def forward( # pylint: disable=arguments-renamed decode=decode, max_len=max_len, cache=cache, - index=index) + index=index, + dropout_rate=dropout_rate) x = x + sa_out - x = x + self._mha_block(self.norm2(x), memory, memory_mask) - x = x + self._ff_block(self.norm3(x)) + x = x + self._mha_block(self.norm2(x), memory, memory_mask, dropout_rate) + x = x + self._ff_block(self.norm3(x), dropout_rate) else: sa_out, cache = self._sa_block( x, @@ -746,10 +734,11 @@ def forward( # pylint: disable=arguments-renamed decode=decode, max_len=max_len, cache=cache, - index=index) + index=index, + dropout_rate=dropout_rate) x = self.norm1(x + sa_out) - x = self.norm2(x + self._mha_block(x, memory, memory_mask)) - x = self.norm3(x + self._ff_block(x)) + x = self.norm2(x + self._mha_block(x, memory, memory_mask, dropout_rate)) + x = self.norm3(x + self._ff_block(x, dropout_rate)) return x, cache @@ -761,30 +750,38 @@ def _sa_block( # pylint: disable=arguments-renamed decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0) -> Any: x, cache = self.self_attn( x, attn_mask=attn_mask, decode=decode, max_len=max_len, cache=cache, - index=index) - return self.dropout1(x), cache + index=index, + dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, self.training), cache # Multihead attention block: def _mha_block(self, x: Tensor, mem: Tensor, - attn_mask: Optional[Tensor]) -> Tensor: - x, _ = self.multihead_attn(x, mem, attn_mask=attn_mask) - return self.dropout2(x) + attn_mask: Optional[Tensor], + dropout_rate: Optional[float] = 0.0) -> Tensor: + x, _ = self.multihead_attn( + x, + mem, + attn_mask=attn_mask, + dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, self.training) # Feed forward block. - def _ff_block(self, inputs: Tensor) -> Tensor: + def _ff_block(self, inputs: Tensor, + dropout_rate: Optional[float] = 0.0) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) x = x * y - x = self.linear2(self.dropout(x)) - return self.dropout3(x) + x = self.linear2(F.dropout(x, dropout_rate, self.training)) + return F.dropout(x, dropout_rate, self.training) class MultiheadAttention(nn.Module): @@ -802,8 +799,6 @@ class MultiheadAttention(nn.Module): ``embed_dim // num_heads``). self_attn: Whether self attention or encoder-decoder attention is used. Default: ``True``. - dropout_rate: Dropout probability on ``attn_output_weights``. - Default: ``0.0`` (no dropout_rate). bias: If specified, adds bias to input / output projection layers. Default: ``False``. device: The device of the module. @@ -817,7 +812,6 @@ def __init__(self, embed_dim: int, num_heads: int, self_attn: bool = True, - dropout_rate: float = 0., attention_temp: float = 1.0, bias: bool = False, device: Optional[torch.device] = None, @@ -826,7 +820,6 @@ def __init__(self, self.embed_dim = embed_dim self.num_heads = num_heads self.self_attn = self_attn - self.dropout = dropout_rate self.head_dim = embed_dim // num_heads self.attention_temp = attention_temp assert self.head_dim * num_heads == self.embed_dim, \ @@ -861,7 +854,8 @@ def forward(self, decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0) -> Any: # TODO: (nico) remove default?! r""" Args: x: Batch of input sequences of shape @@ -887,6 +881,7 @@ def forward(self, max_len: maximum sequence length, necessary for decoding cache. cache: cache dictionary for autoregressive decoding. index: index of the current decoding step, necessary for decoding cache. + dropout_rate: dropout probability on ``attn_output_weights``. Outputs: - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where :math:`L` is the target sequence length, :math:`N` is the batch size, @@ -976,12 +971,12 @@ def forward(self, attn_mask = new_attn_mask # Adjust dropout_rate probability. - dropout_rate = self.dropout if self.training else 0.0 + attn_dropout_rate = dropout_rate if self.training else 0.0 # Calculate attention. q = self.attention_temp * q attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask, dropout_rate) + q, k, v, attn_mask, attn_dropout_rate) # Rearrange for output projection. attn_output = attn_output.transpose(1, 2).contiguous().view( bsz, tgt_len, embed_dim) diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py index d0716d6c8..4c787becc 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/workload.py +++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py @@ -17,6 +17,7 @@ from algoperf import spec from algoperf.workloads.wmt import bleu from algoperf.workloads.wmt.wmt_pytorch import decode +from algoperf.workloads.wmt.wmt_pytorch import models from algoperf.workloads.wmt.wmt_pytorch.models import Transformer from algoperf.workloads.wmt.workload import BaseWmtWorkload @@ -167,10 +168,7 @@ def translate_and_calculate_bleu(self, def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) if self.activation == 'relu': @@ -181,8 +179,6 @@ def init_model_fn( raise ValueError(f'Unknown activation function {self.activation}.') model = Transformer( - dropout_rate=dropout_rate, - attention_dropout_rate=aux_dropout_rate, pre_ln=self.pre_ln, attention_temp=self.attention_temp, activation=activation, @@ -207,7 +203,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -233,7 +230,8 @@ def model_fn( inputs_segmentation=augmented_and_preprocessed_input_batch.get( 'inputs_segmentation', None), targets_segmentation=augmented_and_preprocessed_input_batch.get( - 'targets_segmentation', None)) + 'targets_segmentation', None), + dropout_rate=dropout_rate) return logits_batch, None diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..db56b17cf --- /dev/null +++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py @@ -0,0 +1,92 @@ +""" +Runs fwd pass with random input for our DLRM models and compares outputs. +Run it as: + python3 tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch +import os + +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import ( + DLRMResNet as OriginalDLRMResNet, + DlrmSmall as OriginalDlrmSmall, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models_dropout import ( + DLRMResNet as CustomDLRMResNet, + DlrmSmall as CustomDlrmSmall, +) + + +BATCH, DENSE, SPARSE = 16, 13, 26 +FEATURES = DENSE + SPARSE +VOCAB = 1000 +DEVICE = 'cuda' +SEED = 1996 + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) + +class ModelEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='DLRMResNet, p=0.0', model='dlrm_resnet', dropout_rate=0.0), + dict(testcase_name='DlrmSmall, p=0.0', model='dlrm_small', dropout_rate=0.0), + dict(testcase_name='DLRMResNet, p=0.1', model='dlrm_resnet', dropout_rate=0.1), + dict(testcase_name='DlrmSmall, p=0.1', model='dlrm_small', dropout_rate=0.1), + dict(testcase_name='DLRMResNet, p=1.0', model='dlrm_resnet', dropout_rate=1.0), + dict(testcase_name='DlrmSmall, p=1.0', model='dlrm_small', dropout_rate=1.0), + ) + def test_forward(self, model, dropout_rate): + OrigCls, CustCls = ( + (OriginalDLRMResNet, CustomDLRMResNet) + if model == 'dlrm_resnet' + else (OriginalDlrmSmall, CustomDlrmSmall) + ) + + torch.manual_seed(SEED) + orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustCls(vocab_size=VOCAB).to(DEVICE) + + x = torch.randn(BATCH, FEATURES, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(SEED); y1 = orig(x) + torch.manual_seed(SEED); y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED); y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'), + dict(testcase_name='DlrmSmall, default', model='dlrm_small'), + ) + def test_default_dropout(self, model): + """Test default dropout_rate.""" + OrigCls, CustCls = ( + (OriginalDLRMResNet, CustomDLRMResNet) + if model == 'dlrm_resnet' + else (OriginalDlrmSmall, CustomDlrmSmall) + ) + + torch.manual_seed(SEED) + orig = OrigCls(vocab_size=VOCAB).to(DEVICE) + torch.manual_seed(SEED) + cust = CustCls(vocab_size=VOCAB).to(DEVICE) + + x = torch.randn(BATCH, FEATURES, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..0d3d52980 --- /dev/null +++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py @@ -0,0 +1,109 @@ +""" +Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. +Run it as: + python3 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch +import os + +from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet as OriginalUNet +from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import UNet as CustomUNet + +BATCH, IN_CHANS, H, W = 4, 1, 256, 256 +OUT_CHANS, C, LAYERS = 1, 32, 4 +DEVICE = 'cuda' +TORCH_COMPILE = False +SEED = 1996 + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) + +class FastMRIModeEquivalenceTest(parameterized.TestCase): + + def fwd_pass(self, orig, cust, dropout_rate): + x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(0); y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.1', dropout_rate=0.1), + dict(testcase_name='p=0.7', dropout_rate=0.7), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_dropout_values(self, dropout_rate): + """Test different values of dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + self.fwd_pass(orig, cust, dropout_rate) + + + @parameterized.named_parameters( + dict(testcase_name='default', use_tanh=False, use_layer_norm=False), + dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False), + dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True), + dict(testcase_name='both', use_tanh=True, use_layer_norm=True), + ) + def test_arch_configs(self, use_tanh, use_layer_norm): + """Test different architecture configurations, fixed dropout_rate.""" + dropout_rate = 0.1 + + torch.manual_seed(SEED) + orig = OriginalUNet( + IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate, + use_tanh=use_tanh, use_layer_norm=use_layer_norm + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomUNet( + IN_CHANS, OUT_CHANS, C, LAYERS, + use_tanh=use_tanh, use_layer_norm=use_layer_norm + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + self.fwd_pass(orig, cust, dropout_rate) + + @parameterized.named_parameters( + dict(testcase_name=''), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) + cust.load_state_dict(orig.state_dict()) # sync weights + + x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..d19fad0ba --- /dev/null +++ b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py @@ -0,0 +1,127 @@ +""" +Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. +Run it as: + python3 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch +import os +import itertools + +from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import ViT as OriginalVit +from algoperf.workloads.imagenet_vit.imagenet_pytorch.models_dropout import ViT as CustomVit + +# Model / test hyper-params +BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W) +WIDTH, DEPTH, HEADS = 256, 4, 8 +DROPOUT_RATE = None +DEVICE = 'cuda' +SEED = 1996 + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) + +class ImageNetVitModeEquivalenceTest(parameterized.TestCase): + + def fwd_pass(self, orig, cust, dropout_rate): + x = torch.randn(BATCH, C, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(0); y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.1', dropout_rate=0.1), + dict(testcase_name='p=0.6', dropout_rate=0.6), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_dropout_values(self, dropout_rate): + """Test different dropout_rates.""" + + torch.manual_seed(SEED) + orig = OriginalVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + dropout_rate=dropout_rate, + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + self.fwd_pass(orig, cust, dropout_rate) + + + @parameterized.named_parameters([ + dict( + testcase_name=f"GLU={use_glu}_LN={use_post_ln}_MAP={use_map}", + use_glu=use_glu, + use_post_ln=use_post_ln, + use_map=use_map, + ) + for use_glu, use_post_ln, use_map in itertools.product([False, True], repeat=3) + ]) + def test_arch(self, use_glu, use_post_ln, use_map): + """Test different architecture configurations, fixed dropout_rate.""" + dropout_rate = 0.1 + + torch.manual_seed(SEED) + orig = OriginalVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + use_glu=use_glu, + use_post_layer_norm=use_post_ln, + use_map=use_map, + dropout_rate=dropout_rate, + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + use_glu=use_glu, + use_post_layer_norm=use_post_ln, + use_map=use_map, + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + self.fwd_pass(orig, cust, dropout_rate) + + @parameterized.named_parameters( + dict(testcase_name=''), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) + cust.load_state_dict(orig.state_dict()) # sync weights + + x = torch.randn(BATCH, C, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..4a1252a39 --- /dev/null +++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py @@ -0,0 +1,85 @@ +""" +Runs fwd pass with random input for LIBRISPEECH Conformer models and compares outputs. +Run with: + python3 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py + +NOTE: we don't test for default dropout_rate values, since they changed. +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch +import os + +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( + ConformerConfig as OriginalConfig, + ConformerEncoderDecoder as OriginalModel +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import( + ConformerConfig as CustomConfig, + ConformerEncoderDecoder as CustomModel, +) + +N_LAYERS = 3 +B, T = 32, 36_000 +DEVICE = 'cuda' + +os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(mode=True) +SEED = 1996 + + +class ConformerEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.2', dropout_rate=0.2), + dict(testcase_name='p=0.7', dropout_rate=0.7), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_forward(self, dropout_rate): + + torch.manual_seed(SEED) + orig = OriginalModel( + OriginalConfig( + num_encoder_layers=N_LAYERS, + attention_residual_dropout_rate=dropout_rate, + conv_residual_dropout_rate=dropout_rate, + feed_forward_residual_dropout_rate=dropout_rate, + input_dropout_rate=dropout_rate, + )).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomModel( + CustomConfig( + num_encoder_layers=N_LAYERS + ) + ).to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + y1, p1 = orig(x, paddings) + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings) + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..58ddb354e --- /dev/null +++ b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py @@ -0,0 +1,117 @@ +""" +Runs fwd pass with random input for LIBRISPEECH Deepspeech models and compares outputs. +Run with: + python3 tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py + +`dropout_rate` controls the following args: +- `input_dropout_rate` (if None, 0.1 +- `feed_forward_dropout_rate` (if None, 0.1) +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch +import os + +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import ( + DeepspeechEncoderDecoder as OriginalModel, + DeepspeechConfig as OriginalConfig +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import( + DeepspeechEncoderDecoder as CustomModel, + DeepspeechConfig as CustomConfig +) + +B, T = 32, 30_000 +DEVICE = 'cuda' +TORCH_COMPILE = False + +os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(mode=True) +SEED = 1996 + + +class DeepSpeechEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.2', dropout_rate=0.2), + dict(testcase_name='p=0.7', dropout_rate=0.7), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_forward(self, dropout_rate): + """Test different dropout_rate values.""" + + torch.manual_seed(SEED) + orig = OriginalModel( + OriginalConfig( + num_lstm_layers=2, + num_ffn_layers=2, + input_dropout_rate=dropout_rate, + feed_forward_dropout_rate=dropout_rate, + )).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomModel(CustomConfig( + num_lstm_layers=2, + num_ffn_layers=2, + )).to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + y1, p1 = orig(x, paddings) + + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings) + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + + @parameterized.named_parameters( + dict(testcase_name=''), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalModel(OriginalConfig( num_lstm_layers=2, num_ffn_layers=2)).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomModel(CustomConfig( num_lstm_layers=2, num_ffn_layers=2)).to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights + + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(SEED); y1, p1 = orig(x, paddings) + torch.manual_seed(SEED); y2, p2 = cust(x, paddings) + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..aaca6cebd --- /dev/null +++ b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py @@ -0,0 +1,103 @@ +""" +Runs fwd pass with random graphs for OGBG GNN models and compares outputs. +Run with: + python3 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch, os, random, numpy as np +from jraph import GraphsTuple + +from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as OriginalModel +from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import GNN as CustomModel + +B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph +NODE_FDIM, EDGE_FDIM = 9, 3 # expected feature dims +DEVICE = 'cuda' + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) +SEED = 1996 + + +def _rand_graph(): + total_nodes, total_edges = B * N, B * E + nodes = torch.randn(total_nodes, NODE_FDIM, device=DEVICE) + edges = torch.randn(total_edges, EDGE_FDIM, device=DEVICE) + senders, receivers = [], [] + for i in range(B): + offset = i * N + s = torch.randint(N, (E,), device=DEVICE) + offset + r = torch.randint(N, (E,), device=DEVICE) + offset + senders.append(s), receivers.append(r) + senders = torch.cat(senders); receivers = torch.cat(receivers) + n_node = torch.full((B,), N, device=DEVICE, dtype=torch.int32) + n_edge = torch.full((B,), E, device=DEVICE, dtype=torch.int32) + return GraphsTuple(nodes, edges, receivers, senders, None, n_node, n_edge) + + +class GNNEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='0.0', dropout_rate=0.0), + dict(testcase_name='0.2', dropout_rate=0.2), + dict(testcase_name='0.7', dropout_rate=0.7), + dict(testcase_name='1.0', dropout_rate=1.0), + ) + def test_forward(self, dropout_rate): + """Test different dropout_rates.""" + + orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE) + cust = CustomModel().to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights + + graph = _rand_graph() + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(graph) + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(graph, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(graph) + assert_close(y1, y2, atol=0, rtol=0) + + + @parameterized.named_parameters( + dict(testcase_name=''), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + + orig = OriginalModel().to(DEVICE) + cust = CustomModel().to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights + + graph = _rand_graph() + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(graph) + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(graph) + + assert_close(y1, y2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..9675f1df2 --- /dev/null +++ b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py @@ -0,0 +1,110 @@ +""" +Runs fwd pass with random input for WMT Transformer models and compares outputs. +Run with: + python3 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch, os, random, numpy as np + +from algoperf.workloads.wmt.wmt_pytorch.models import ( + Transformer as OriginalModel, +) +from algoperf.workloads.wmt.wmt_pytorch.models_dropout import ( + Transformer as CustomModel, +) + +B, SRC_LEN, TGT_LEN, NTOK = 16, 80, 80, 32_000 +DEVICE = "cuda" +SEED = 1996 + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) + + +def _rand_tokens(bs, seqlen): + return torch.randint(1, NTOK, (bs, seqlen), device=DEVICE) + + +class TransformerEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + # NOTE: removed dropout=1.0 since it will generate nan in scaled_dot_product_attention + dict(testcase_name="0.0", dropout_rate=0.0, compile=False), + dict(testcase_name="0.2", dropout_rate=0.2, compile=False), + dict(testcase_name="0.7", dropout_rate=0.7, compile=False), + dict(testcase_name="p=0.0_compile", dropout_rate=0.0, compile=True), + dict(testcase_name="p=0.2_compile", dropout_rate=0.2, compile=True), + dict(testcase_name="p=0.7_compile", dropout_rate=0.7, compile=True), + ) + def test_dropout_value(self, dropout_rate, compile): + + orig = OriginalModel( + dropout_rate=dropout_rate, + attention_dropout_rate=dropout_rate + ).to(DEVICE) + cust = CustomModel().to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if compile: + orig = torch.compile(orig) + cust = torch.compile(cust) + + src = _rand_tokens(B, SRC_LEN) + tgt = _rand_tokens(B, TGT_LEN) + + for mode in ("train", "eval"): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(src, tgt) + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(src, tgt, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(src, tgt) + assert_close(y1, y2, atol=0, rtol=0) + + + @parameterized.named_parameters( + dict(testcase_name="default", compile=False), + dict(testcase_name="default_compile", compile=True), + ) + def test_default(self, compile): + + orig = OriginalModel().to(DEVICE) + cust = CustomModel().to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if compile: + orig = torch.compile(orig) + cust = torch.compile(cust) + + src = _rand_tokens(B, SRC_LEN) + tgt = _rand_tokens(B, TGT_LEN) + + for mode in ("train", "eval"): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(src, tgt) + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(src, tgt) + + assert_close(y1, y2, atol=0, rtol=0) + + +if __name__ == "__main__": + absltest.main()