From b3060762eda8a47c6409bf5804ccca8d53686e0c Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 10 Jun 2025 18:42:28 +0200 Subject: [PATCH 01/23] dropout fix criteo, fastmri, vit, conf --- .../criteo1tb_pytorch/models_dropout.py | 298 ++++++++++ .../models_functional_dropout.py | 308 +++++++++++ algoperf/workloads/dropout_modules.py | 41 ++ .../fastmri/fastmri_pytorch/models_dropout.py | 167 ++++++ .../imagenet_pytorch/models_dropout.py | 395 +++++++++++++ .../librispeech_pytorch/models_dropout.py | 518 ++++++++++++++++++ 6 files changed, 1727 insertions(+) create mode 100644 algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py create mode 100644 algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py create mode 100644 algoperf/workloads/dropout_modules.py create mode 100644 algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py create mode 100644 algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py create mode 100644 algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py new file mode 100644 index 000000000..8042ec31e --- /dev/null +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py @@ -0,0 +1,298 @@ +"""Pytorch implementation of DLRM-Small.""" + +import math + +import torch +import torch.nn.functional as F +from torch import nn + +from algoperf.workloads.dropout_modules import CustomDropout, SequentialWithDropout + + +class DenseBlock(nn.Module): + """Dense block with optional residual connection.""" "" + def __init__(self, module, resnet=False): + super().__init__() + self.module = module + self.resnet = resnet + + def forward(self, x): + return self.module(x) + x if self.resnet else self.module(x) + + +class DenseBlockWithDropout(nn.Module): + """Dense block with optional residual connection and support for dropout.""" + def __init__(self, module, resnet=False): + super().__init__() + self.module = module + self.resnet = resnet + self._supports_custom_dropout = True + + def forward(self, x, p=None): + return self.module(x, p) + x if self.resnet else self.module(x, p) + + +class DotInteract(nn.Module): + """Performs feature interaction operation between dense or sparse features.""" + + def __init__(self, num_sparse_features): + super().__init__() + self.triu_indices = torch.triu_indices(num_sparse_features + 1, + num_sparse_features + 1) + + def forward(self, dense_features, sparse_features): + combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features), + dim=1) + interactions = torch.bmm(combined_values, + torch.transpose(combined_values, 1, 2)) + interactions_flat = interactions[:, + self.triu_indices[0], + self.triu_indices[1]] + return torch.cat((dense_features, interactions_flat), dim=1) + + +class DLRMResNet(nn.Module): + """Define a DLRM-Small model. + + Parameters: + vocab_size: vocab size of embedding table. + num_dense_features: number of dense features as the bottom mlp input. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + embed_dim: embedding dimension. + """ + + def __init__(self, + vocab_size, + num_dense_features=13, + num_sparse_features=26, + mlp_bottom_dims=(256, 256, 256), + mlp_top_dims=(256, 256, 256, 256, 1), + embed_dim=128, + dropout_rate=0.0, + use_layer_norm=False, + embedding_init_multiplier=None): + super().__init__() + self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) + self.num_dense_features = num_dense_features + self.num_sparse_features = num_sparse_features + self.mlp_bottom_dims = mlp_bottom_dims + self.mlp_top_dims = mlp_top_dims + self.embed_dim = embed_dim + + # Ideally, we should use the pooled embedding implementation from + # `TorchRec`. However, in order to have identical implementation + # with that of Jax, we define a single embedding matrix. + num_chunks = 4 + assert vocab_size % num_chunks == 0 + self.embedding_table_chucks = [] + scale = 1.0 / torch.sqrt(self.vocab_size) + for i in range(num_chunks): + chunk = nn.Parameter( + torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) + chunk.data.uniform_(0, 1) + chunk.data = scale * chunk.data + self.register_parameter(f'embedding_chunk_{i}', chunk) + self.embedding_table_chucks.append(chunk) + + input_dim = self.num_dense_features + bot_mlp_blocks = [] + for layer_idx, dense_dim in enumerate(self.mlp_bottom_dims): + block = [] + block.append(nn.Linear(input_dim, dense_dim)) + block.append(nn.ReLU(inplace=True)) + block = nn.Sequential(*block) + if layer_idx > 0: + block = DenseBlock(block, resnet=True) + else: + block = DenseBlock(block) + bot_mlp_blocks.append(block) + input_dim = dense_dim + self.bot_mlp = nn.Sequential(*bot_mlp_blocks) + + for module in self.bot_mlp.modules(): + if isinstance(module, nn.Linear): + limit = math.sqrt(6. / (module.in_features + module.out_features)) + nn.init.uniform_(module.weight.data, -limit, limit) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + # Number of sparse features = 26 + fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1] + num_layers_top = len(self.mlp_top_dims) + mlp_top_blocks = [] + for layer_idx, fan_out in enumerate(self.mlp_top_dims): + block = [] + block.append(nn.Linear(fan_in, fan_out)) + if layer_idx < (num_layers_top - 1): + block.append(nn.ReLU(inplace=True)) + if (dropout_rate is not None and dropout_rate > 0.0 and + layer_idx == num_layers_top - 2): + block.append(CustomDropout()) # (nico) + block = SequentialWithDropout(*block) # (nico) + if (layer_idx != 0) and (layer_idx != num_layers_top - 1): + block = DenseBlockWithDropout(block, resnet=True) + else: + block = DenseBlockWithDropout(block) + mlp_top_blocks.append(block) + fan_in = fan_out + self.top_mlp = SequentialWithDropout(*mlp_top_blocks) # (nico) + + for module in self.top_mlp.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_( + module.weight.data, + 0., + math.sqrt(2. / (module.in_features + module.out_features))) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + def forward(self, x, dropout_rate): + batch_size = x.shape[0] + + dense_features, sparse_features = torch.split( + x, [self.num_dense_features, self.num_sparse_features], 1) + + # Bottom MLP. + embedded_dense = self.bot_mlp(dense_features) + + # Sparse feature processing. + sparse_features = sparse_features.to(dtype=torch.int32) + idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size + embedding_table = torch.cat(self.embedding_table_chucks, dim=0) + embedded_sparse = embedding_table[idx_lookup] + embedded_sparse = torch.reshape(embedded_sparse, + [batch_size, 26 * self.embed_dim]) + top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) + + # Final MLP. + logits = self.top_mlp(top_mlp_input, dropout_rate) + return logits + + +class DlrmSmall(nn.Module): + """Define a DLRM-Small model. + + Parameters: + vocab_size: vocab size of embedding table. + num_dense_features: number of dense features as the bottom mlp input. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + embed_dim: embedding dimension. + """ + + def __init__(self, + vocab_size, + num_dense_features=13, + num_sparse_features=26, + mlp_bottom_dims=(512, 256, 128), + mlp_top_dims=(1024, 1024, 512, 256, 1), + embed_dim=128, + dropout_rate=0.0, + use_layer_norm=False, + embedding_init_multiplier=None): + super().__init__() + self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) + self.num_dense_features = num_dense_features + self.num_sparse_features = num_sparse_features + self.mlp_bottom_dims = mlp_bottom_dims + self.mlp_top_dims = mlp_top_dims + self.embed_dim = embed_dim + self.embedding_init_multiplier = embedding_init_multiplier + + # Ideally, we should use the pooled embedding implementation from + # `TorchRec`. However, in order to have identical implementation + # with that of Jax, we define a single embedding matrix. + num_chunks = 4 + assert vocab_size % num_chunks == 0 + self.embedding_table_chucks = [] + + if self.embedding_init_multiplier is None: + scale = 1.0 / torch.sqrt(self.vocab_size) + else: + scale = self.embedding_init_multiplier + + for i in range(num_chunks): + chunk = nn.Parameter( + torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) + chunk.data.uniform_(0, 1) + chunk.data = scale * chunk.data + self.register_parameter(f'embedding_chunk_{i}', chunk) + self.embedding_table_chucks.append(chunk) + + input_dim = self.num_dense_features + bottom_mlp_layers = [] + for dense_dim in self.mlp_bottom_dims: + bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim)) + bottom_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) + input_dim = dense_dim + self.bot_mlp = nn.Sequential(*bottom_mlp_layers) + for module in self.bot_mlp.modules(): + if isinstance(module, nn.Linear): + limit = math.sqrt(6. / (module.in_features + module.out_features)) + nn.init.uniform_(module.weight.data, -limit, limit) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) + + # TODO: Write down the formula here instead of the constant. + input_dims = 506 + num_layers_top = len(self.mlp_top_dims) + top_mlp_layers = [] + for layer_idx, fan_out in enumerate(self.mlp_top_dims): + fan_in = input_dims if layer_idx == 0 \ + else self.mlp_top_dims[layer_idx - 1] + top_mlp_layers.append(nn.Linear(fan_in, fan_out)) + if layer_idx < (num_layers_top - 1): + top_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) + if (dropout_rate is not None and dropout_rate > 0.0 and + layer_idx == num_layers_top - 2): + top_mlp_layers.append(CustomDropout()) # (nico) + self.top_mlp = SequentialWithDropout(*top_mlp_layers) # (nico) + if use_layer_norm: + self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) + else: + self.embed_ln = None + for module in self.top_mlp.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_( + module.weight.data, + 0., + math.sqrt(2. / (module.in_features + module.out_features))) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + def forward(self, x, dropout_rate): + batch_size = x.shape[0] + + dense_features, sparse_features = torch.split( + x, [self.num_dense_features, self.num_sparse_features], 1) + + # Bottom MLP. + embedded_dense = self.bot_mlp(dense_features) + + # Sparse feature processing. + sparse_features = sparse_features.to(dtype=torch.int32) + idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size + embedding_table = torch.cat(self.embedding_table_chucks, dim=0) + embedded_sparse = embedding_table[idx_lookup] + embedded_sparse = torch.reshape(embedded_sparse, + [batch_size, -1, self.embed_dim]) + if self.embed_ln: + embedded_sparse = self.embed_ln(embedded_sparse) + # Dot product interactions. + concatenated_dense = self.dot_interact( + dense_features=embedded_dense, sparse_features=embedded_sparse) + + # Final MLP. + logits = self.top_mlp(concatenated_dense, dropout_rate) + return logits diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py new file mode 100644 index 000000000..346e0e72a --- /dev/null +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py @@ -0,0 +1,308 @@ +"""Pytorch implementation of DLRM-Small.""" + +import math + +import torch +import torch.nn.functional as F +from torch import nn + + +class DenseBlock(nn.Module): + """Dense block with optional residual connection.""" "" + + def __init__(self, module, resnet=False): + super().__init__() + self.module = module + self.resnet = resnet + + def forward(self, x): + if self.resnet: + return self.module(x) + x + else: + return self.module(x) + + +class DotInteract(nn.Module): + """Performs feature interaction operation between dense or sparse features.""" + + def __init__(self, num_sparse_features): + super().__init__() + self.triu_indices = torch.triu_indices(num_sparse_features + 1, + num_sparse_features + 1) + + def forward(self, dense_features, sparse_features): + combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features), + dim=1) + interactions = torch.bmm(combined_values, + torch.transpose(combined_values, 1, 2)) + interactions_flat = interactions[:, + self.triu_indices[0], + self.triu_indices[1]] + return torch.cat((dense_features, interactions_flat), dim=1) + + +class DLRMResNet(nn.Module): + """Define a DLRM-Small model. + + Parameters: + vocab_size: vocab size of embedding table. + num_dense_features: number of dense features as the bottom mlp input. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + embed_dim: embedding dimension. + """ + + def __init__(self, + vocab_size, + num_dense_features=13, + num_sparse_features=26, + mlp_bottom_dims=(256, 256, 256), + mlp_top_dims=(256, 256, 256, 256, 1), + embed_dim=128, + # dropout_rate=0.0, + use_layer_norm=False, + embedding_init_multiplier=None): + super().__init__() + self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) + self.num_dense_features = num_dense_features + self.num_sparse_features = num_sparse_features + self.mlp_bottom_dims = mlp_bottom_dims + self.mlp_top_dims = mlp_top_dims + self.embed_dim = embed_dim + + # Ideally, we should use the pooled embedding implementation from + # `TorchRec`. However, in order to have identical implementation + # with that of Jax, we define a single embedding matrix. + num_chunks = 4 + assert vocab_size % num_chunks == 0 + self.embedding_table_chucks = [] + scale = 1.0 / torch.sqrt(self.vocab_size) + for i in range(num_chunks): + chunk = nn.Parameter( + torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) + chunk.data.uniform_(0, 1) + chunk.data = scale * chunk.data + self.register_parameter(f'embedding_chunk_{i}', chunk) + self.embedding_table_chucks.append(chunk) + + input_dim = self.num_dense_features + bot_mlp_blocks = [] + for layer_idx, dense_dim in enumerate(self.mlp_bottom_dims): + block = [] + block.append(nn.Linear(input_dim, dense_dim)) + block.append(nn.ReLU(inplace=True)) + block = nn.Sequential(*block) + if layer_idx > 0: + block = DenseBlock(block, resnet=True) + else: + block = DenseBlock(block) + bot_mlp_blocks.append(block) + input_dim = dense_dim + self.bot_mlp = nn.Sequential(*bot_mlp_blocks) + + for module in self.bot_mlp.modules(): + if isinstance(module, nn.Linear): + limit = math.sqrt(6. / (module.in_features + module.out_features)) + nn.init.uniform_(module.weight.data, -limit, limit) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + # Number of sparse features = 26 + fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1] + num_layers_top = len(self.mlp_top_dims) + mlp_top_blocks = [] + for layer_idx, fan_out in enumerate(self.mlp_top_dims): + block = [] + block.append(nn.Linear(fan_in, fan_out)) + if layer_idx < (num_layers_top - 1): + block.append(nn.ReLU(inplace=True)) + # if (dropout_rate is not None and dropout_rate > 0.0 and + # layer_idx == num_layers_top - 2): + # block.append(nn.Dropout(p=dropout_rate)) + block = nn.Sequential(*block) + if (layer_idx != 0) and (layer_idx != num_layers_top - 1): + block = DenseBlock(block, resnet=True) + else: + block = DenseBlock(block) + mlp_top_blocks.append(block) + fan_in = fan_out + self.top_mlp = nn.Sequential(*mlp_top_blocks) + + for module in self.top_mlp.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_( + module.weight.data, + 0., + math.sqrt(2. / (module.in_features + module.out_features))) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + def forward(self, x, dropout_rate): + batch_size = x.shape[0] + + dense_features, sparse_features = torch.split( + x, [self.num_dense_features, self.num_sparse_features], 1) + + # Bottom MLP. + embedded_dense = self.bot_mlp(dense_features) + + # Sparse feature processing. + sparse_features = sparse_features.to(dtype=torch.int32) + idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size + embedding_table = torch.cat(self.embedding_table_chucks, dim=0) + embedded_sparse = embedding_table[idx_lookup] + embedded_sparse = torch.reshape(embedded_sparse, + [batch_size, 26 * self.embed_dim]) + top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) + + # Final MLP (horrible!!). + h = top_mlp_input + num_layers_top = len(self.mlp_top_dims) + for layer_idx, block in enumerate(self.top_mlp): + # block.module is nn.Sequential([...]) + seq = block.module + # 1) linear + out = seq[0](h) + # 2) ReLU (if present) + if layer_idx < (num_layers_top - 1): + out = seq[1](out) + # 3) functional dropout at penult layer + if dropout_rate > 0 and layer_idx == num_layers_top - 2: + out = F.dropout(out, dropout_rate, training=self.training) + # 4) wrap in residual if needed + h = out + h if block.resnet else out + return h + + +class DlrmSmall(nn.Module): + """Define a DLRM-Small model. + + Parameters: + vocab_size: vocab size of embedding table. + num_dense_features: number of dense features as the bottom mlp input. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + embed_dim: embedding dimension. + """ + + def __init__(self, + vocab_size, + num_dense_features=13, + num_sparse_features=26, + mlp_bottom_dims=(512, 256, 128), + mlp_top_dims=(1024, 1024, 512, 256, 1), + embed_dim=128, + # dropout_rate=0.0, + use_layer_norm=False, + embedding_init_multiplier=None): + super().__init__() + self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) + self.num_dense_features = num_dense_features + self.num_sparse_features = num_sparse_features + self.mlp_bottom_dims = mlp_bottom_dims + self.mlp_top_dims = mlp_top_dims + self.embed_dim = embed_dim + self.embedding_init_multiplier = embedding_init_multiplier + + # Ideally, we should use the pooled embedding implementation from + # `TorchRec`. However, in order to have identical implementation + # with that of Jax, we define a single embedding matrix. + num_chunks = 4 + assert vocab_size % num_chunks == 0 + self.embedding_table_chucks = [] + + if self.embedding_init_multiplier is None: + scale = 1.0 / torch.sqrt(self.vocab_size) + else: + scale = self.embedding_init_multiplier + + for i in range(num_chunks): + chunk = nn.Parameter( + torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) + chunk.data.uniform_(0, 1) + chunk.data = scale * chunk.data + self.register_parameter(f'embedding_chunk_{i}', chunk) + self.embedding_table_chucks.append(chunk) + + input_dim = self.num_dense_features + bottom_mlp_layers = [] + for dense_dim in self.mlp_bottom_dims: + bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim)) + bottom_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) + input_dim = dense_dim + self.bot_mlp = nn.Sequential(*bottom_mlp_layers) + for module in self.bot_mlp.modules(): + if isinstance(module, nn.Linear): + limit = math.sqrt(6. / (module.in_features + module.out_features)) + nn.init.uniform_(module.weight.data, -limit, limit) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) + + # TODO: Write down the formula here instead of the constant. + input_dims = 506 + num_layers_top = len(self.mlp_top_dims) + top_mlp_layers = [] + for layer_idx, fan_out in enumerate(self.mlp_top_dims): + fan_in = input_dims if layer_idx == 0 \ + else self.mlp_top_dims[layer_idx - 1] + top_mlp_layers.append(nn.Linear(fan_in, fan_out)) + if layer_idx < (num_layers_top - 1): + top_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) + # if (dropout_rate is not None and dropout_rate > 0.0 and + # layer_idx == num_layers_top - 2): + # top_mlp_layers.append(nn.Dropout(p=dropout_rate)) + self.top_mlp = nn.Sequential(*top_mlp_layers) + if use_layer_norm: + self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) + else: + self.embed_ln = None + for module in self.top_mlp.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_( + module.weight.data, + 0., + math.sqrt(2. / (module.in_features + module.out_features))) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + def forward(self, x, dropout_rate): + batch_size = x.shape[0] + + dense_features, sparse_features = torch.split( + x, [self.num_dense_features, self.num_sparse_features], 1) + + # Bottom MLP. + embedded_dense = self.bot_mlp(dense_features) + + # Sparse feature processing. + sparse_features = sparse_features.to(dtype=torch.int32) + idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size + embedding_table = torch.cat(self.embedding_table_chucks, dim=0) + embedded_sparse = embedding_table[idx_lookup] + embedded_sparse = torch.reshape(embedded_sparse, + [batch_size, -1, self.embed_dim]) + if self.embed_ln: + embedded_sparse = self.embed_ln(embedded_sparse) + # Dot product interactions. + concatenated_dense = self.dot_interact( + dense_features=embedded_dense, sparse_features=embedded_sparse) + + # Final MLP: run each layer, and after the penultimate layer do functional dropout + h = concatenated_dense + N = len(self.top_mlp) + for idx, layer in enumerate(self.top_mlp): + h = layer(h) + # insert dropout exactly where nn.Dropout used to live + if dropout_rate > 0 and idx == N - 2: + h = F.dropout(h, dropout_rate, training=self.training) + return h diff --git a/algoperf/workloads/dropout_modules.py b/algoperf/workloads/dropout_modules.py new file mode 100644 index 000000000..3917b75bf --- /dev/null +++ b/algoperf/workloads/dropout_modules.py @@ -0,0 +1,41 @@ +"""Custom classes to support a dynamic modulized dropout, see issue??TODO""" + +from torch import Tensor +from torch import nn +import torch.nn.functional as F + + +class CustomDropout(nn.Module): + """A module around torch.nn.functional.dropout.""" + def __init__(self): + super().__init__() + self._supports_custom_dropout = True + + def forward(self, input: Tensor, p: float) -> Tensor: + return F.dropout(input, p, training=self.training) + + +class CustomDropout2d(nn.Module): + """A module around torch.nn.functional.dropout2d.""" + def __init__(self): + super().__init__() + self._supports_custom_dropout = True + + def forward(self, input: Tensor, p: float) -> Tensor: + return F.dropout2d(input, p, training=self.training) + + +class SequentialWithDropout(nn.Sequential): + """Sequential of modules with dropout.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._supports_custom_dropout = True + + def forward(self, x, p): + for module in self: + # if isinstance(module, (CustomDropout, SequentialWithDropout, DenseBlockWithDropout)): + if getattr(module, '_supports_custom_dropout', False): # TODO (nico): improve + x = module(x, p) + else: + x = module(x) + return x diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py new file mode 100644 index 000000000..5862f6352 --- /dev/null +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py @@ -0,0 +1,167 @@ +"""U-Net Model. + +Adapted from fastMRI: +https://github.com/facebookresearch/fastMRI/blob/main/fastmri/models/unet.py +""" + +from functools import partial +from typing import Optional + +import torch +from torch import nn +from torch import Tensor +from torch.nn import functional as F + +from algoperf import init_utils +from algoperf.workloads.dropout_modules import CustomDropout2d, SequentialWithDropout + + + +class UNet(nn.Module): + r"""U-Net model from + `"U-net: Convolutional networks + for biomedical image segmentation" + `_. + """ + + def __init__(self, + in_chans: int = 1, + out_chans: int = 1, + num_channels: int = 32, + num_pool_layers: int = 4, + use_tanh: bool = False, + use_layer_norm: bool = False) -> None: + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.num_channels = num_channels + self.num_pool_layers = num_pool_layers + self.down_sample_layers = nn.ModuleList([ + ConvBlock(in_chans, + num_channels, + use_tanh, + use_layer_norm) + ]) + ch = num_channels + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append( + ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, use_tanh, use_layer_norm) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append( + TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) + self.up_conv.append( + ConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) + ch //= 2 + + self.up_transpose_conv.append( + TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) + self.up_conv.append( + SequentialWithDropout( + ConvBlock(ch * 2, ch, use_tanh, use_layer_norm), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + )) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + init_utils.pytorch_default_init(m) + + def forward(self, x: Tensor, dropout_rate: float) -> Tensor: + stack = [] + output = x + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output, dropout_rate) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output = self.conv(output, dropout_rate) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/bottom if needed to handle + # odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output, dropout_rate) + + return output + + +class ConvBlock(nn.Module): + # A Convolutional Block that consists of two convolution layers each + # followed by instance normalization, LeakyReLU activation and dropout_rate. + + def __init__(self, + in_chans: int, + out_chans: int, + use_tanh: bool, + use_layer_norm: bool) -> None: + super().__init__() + self._supports_custom_dropout = True + + if use_layer_norm: + norm_layer = partial(nn.GroupNorm, 1, eps=1e-6) + else: + norm_layer = nn.InstanceNorm2d + if use_tanh: + activation_fn = nn.Tanh() + else: + activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.conv_layers = SequentialWithDropout( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + norm_layer(out_chans), + activation_fn, + CustomDropout2d(), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + norm_layer(out_chans), + activation_fn, + CustomDropout2d(), + ) + + def forward(self, x: Tensor, dropout_rate: float) -> Tensor: + return self.conv_layers(x, dropout_rate) + + +class TransposeConvBlock(nn.Module): + # A Transpose Convolutional Block that consists of one convolution transpose + # layers followed by instance normalization and LeakyReLU activation. + + def __init__( + self, + in_chans: int, + out_chans: int, + use_tanh: bool, + use_layer_norm: bool, + ): + super().__init__() + if use_tanh: + activation_fn = nn.Tanh() + else: + activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + activation_fn, + ) + + def forward(self, x: Tensor) -> Tensor: + return self.layers(x) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py new file mode 100644 index 000000000..f5e315fd7 --- /dev/null +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py @@ -0,0 +1,395 @@ +"""PyTorch implementation of refactored and simplified ViT. + +Adapted from: +https://github.com/huggingface/transformers/tree/main/src/transformers/models/vit +and https://github.com/lucidrains/vit-pytorch. +""" + +import math +from typing import Any, Optional, Tuple, Union + +import torch +from torch import nn +import torch.nn.functional as F + +from algoperf import init_utils +from algoperf import spec +from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention + + +def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: + """Follows the MoCo v3 logic.""" + _, width, h, w = patches.shape + device = patches.device + y, x = torch.meshgrid(torch.arange(h, device=device), + torch.arange(w, device=device), indexing='ij') + + if width % 4 != 0: + raise ValueError('Width must be mult of 4 for sincos posemb.') + omega = torch.arange(width // 4, device=device) / (width // 4 - 1) + omega = 1. / (temperature**omega) + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + return pe[None, :, :] + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block.""" + + def __init__( + self, + width: int, + mlp_dim: Optional[int] = None, # Defaults to 4x input dim. + use_glu: bool = False, + dropout_rate: float = 0.0) -> None: + super().__init__() + + self.width = width + self.mlp_dim = mlp_dim or 4 * width + self.use_glu = use_glu + self.dropout_rate = dropout_rate + + self.linear1 = nn.Linear(self.width, self.mlp_dim) + self.act_fnc = nn.GELU(approximate='tanh') + + if self.use_glu: + self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) + else: + self.glu_linear = None + + self.linear2 = nn.Linear(self.mlp_dim, self.width) + + self.reset_parameters() + + def reset_parameters(self) -> None: + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight.data) + if module.bias is not None: + module.bias.data.normal_(std=1e-6) + + def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate + + x = self.linear1(x) + x = self.act_fnc(x) + + if self.use_glu: + y = self.glu_linear(x) + x = x * y + + x = F.dropout(x, dropout_rate, training=self.training) + x = self.linear2(x) + return x + + +class SelfAttention(nn.Module): + """Self-attention special case of multi-head dot-product attention.""" + + def __init__(self, + width: int, + num_heads: int = 8, + dropout_rate: float = 0.0) -> None: + super().__init__() + + self.width = width + self.num_heads = num_heads + + assert width % num_heads == 0, ( + 'Memory dimension must be divisible by number of heads.') + + self.head_dim = int(width / num_heads) + self.all_head_dim = self.num_heads * self.head_dim + self.dropout_rate = dropout_rate + + self.query = nn.Linear(self.width, self.all_head_dim) + self.key = nn.Linear(self.width, self.all_head_dim) + self.value = nn.Linear(self.width, self.all_head_dim) + self.out = nn.Linear(self.width, self.width) + self.reset_parameters() + + def reset_parameters(self) -> None: + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight.data) + if module.bias is not None: + nn.init.constant_(module.bias.data, 0.) + + def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor: + new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dim) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate + + mixed_query_layer = self.query(x) + + key_layer = self.transpose_for_scores(self.key(x)) + value_layer = self.transpose_for_scores(self.value(x)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.head_dim) + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = F.dropout(attention_probs, dropout_rate, training=self.training) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_dim,) + context_layer = context_layer.view(new_context_layer_shape) + out = self.out(context_layer) + return out + + +class Encoder1DBlock(nn.Module): + """Single transformer encoder block (MHSA + MLP).""" + + def __init__(self, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, + dropout_rate: float = 0.0) -> None: + super().__init__() + + self.width = width + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm + self.dropout_rate = dropout_rate + + self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) + self.self_attention1 = SelfAttention(self.width, self.num_heads) + self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) + self.mlp3 = MlpBlock( + width=self.width, + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=dropout_rate) + + def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate + + if not self.use_post_layer_norm: + y = self.layer_norm0(x) + y = self.self_attention1(y) + y = F.dropout(y, dropout_rate, training=self.training) + x = x + y + + y = self.layer_norm2(x) + y = self.mlp3(y) + y = F.dropout(y, dropout_rate, training=self.training) + x = x + y + else: + y = x + y = self.self_attention1(y) + y = F.dropout(y, dropout_rate, training=self.training) + x = x + y + x = self.layer_norm0(x) + + y = x + y = self.mlp3(y) + y = F.dropout(y, dropout_rate, training=self.training) + x = x + y + x = self.layer_norm2(x) + return x + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation.""" + + def __init__(self, + depth: int, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, + dropout_rate: float = 0.0) -> None: + super().__init__() + + self.depth = depth + self.width = width + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm + + self.net = nn.ModuleList([ + Encoder1DBlock(self.width, + self.mlp_dim, + self.num_heads, + self.use_glu, + self.use_post_layer_norm, + dropout_rate) for _ in range(depth) + ]) + + if not self.use_post_layer_norm: + self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) + else: + self.encoder_norm = None + + def forward(self, x: spec.Tensor) -> spec.Tensor: + # Input Encoder. + for block in self.net: + x = block(x) + if not self.use_post_layer_norm: + return self.encoder_norm(x) + else: + return x + + +class MAPHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12): + super().__init__() + self.width = width + self.mlp_dim = mlp_dim + self.num_heads = num_heads + + self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) + nn.init.xavier_uniform_(self.probe.data) + + self.mha = MultiheadAttention( + self.width, num_heads=self.num_heads, self_attn=False, bias=True) + self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) + self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) + + def forward(self, x: spec.Tensor) -> spec.Tensor: + n, _, _ = x.shape + probe = torch.tile(self.probe, [n, 1, 1]) + + x = self.mha(probe, x)[0] + y = self.layer_norm(x) + x = x + self.mlp(y) + return x[:, 0] + + +class ViT(nn.Module): + """ViT model.""" + + image_height: int = 224 + image_width: int = 224 + channels: int = 3 + + def __init__( + self, + num_classes: int = 1000, + patch_size: Tuple[int, int] = (16, 16), + width: int = 768, + depth: int = 12, + mlp_dim: Optional[int] = None, # Defaults to 4x input dim. + num_heads: int = 12, + rep_size: Union[int, bool] = True, + dropout_rate: Optional[float] = 0.0, + head_zeroinit: bool = True, + use_glu: bool = False, + use_post_layer_norm: bool = False, + use_map: bool = False, + dtype: Any = torch.float32) -> None: + super().__init__() + if dropout_rate is None: + dropout_rate = 0.0 + + self.num_classes = num_classes + self.patch_size = patch_size + self.width = width + self.depth = depth + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.rep_size = rep_size + self.head_zeroinit = head_zeroinit + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm + self.use_map = use_map + self.dtype = dtype + self.dropout_rate = dropout_rate + + if self.rep_size: + rep_size = self.width if self.rep_size is True else self.rep_size + self.pre_logits = nn.Linear(self.width, rep_size) + + self.conv_patch_extract = nn.Conv2d( + self.channels, + self.width, + self.patch_size, + stride=self.patch_size, + padding='valid') + + self.encoder = Encoder( + depth=self.depth, + width=self.width, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + dropout_rate=dropout_rate) + + if self.num_classes: + self.head = nn.Linear(self.width, self.num_classes) + + if self.use_map: + self.map = MAPHead(self.width, self.mlp_dim, self.num_heads) + else: + self.map = None + + self.reset_parameters() + + def reset_parameters(self) -> None: + init_utils.pytorch_default_init(self.conv_patch_extract) + + if self.rep_size: + init_utils.pytorch_default_init(self.pre_logits) + + if self.num_classes: + if self.head_zeroinit: + nn.init.constant_(self.head.weight.data, 0.) + nn.init.constant_(self.head.bias.data, 0.) + else: + init_utils.pytorch_default_init(self.head) + + def get_posemb(self, x: spec.Tensor) -> spec.Tensor: + return posemb_sincos_2d(x).type(self.dtype) + + def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate + + # Patch extraction. + x = self.conv_patch_extract(x) + + # Add posemb before adding extra token. + n, c, h, w = x.shape + pes = self.get_posemb(x) + + # Reshape to match Jax's ViT implementation. + x = torch.transpose(torch.reshape(x, (n, c, h * w)), 1, 2) + x = x + pes + + x = F.dropout(x, dropout_rate, training=self.training) + x = self.encoder(x) + + if self.use_map: + x = self.map(x) + else: + x = torch.mean(x, dim=1) + + if self.rep_size: + x = torch.tanh(self.pre_logits(x)) + + if self.num_classes: + x = self.head(x) + + return x diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py new file mode 100644 index 000000000..da66dfe43 --- /dev/null +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py @@ -0,0 +1,518 @@ +"""This is a pytorch implementation mirroring: +https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py. +""" + +from dataclasses import dataclass +from functools import partial +import math +from typing import Optional, Tuple + +import torch +from torch import nn +from torch.nn import init +import torch.nn.functional as F + +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ + preprocessor +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ + SpecAug + + +@dataclass +class ConformerConfig: + """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 1024 + encoder_dim: int = 512 + num_attention_heads: int = 8 + num_encoder_layers: int = 4 + attention_dropout_rate: float = 0.0 + # If None, defaults to 0.1. + attention_residual_dropout_rate: Optional[float] = 0.1 + # If None, defaults to 0.0. + conv_residual_dropout_rate: Optional[float] = 0.0 + feed_forward_dropout_rate: float = 0.0 + # If None, defaults to 0.1. + feed_forward_residual_dropout_rate: Optional[float] = 0.1 + convolution_kernel_size: int = 5 + feed_forward_expansion_factor: int = 4 + freq_mask_count: int = 2 + freq_mask_max_bins: int = 27 + time_mask_count: int = 10 + time_mask_max_frames: int = 40 + time_mask_max_ratio: float = 0.05 + time_masks_per_frame: float = 0.0 + use_dynamic_time_mask_max_frames: bool = True + # If None, defaults to 0.1. + input_dropout_rate: Optional[float] = 0.1 + batch_norm_momentum: float = 1 - 0.999 + batch_norm_epsilon: float = 0.001 + use_specaug: bool = True + attention_temperature: float = 1.0 + activation_function_name: str = 'swish' + use_post_layer_norm: bool = True + + +def initialize(m): + if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d): + init.xavier_uniform_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.MultiheadAttention): + init.xavier_uniform_(m.in_proj_weight) + for i in m.children(): + initialize(i) + + +class LayerNorm(nn.Module): + + def __init__(self, dim, epsilon=1e-6): + super().__init__() + self.dim = dim + + self.scale = nn.Parameter(torch.zeros(self.dim)) + self.bias = nn.Parameter(torch.zeros(self.dim)) + self.epsilon = epsilon + + def forward(self, x): + return F.layer_norm(x, (self.dim,), 1 + self.scale, self.bias, self.epsilon) + + +class Subsample(nn.Module): + + def __init__(self, + encoder_dim: int = 0, + input_dropout_rate: float = 0.0, + num_bins: int = 80): + super().__init__() + self.encoder_dim = encoder_dim + self.input_dropout_rate = input_dropout_rate + + self.conv1 = Conv2dSubsampling( + input_channels=1, output_channels=encoder_dim) + self.conv2 = Conv2dSubsampling( + input_channels=encoder_dim, output_channels=encoder_dim) + + self.linear = nn.Linear( + in_features=self.encoder_dim * num_bins // 4, + out_features=self.encoder_dim, + bias=True) + self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) + + def forward(self, inputs, input_paddings, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.input_dropout_rate + + output_paddings = input_paddings + outputs = inputs[:, None, :, :] + + outputs, output_paddings = self.conv1(outputs, output_paddings) + outputs, output_paddings = self.conv2(outputs, output_paddings) + + batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape + outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size, + subsampled_lengths, + subsampled_dims * channels) + + outputs = self.linear(outputs) + outputs = outputs + self.pos_encode(seq_length=outputs.shape[1]) + outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) + + return outputs, output_paddings + + +class Conv2dSubsampling(nn.Module): + + def __init__(self, + input_channels: int, + output_channels: int, + filter_stride: Tuple[int] = (2, 2), + padding: str = 'SAME'): + super().__init__() + + self.input_channels = input_channels + self.output_channels = output_channels + self.filter_stride = filter_stride + self.padding = padding + + self.filter_shape = (output_channels, input_channels, 3, 3) + + self.kernel = nn.Parameter( + torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) + self.bias = nn.Parameter(torch.zeros(output_channels)) + self.register_buffer('paddings_kernel', torch.ones([1, 1, 1])) + + def get_same_padding(self, input_shape): + in_height, in_width = input_shape[2:] + stride_height, stride_width = self.filter_stride + filter_height, filter_width = 3, 3 + if in_height % stride_height == 0: + pad_along_height = max(filter_height - stride_height, 0) + else: + pad_along_height = max(filter_height - (in_height % stride_height), 0) + if in_width % stride_width == 0: + pad_along_width = max(filter_width - stride_width, 0) + else: + pad_along_width = max(filter_width - (in_width % stride_width), 0) + pad_top = pad_along_height // 2 + pad_bottom = pad_along_height - pad_top + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + return (pad_left, pad_right, pad_top, pad_bottom) + + def forward(self, inputs, paddings): + groups = inputs.shape[1] // self.input_channels + + if self.padding == 'SAME': + in_ = F.pad(inputs, self.get_same_padding(inputs.shape)) + else: + in_ = inputs + outputs = F.conv2d( + input=in_, + weight=self.kernel, + bias=self.bias, + stride=self.filter_stride, + dilation=(1, 1), + groups=groups) + + outputs = F.relu(outputs) + + input_length = paddings.shape[1] + stride = self.filter_stride[0] + pad_len = (input_length + stride - 1) // stride * stride - input_length + padded_paddings = F.pad( + paddings[:, None, :], (0, pad_len), mode='constant', value=0) + out_padding = F.conv1d( + input=padded_paddings, + weight=self.paddings_kernel, + stride=self.filter_stride[:1]) + out_padding = out_padding.squeeze(dim=1) + outputs = outputs * (1 - out_padding[:, None, :, None]) + return outputs, out_padding + + +class FeedForwardModule(nn.Module): + + def __init__(self, config: ConformerConfig): + super().__init__() + self.config = config + + self.ln = LayerNorm(dim=config.encoder_dim) + self.linear1 = nn.Linear( + in_features=config.encoder_dim, + out_features=config.encoder_dim * config.feed_forward_expansion_factor, + bias=True) + self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate, inplace=True) + self.linear2 = nn.Linear( + in_features=config.encoder_dim * config.feed_forward_expansion_factor, + out_features=config.encoder_dim, + bias=True) + + if config.feed_forward_residual_dropout_rate is None: + self.feed_forward_residual_dropout_rate = 0.1 + else: + self.feed_forward_residual_dropout_rate = config.feed_forward_residual_dropout_rate + + def forward(self, inputs, padding_mask, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.feed_forward_residual_dropout_rate + + inputs = self.ln(inputs) + inputs = self.linear1(inputs) + if self.config.activation_function_name == 'swish': + activation_fn = F.silu + elif self.config.activation_function_name == 'gelu': + # Use tanh approximation of GELU which is default for jax + activation_fn = partial(F.gelu, approximate='tanh') + else: + raise ValueError('Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{self.config.activation_function_name}') + inputs = activation_fn(inputs) + inputs = self.dropout1(inputs) + inputs = inputs * padding_mask + inputs = self.linear2(inputs) + inputs = inputs * padding_mask + inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) + + return inputs + + +class AddPositionalEmbedding(nn.Module): + + def __init__(self, + min_timescale: int = 1, + max_timescale: int = 10_000, + embedding_dim: int = 512): + super().__init__() + self.min_timescale = min_timescale + self.max_timescale = max_timescale + self.embedding_dim = embedding_dim + num_timescales = self.embedding_dim // 2 + log_timescale_increment = math.log( + float(self.max_timescale) / float(self.min_timescale)) / ( + num_timescales - 1) + inv_timescales = self.min_timescale * \ + torch.exp(torch.arange(num_timescales, dtype=torch.float32) + * -log_timescale_increment) + self.register_buffer('inv_timescales', inv_timescales[None, None, :]) + + def forward(self, seq_length): + position = torch.arange( + end=seq_length, dtype=torch.float32, device=self.inv_timescales.device) + scaled_time = position[None, :, None] * self.inv_timescales + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + if self.embedding_dim % 2: + signal = torch.cat( + [signal, torch.zeros(signal.shape[0], signal.shape[1], 1)], dim=2) + return signal + + +class QueryScaler(nn.Module): + + def __init__(self, dim): + super().__init__() + self.dim = dim + self.scale = nn.Parameter(torch.zeros(self.dim)) + + def forward(self, inputs): + r_softplus_0 = 1.442695041 + scale = r_softplus_0 * F.softplus(self.scale) + return inputs * scale + + +class MHSAwithQS(nn.Module): + + def __init__(self, config: ConformerConfig): + super().__init__() + self.embed_dim = config.encoder_dim + self.num_heads = config.num_attention_heads + self.attention_dropout_rate = config.attention_dropout_rate + self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim) + self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim) + self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads) + self.attention_temperature = config.attention_temperature + + def forward(self, inputs, key_padding_mask=None): + batch_size, seq_len, embed_dim = inputs.shape + q, k, v = self.in_proj(inputs).split(self.embed_dim, dim=2) + q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2) + k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) + v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) + out = F.scaled_dot_product_attention( + query=q, + key=k, + value=v, + attn_mask=~key_padding_mask[:, None, None], + dropout_p=self.attention_dropout_rate, + ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim) + out = out * self.attention_temperature + out = self.out_proj(out) + return out + + +class MultiHeadedSelfAttention(nn.Module): + + def __init__(self, config: ConformerConfig): + super().__init__() + + self.config = config + + self.ln = LayerNorm(dim=config.encoder_dim) + self.self_attention = MHSAwithQS(config) + if config.attention_residual_dropout_rate is None: + self.attention_residual_dropout_rate = 0.1 + else: + self.attention_residual_dropout_rate = config.attention_residual_dropout_rate + + def forward(self, outputs, paddings, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.attention_residual_dropout_rate + + outputs = self.ln(outputs) + outputs = self.self_attention( + outputs, + key_padding_mask=paddings == 1, + ) + outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) + return outputs + + +class BatchNorm(nn.Module): + + def __init__(self, config: ConformerConfig): + super().__init__() + running_mean = torch.zeros(config.encoder_dim) + running_var = torch.ones(config.encoder_dim) + self.register_buffer('running_mean', running_mean) + self.register_buffer('running_var', running_var) + self.scale = nn.Parameter(torch.zeros(config.encoder_dim)) + self.bias = nn.Parameter(torch.zeros(config.encoder_dim)) + + self.register_buffer('dim', torch.FloatTensor([config.encoder_dim])) + self.momentum = config.batch_norm_momentum + self.epsilon = config.batch_norm_epsilon + + def forward(self, inputs, input_paddings): + #inputs: NHD + #padding: NH + """ + Alternatively: + inputs[input_paddings==0] = F.batch_norm( + input = inputs[input_paddings==0], + running_mean = self.running_mean, + running_var = self.running_var, + weight = 1+self.scale, + bias = self.bias, + training = self.training, + momentum=1-self.momentum, + eps=self.epsilon + ) + inputs.masked_fill(input_paddings[...,None] != 0, 0) + return inputs + """ + mask = 1 - input_paddings[:, :, None] + if self.training: + count = mask.sum() + masked_inp = inputs.masked_fill(mask == 0, 0) + mean = (masked_inp).sum(dim=(0, 1)) / count + var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count + + self.running_mean = (1 - self.momentum) * self.running_mean + ( + self.momentum) * mean.detach() + self.running_var = (1 - self.momentum) * self.running_var + ( + self.momentum) * var.detach() + + else: + mean = self.running_mean + var = self.running_var + v = (1 + self.scale) * torch.rsqrt(var + self.epsilon) + bn = (inputs - mean) * v + self.bias + output = bn.masked_fill(mask == 0, 0) + return output + + +class ConvolutionBlock(nn.Module): + + def __init__(self, config): + super().__init__() + + self.config = config + self.ln = LayerNorm(dim=config.encoder_dim) + self.lin1 = nn.Linear( + in_features=config.encoder_dim, out_features=config.encoder_dim) + self.lin2 = nn.Linear( + in_features=config.encoder_dim, out_features=config.encoder_dim) + + self.conv1 = nn.Conv1d( + in_channels=config.encoder_dim, + out_channels=config.encoder_dim, + kernel_size=(config.convolution_kernel_size,), + stride=(1,), + padding='same', + bias=False, + groups=config.encoder_dim) + self.bn = BatchNorm(config) + self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim) + if config.conv_residual_dropout_rate is None: + self.conv_residual_dropout_rate = 0.0 + else: + self.conv_residual_dropout_rate = config.conv_residual_dropout_rate + + def forward(self, inputs, input_paddings, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.conv_residual_dropout_rate + + inputs = self.ln(inputs) + + inputs = F.glu(torch.cat([self.lin1(inputs), self.lin2(inputs)], dim=2)) + inputs = inputs * (1 - input_paddings[:, :, None]) + + inputs = inputs.permute(0, 2, 1) + inputs = self.conv1(inputs) + inputs = inputs.permute(0, 2, 1) + + inputs = self.bn(inputs, input_paddings) + if self.config.activation_function_name == 'swish': + activation_fn = F.silu + elif self.config.activation_function_name == 'gelu': + activation_fn = F.gelu + else: + raise ValueError('Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{self.config.activation_function_name}') + inputs = activation_fn(inputs) + inputs = self.lin3(inputs) + + inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) + return inputs + + +class ConformerBlock(nn.Module): + + def __init__(self, config: ConformerConfig): + super().__init__() + + self.ff1 = FeedForwardModule(config) + self.mhsa = MultiHeadedSelfAttention(config) + self.conv = ConvolutionBlock(config) + self.ff2 = FeedForwardModule(config) + self.ln = None + if config.use_post_layer_norm: + self.ln = LayerNorm(dim=config.encoder_dim) + + def forward(self, inputs, input_paddings, dropout_rate=None): + padding_mask = 1 - input_paddings[:, :, None] + inputs = inputs + 0.5 * self.ff1(inputs, padding_mask, dropout_rate) + inputs = inputs + self.mhsa(inputs, input_paddings, dropout_rate) + inputs = inputs + self.conv(inputs, input_paddings, dropout_rate) + inputs = inputs + 0.5 * self.ff2(inputs, padding_mask, dropout_rate) + if self.ln: + inputs = self.ln(inputs) + return inputs + + +class ConformerEncoderDecoder(nn.Module): + + def __init__(self, config: ConformerConfig): + super().__init__() + self.config = config + preprocessing_config = preprocessor.PreprocessorConfig() + self.preprocessor = preprocessor.MelFilterbankFrontend( + preprocessing_config, + per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR) + self.specaug = SpecAug( + freq_mask_count=config.freq_mask_count, + freq_mask_max_bins=config.freq_mask_max_bins, + time_mask_count=config.time_mask_count, + time_mask_max_frames=config.time_mask_max_frames, + time_mask_max_ratio=config.time_mask_max_ratio, + time_masks_per_frame=config.time_masks_per_frame, + use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames + ) + if config.input_dropout_rate is None: + input_dropout_rate = 0.1 + else: + input_dropout_rate = config.input_dropout_rate + self.subsample = Subsample( + encoder_dim=config.encoder_dim, + input_dropout_rate=input_dropout_rate, + num_bins=preprocessing_config.num_bins) + self.conformers = nn.ModuleList( + [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) + + self.ln = LayerNorm(config.encoder_dim) + self.lin = nn.Linear(config.encoder_dim, config.vocab_size) + + def forward(self, inputs, input_paddings, dropout_rate=None): + outputs = inputs + output_paddings = input_paddings + outputs, output_paddings = self.preprocessor(outputs, output_paddings) + if self.training and self.config.use_specaug: + outputs, output_paddings = self.specaug(outputs, output_paddings) + outputs, output_paddings = self.subsample(outputs, output_paddings) + for conformer in self.conformers: + outputs = conformer(outputs, output_paddings, dropout_rate) + outputs = self.ln(outputs) + outputs = self.lin(outputs) + return outputs, output_paddings From 3e7a3967ba5f4845826e971ba5f0c49fa4c031b9 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 11 Jun 2025 11:22:23 +0200 Subject: [PATCH 02/23] dropout fix deepspeech, ogbg --- algoperf/workloads/dropout_modules.py | 5 +- .../librispeech_pytorch/models_dropout.py | 395 ++++++++++++++++++ .../ogbg/ogbg_pytorch/models_dropout.py | 314 ++++++++++++++ 3 files changed, 711 insertions(+), 3 deletions(-) create mode 100644 algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py create mode 100644 algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py diff --git a/algoperf/workloads/dropout_modules.py b/algoperf/workloads/dropout_modules.py index 3917b75bf..6cec3f7ad 100644 --- a/algoperf/workloads/dropout_modules.py +++ b/algoperf/workloads/dropout_modules.py @@ -31,10 +31,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._supports_custom_dropout = True - def forward(self, x, p): + def forward(self, x: Tensor, p: float) -> Tensor: for module in self: - # if isinstance(module, (CustomDropout, SequentialWithDropout, DenseBlockWithDropout)): - if getattr(module, '_supports_custom_dropout', False): # TODO (nico): improve + if getattr(module, '_supports_custom_dropout', False): x = module(x, p) else: x = module(x) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py new file mode 100644 index 000000000..e68a820ed --- /dev/null +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py @@ -0,0 +1,395 @@ +"""This is a pytorch implementation mirroring: +https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py. +""" + +from dataclasses import dataclass +import os +from typing import Optional, Tuple + +import torch +from torch import nn +import torch.distributed.nn as dist_nn +import torch.nn.functional as F + +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ + preprocessor +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ + SpecAug + +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ + + +@dataclass +class DeepspeechConfig: + """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 1024 + encoder_dim: int = 512 + num_lstm_layers: int = 6 + num_ffn_layers: int = 3 + conv_subsampling_factor: int = 2 + conv_subsampling_layers: int = 2 + use_specaug: bool = True + freq_mask_count: int = 2 + freq_mask_max_bins: int = 27 + time_mask_count: int = 10 + time_mask_max_frames: int = 40 + time_mask_max_ratio: float = 0.05 + time_masks_per_frame: float = 0.0 + use_dynamic_time_mask_max_frames: bool = True + batch_norm_momentum: float = 1 - 0.999 + batch_norm_epsilon: float = 0.001 + # If None, defaults to 0.1. + input_dropout_rate: Optional[float] = 0.1 + # If None, defaults to 0.1. + feed_forward_dropout_rate: Optional[float] = 0.1 + enable_residual_connections: bool = True + enable_decoder_layer_norm: bool = True + bidirectional: bool = True + use_tanh: bool = False + layernorm_everywhere: bool = False + + +class LayerNorm(nn.Module): + + def __init__(self, dim, epsilon=1e-6): + super().__init__() + self.dim = dim + + self.scale = nn.Parameter(torch.zeros(self.dim)) + self.bias = nn.Parameter(torch.zeros(self.dim)) + self.epsilon = epsilon + + def forward(self, x): + mean = x.mean(dim=-1, keepdims=True) + var = x.var(dim=-1, unbiased=False, keepdims=True) + + normed_x = (x - mean) * torch.rsqrt(var + self.epsilon) + normed_x *= (1 + self.scale) + normed_x += self.bias + + return normed_x + + +class Subsample(nn.Module): + + def __init__(self, config: DeepspeechConfig): + super().__init__() + encoder_dim = config.encoder_dim + + self.encoder_dim = encoder_dim + + self.conv1 = Conv2dSubsampling( + input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh) + self.conv2 = Conv2dSubsampling( + input_channels=encoder_dim, + output_channels=encoder_dim, + use_tanh=config.use_tanh) + + self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True) + + if config.input_dropout_rate is None: + self.input_dropout_rate = 0.1 + else: + self.input_dropout_rate = config.input_dropout_rate + + def forward(self, inputs, input_paddings, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.input_dropout_rate + + output_paddings = input_paddings + outputs = inputs[:, None, :, :] + + outputs, output_paddings = self.conv1(outputs, output_paddings) + outputs, output_paddings = self.conv2(outputs, output_paddings) + + batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape + outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size, + subsampled_lengths, + subsampled_dims * channels) + + outputs = self.lin(outputs) + outputs = F.dropout(outputs, dropout_rate, training=self.training) + + return outputs, output_paddings + + +class Conv2dSubsampling(nn.Module): + + def __init__(self, + input_channels: int, + output_channels: int, + filter_stride: Tuple[int] = (2, 2), + padding: str = 'SAME', + batch_norm_momentum: float = 0.999, + batch_norm_epsilon: float = 0.001, + use_tanh: bool = False): + super().__init__() + + self.input_channels = input_channels + self.output_channels = output_channels + self.filter_stride = filter_stride + self.padding = padding + + self.filter_shape = (output_channels, input_channels, 3, 3) + + self.kernel = nn.Parameter( + nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) + self.bias = nn.Parameter(torch.zeros(output_channels)) + + self.use_tanh = use_tanh + + def get_same_padding(self, input_shape): + in_height, in_width = input_shape[2:] + stride_height, stride_width = self.filter_stride + filter_height, filter_width = 3, 3 + if in_height % stride_height == 0: + pad_along_height = max(filter_height - stride_height, 0) + else: + pad_along_height = max(filter_height - (in_height % stride_height), 0) + if in_width % stride_width == 0: + pad_along_width = max(filter_width - stride_width, 0) + else: + pad_along_width = max(filter_width - (in_width % stride_width), 0) + pad_top = pad_along_height // 2 + pad_bottom = pad_along_height - pad_top + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + return (pad_left, pad_right, pad_top, pad_bottom) + + def forward(self, inputs, paddings): + groups = inputs.shape[1] // self.input_channels + + if self.padding == 'SAME': + in_ = F.pad(inputs, self.get_same_padding(inputs.shape)) + else: + in_ = inputs + outputs = F.conv2d( + input=in_, + weight=self.kernel, + bias=self.bias, + stride=self.filter_stride, + dilation=(1, 1), + groups=groups) + + if self.use_tanh: + outputs = F.tanh(outputs) + else: + outputs = F.relu(outputs) + + input_length = paddings.shape[1] + stride = self.filter_stride[0] + pad_len = (input_length + stride - 1) // stride * stride - input_length + out_padding = F.conv1d( + input=torch.cat([ + paddings[:, None, :], + torch.zeros( + size=(paddings.shape[0], 1, pad_len), device=paddings.device) + ], + dim=2), + weight=torch.ones([1, 1, 1], device=paddings.device), + stride=self.filter_stride[:1]) + out_padding = out_padding.squeeze(dim=1) + outputs = outputs * (1 - out_padding[:, None, :, None]) + return outputs, out_padding + + +class FeedForwardModule(nn.Module): + + def __init__(self, config: DeepspeechConfig): + super().__init__() + self.config = config + + if config.layernorm_everywhere: + self.normalization_layer = LayerNorm(config.encoder_dim) + else: + self.bn_normalization_layer = BatchNorm( + dim=config.encoder_dim, + batch_norm_momentum=config.batch_norm_momentum, + batch_norm_epsilon=config.batch_norm_epsilon) + self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True) + if config.feed_forward_dropout_rate is None: + self.feed_forward_dropout_rate = 0.1 + else: + self.feed_forward_dropout_rate = config.feed_forward_dropout_rate + + def forward(self, inputs, input_paddings, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.feed_forward_dropout_rate + + padding_mask = (1 - input_paddings)[:, :, None] + if self.config.layernorm_everywhere: + inputs = self.normalization_layer(inputs) + else: # batchnorm + inputs = self.bn_normalization_layer(inputs, input_paddings) + + inputs = self.lin(inputs) + + if self.config.use_tanh: + inputs = F.tanh(inputs) + else: + inputs = F.relu(inputs) + + inputs = inputs * padding_mask + inputs = F.dropout(inputs, dropout_rate, training=self.training) + + return inputs + + +class BatchNorm(nn.Module): + + def __init__(self, dim, batch_norm_momentum, batch_norm_epsilon): + super().__init__() + running_mean = torch.zeros(dim) + running_var = torch.ones(dim) + self.register_buffer('running_mean', running_mean) + self.register_buffer('running_var', running_var) + self.weight = nn.Parameter(torch.zeros(dim)) + self.bias = nn.Parameter(torch.zeros(dim)) + + self.momentum = batch_norm_momentum + self.epsilon = batch_norm_epsilon + self.dim = dim + + def forward(self, inputs, input_paddings): + #inputs: NHD + #padding: NH + mask = 1 - input_paddings[:, :, None] + if self.training: + count = mask.sum() + masked_inp = inputs.masked_fill(mask == 0, 0) + sum_ = (masked_inp).sum(dim=(0, 1)) + if USE_PYTORCH_DDP: + sum_ = dist_nn.all_reduce(sum_) + count = dist_nn.all_reduce(count) + mean = sum_ / count + + sum_ = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) + if USE_PYTORCH_DDP: + sum_ = dist_nn.all_reduce(sum_) + var = sum_ / count + + self.running_mean = (1 - self.momentum) * self.running_mean + ( + self.momentum) * mean.detach() + self.running_var = (1 - self.momentum) * self.running_var + ( + self.momentum) * var.detach() + else: + mean = self.running_mean + var = self.running_var + v = (1 + self.weight) * torch.rsqrt(var + self.epsilon) + bn = (inputs - mean) * v + self.bias + output = bn.masked_fill(mask == 0, 0) + return output + + +class BatchRNN(nn.Module): + + def __init__(self, config: DeepspeechConfig): + super().__init__() + self.config = config + hidden_size = config.encoder_dim + input_size = config.encoder_dim + bidirectional = config.bidirectional + self.bidirectional = bidirectional + + if config.layernorm_everywhere: + self.normalization_layer = LayerNorm(config.encoder_dim) + else: + self.bn_normalization_layer = BatchNorm(config.encoder_dim, + config.batch_norm_momentum, + config.batch_norm_epsilon) + + if bidirectional: + self.lstm = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size // 2, + bidirectional=True, + batch_first=True) + else: + self.lstm = nn.LSTM( + input_size=input_size, hidden_size=hidden_size, batch_first=True) + + def forward(self, inputs, input_paddings): + if self.config.layernorm_everywhere: + inputs = self.normalization_layer(inputs) + else: + inputs = self.bn_normalization_layer(inputs, input_paddings) + lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy() + packed_inputs = torch.nn.utils.rnn.pack_padded_sequence( + inputs, lengths, batch_first=True, enforce_sorted=False) + packed_outputs, _ = self.lstm(packed_inputs) + outputs, _ = torch.nn.utils.rnn.pad_packed_sequence( + packed_outputs, batch_first=True) + if outputs.shape[1] < inputs.shape[1]: + outputs = torch.cat([ + outputs, + torch.zeros( + size=(outputs.shape[0], + inputs.shape[1] - outputs.shape[1], + outputs.shape[2]), + device=outputs.device) + ], + dim=1) + return outputs + + +class DeepspeechEncoderDecoder(nn.Module): + + def __init__(self, config: DeepspeechConfig): + super().__init__() + self.config = config + + self.specaug = SpecAug( + freq_mask_count=config.freq_mask_count, + freq_mask_max_bins=config.freq_mask_max_bins, + time_mask_count=config.time_mask_count, + time_mask_max_frames=config.time_mask_max_frames, + time_mask_max_ratio=config.time_mask_max_ratio, + time_masks_per_frame=config.time_masks_per_frame, + use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames + ) + preprocessing_config = preprocessor.PreprocessorConfig() + self.preprocessor = preprocessor.MelFilterbankFrontend( + preprocessing_config, + per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR) + + self.subsample = Subsample(config=config) + + self.lstms = nn.ModuleList( + [BatchRNN(config) for _ in range(config.num_lstm_layers)]) + self.ffns = nn.ModuleList( + [FeedForwardModule(config) for _ in range(config.num_ffn_layers)]) + + if config.enable_decoder_layer_norm: + self.ln = LayerNorm(config.encoder_dim) + else: + self.ln = nn.Identity() + + self.lin = nn.Linear(config.encoder_dim, config.vocab_size) + + def forward(self, inputs, input_paddings, dropout_rate=None): + outputs = inputs + output_paddings = input_paddings + + outputs, output_paddings = self.preprocessor(outputs, output_paddings) + if self.training and self.config.use_specaug: + outputs, output_paddings = self.specaug(outputs, output_paddings) + outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) + for idx in range(self.config.num_lstm_layers): + if self.config.enable_residual_connections: + outputs = outputs + self.lstms[idx](outputs, output_paddings) + else: + outputs = self.lstms[idx](outputs, output_paddings) + + for idx in range(self.config.num_ffn_layers): + if self.config.enable_residual_connections: + outputs = outputs + self.ffns[idx](outputs, output_paddings) + else: + outputs = self.ffns[idx](outputs, output_paddings, dropout_rate) + + if self.config.enable_decoder_layer_norm: + outputs = self.ln(outputs) + + outputs = self.lin(outputs) + + return outputs, output_paddings diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py new file mode 100644 index 000000000..1d89ea9e7 --- /dev/null +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py @@ -0,0 +1,314 @@ +# Ported to PyTorch from +# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. +from functools import partial +from typing import Callable, Optional, Tuple + +import jax.tree_util as tree +from jraph import GraphsTuple +import torch +from torch import nn + +from algoperf import init_utils +from algoperf.workloads.dropout_modules import CustomDropout, SequentialWithDropout + + +def _make_mlp(in_dim, hidden_dims, activation_fn): + """Creates a MLP with specified dimensions.""" + layers = SequentialWithDropout() + for i, dim in enumerate(hidden_dims): + layers.add_module(f'dense_{i}', + nn.Linear(in_features=in_dim, out_features=dim)) + layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) + layers.add_module(f'activation_fn_{i}', activation_fn()) + layers.add_module(f'dropout_{i}', CustomDropout()) + in_dim = dim + return layers + + +class GNN(nn.Module): + """Defines a graph network. + + The model assumes the input data is a jraph.GraphsTuple without global + variables. The final prediction will be encoded in the globals. + """ + + def __init__(self, + num_outputs: int = 128, + dropout_rate: Optional[float] = 0.1, + activation_fn_name: str = 'relu', + latent_dim: int = 256, + hidden_dims: Tuple[int] = (256,), + num_message_passing_steps: int = 5) -> None: + super().__init__() + self.latent_dim = latent_dim + self.hidden_dims = hidden_dims + self.num_message_passing_steps = num_message_passing_steps + self.num_outputs = num_outputs + if dropout_rate is None: + self.dropout_rate = 0.1 + # in_features are specifically chosen for the ogbg workload. + self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim) + self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim) + + if activation_fn_name == 'relu': + activation_fn = nn.ReLU + elif activation_fn_name == 'gelu': + activation_fn = partial(nn.GELU, approximate='tanh') + elif activation_fn_name == 'silu': + activation_fn = nn.SiLU + else: + raise ValueError( + f'Invalid activation function name: {self.activation_fn_name}') + + graph_network_layers = [] + for st in range(self.num_message_passing_steps): + # Constants in in_dims are based on forward call of GraphNetwork: + # specifically update_edge_fn update_node_fn and update_global_fn. + if st == 0: + in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs + in_dim_node_fn = self.latent_dim + self.hidden_dims[ + -1] * 2 + self.num_outputs + last_in_dim = self.hidden_dims[-1] * 2 + self.num_outputs + else: + in_dim_edge_fn = self.hidden_dims[-1] * 4 + in_dim_node_fn = self.hidden_dims[-1] * 4 + last_in_dim = self.hidden_dims[-1] * 3 + + graph_network_layers.append( + GraphNetwork( + update_edge_fn=_make_mlp(in_dim_edge_fn, + self.hidden_dims, + activation_fn), + update_node_fn=_make_mlp(in_dim_node_fn, + self.hidden_dims, + activation_fn), + update_global_fn=_make_mlp(last_in_dim, + self.hidden_dims, + activation_fn))) + self.graph_network = SequentialWithDropout(*graph_network_layers) + + self.decoder = nn.Linear( + in_features=self.hidden_dims[-1], out_features=self.num_outputs) + + for m in self.modules(): + if isinstance(m, nn.Linear): + init_utils.pytorch_default_init(m) + + def forward(self, graph: GraphsTuple, dropout_rate=None) -> torch.Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate + + graph = graph._replace( + globals=torch.zeros([graph.n_node.shape[0], self.num_outputs], + device=graph.n_node.device)) + graph = graph._replace(nodes=self.node_embedder(graph.nodes)) + graph = graph._replace(edges=self.edge_embedder(graph.edges)) + + graph = self.graph_network(graph, dropout_rate) + + # Map globals to represent the final result + graph = graph._replace(globals=self.decoder(graph.globals)) + + return graph.globals + + +class GraphNetwork(nn.Module): + """Returns a method that applies a configured GraphNetwork. + This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 + There is one difference. For the nodes update the class aggregates over the + sender edges and receiver edges separately. This is a bit more general + than the algorithm described in the paper. The original behaviour can be + recovered by using only the receiver edge aggregations for the update. + In addition this implementation supports softmax attention over incoming + edge features. + Example usage:: + gn = GraphNetwork(update_edge_function, + update_node_function, **kwargs) + # Conduct multiple rounds of message passing with the same parameters: + for _ in range(num_message_passing_steps): + graph = gn(graph) + Args: + update_edge_fn: function used to update the edges or None to deactivate edge + updates. + update_node_fn: function used to update the nodes or None to deactivate node + updates. + update_global_fn: function used to update the globals or None to deactivate + globals updates. + Returns: + A method that applies the configured GraphNetwork. + """ + + def __init__(self, + update_edge_fn: Optional[Callable] = None, + update_node_fn: Optional[Callable] = None, + update_global_fn: Optional[Callable] = None) -> None: + super().__init__() + self.update_edge_fn = update_edge_fn + self.update_node_fn = update_node_fn + self.update_global_fn = update_global_fn + self._supports_custom_dropout = True # supports SequentialWithDropout + + def forward(self, graph: GraphsTuple, dropout_rate=None) -> GraphsTuple: + """Applies a configured GraphNetwork to a graph. + This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 + There is one difference. For the nodes update the class aggregates over the + sender edges and receiver edges separately. This is a bit more general + the algorithm described in the paper. The original behaviour can be + recovered by using only the receiver edge aggregations for the update. + In addition this implementation supports softmax attention over incoming + edge features. + Many popular Graph Neural Networks can be implemented as special cases of + GraphNets, for more information please see the paper. + Args: + graph: a `GraphsTuple` containing the graph. + Returns: + Updated `GraphsTuple`. + """ + nodes, edges, receivers, senders, globals_, n_node, n_edge = graph + sum_n_node = tree.tree_leaves(nodes)[0].shape[0] + if not tree.tree_all( + tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)): + raise ValueError( + 'All node arrays in nest must contain the same number of nodes.') + + sent_attributes = tree.tree_map(lambda n: n[senders], nodes) + received_attributes = tree.tree_map(lambda n: n[receivers], nodes) + # Here we scatter the global features to the corresponding edges, + # giving us tensors of shape [num_edges, global_feat]. + global_edge_attributes = tree.tree_map( + lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_) + if self.update_edge_fn: + edge_fn_inputs = torch.cat( + [edges, sent_attributes, received_attributes, global_edge_attributes], + dim=-1) + edges = self.update_edge_fn(edge_fn_inputs, dropout_rate) + + if self.update_node_fn: + sent_attributes = tree.tree_map( + lambda e: scatter_sum(e, senders, dim=0, dim_size=sum_n_node), edges) + received_attributes = tree.tree_map( + lambda e: scatter_sum(e, receivers, dim=0, dim_size=sum_n_node), + edges) + # Here we scatter the global features to the corresponding nodes, + # giving us tensors of shape [num_nodes, global_feat]. + global_attributes = tree.tree_map( + lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_) + node_fn_inputs = torch.cat( + [nodes, sent_attributes, received_attributes, global_attributes], + dim=-1) + nodes = self.update_node_fn(node_fn_inputs, dropout_rate) + + if self.update_global_fn: + n_graph = n_node.shape[0] + graph_idx = torch.arange(n_graph, device=graph.n_node.device) + # To aggregate nodes and edges from each graph to global features, + # we first construct tensors that map the node to the corresponding graph. + # For example, if you have `n_node=[1,2]`, we construct the tensor + # [0, 1, 1]. We then do the same for edges. + node_gr_idx = torch.repeat_interleave(graph_idx, n_node, dim=0) + edge_gr_idx = torch.repeat_interleave(graph_idx, n_edge, dim=0) + # We use the aggregation function to pool the nodes/edges per graph. + node_attributes = tree.tree_map( + lambda n: scatter_sum(n, node_gr_idx, dim=0, dim_size=n_graph), nodes) + edge_attributes = tree.tree_map( + lambda e: scatter_sum(e, edge_gr_idx, dim=0, dim_size=n_graph), edges) + # These pooled nodes are the inputs to the global update fn. + global_fn_inputs = torch.cat([node_attributes, edge_attributes, globals_], + dim=-1) + globals_ = self.update_global_fn(global_fn_inputs, dropout_rate) + + return GraphsTuple( + nodes=nodes, + edges=edges, + receivers=receivers, + senders=senders, + globals=globals_, + n_node=n_node, + n_edge=n_edge) + + +# Forked from +# github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py. +def scatter_sum(src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + r""" + | + .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ + master/docs/source/_figures/add.svg?sanitize=true + :align: center + :width: 400px + | + Reduces all values from the :attr:`src` tensor into :attr:`out` at the + indices specified in the :attr:`index` tensor along a given axis + :attr:`dim`. + For each value in :attr:`src`, its output index is specified by its index + in :attr:`src` for dimensions outside of :attr:`dim` and by the + corresponding value in :attr:`index` for dimension :attr:`dim`. + The applied reduction is here defined as a sum. + Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional + tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` + and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional + tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. + Moreover, the values of :attr:`index` must be between :math:`0` and + :math:`y - 1`, although no specific ordering of indices is required. + The :attr:`index` tensor supports broadcasting in case its dimensions do + not match with :attr:`src`. + For one-dimensional tensors, the operation computes + .. math:: + \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j + where :math:`\sum_j` is over :math:`j` such that + :math:`\mathrm{index}_j = i`. + .. note:: + This operation is implemented via atomic operations on the GPU and is + therefore **non-deterministic** since the order of parallel operations + to the same value is undetermined. + For floating-point variables, this results in a source of variance in + the result. + :param src: The source tensor. + :param index: The indices of elements to scatter. + :param dim: The axis along which to index. (default: :obj:`-1`) + :param out: The destination tensor. + :param dim_size: If :attr:`out` is not given, automatically create output + with size :attr:`dim_size` at dimension :attr:`dim`. + If :attr:`dim_size` is not given, a minimal sized output tensor + according to :obj:`index.max() + 1` is returned. + :rtype: :class:`Tensor` + .. code-block:: python + src = torch.randn(10, 6, 64) + index = torch.tensor([0, 1, 0, 1, 2, 1]) + # Broadcasting in the first and last dim. + out = scatter_sum(src, index, dim=1) + print(out.size()) + .. code-block:: + torch.Size([10, 3, 64]) + """ + index = broadcast(index, src, dim) + if out is None: + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = torch.zeros(size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) + else: + return out.scatter_add_(dim, index, src) + + +# Forked from +# github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/utils.py. +def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand(other.size()) + return src From e80add440b4ab281414d7babf4f5492c16d758e2 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 11 Jun 2025 11:23:05 +0200 Subject: [PATCH 03/23] remove attention_dropout_rate from wmt --- .../wmt/wmt_pytorch/models_dropout.py | 981 ++++++++++++++++++ 1 file changed, 981 insertions(+) create mode 100644 algoperf/workloads/wmt/wmt_pytorch/models_dropout.py diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py new file mode 100644 index 000000000..588d06abf --- /dev/null +++ b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py @@ -0,0 +1,981 @@ +import copy +import math +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F +from torch.nn.init import normal_ +from torch.nn.init import xavier_uniform_ + + +def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: + """Make a causal mask for self-attention. + + Args: + x: input array of shape `[batch..., len]` + device: device to store the idxs + + Returns: + A `[batch..., len, len]` shaped causal attention mask. + """ + idxs = torch.broadcast_to( + torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape) + return torch.greater_equal(idxs.unsqueeze(-1), idxs.unsqueeze(-2)) + + +def make_src_mask(src, inputs_segmentation, nhead): + """Utility for creating src mask and adjust it for PyTorch Transformer API.""" + src_mask = torch.mul((src > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)) + # Add segmentation block-diagonal attention mask if using segmented data. + if inputs_segmentation is not None: + src_mask = torch.logical_and( + src_mask, + torch.eq( + inputs_segmentation.unsqueeze(-1), + inputs_segmentation.unsqueeze(-2))) + # Flip values and ensure numerical stability. + src_mask = torch.repeat_interleave( + torch.logical_not(src_mask), repeats=nhead, dim=0) + new_src_mask = torch.zeros_like(src_mask, dtype=torch.float32) + new_src_mask.masked_fill_(src_mask, -1e10) + return new_src_mask + + +def make_tgt_and_memory_mask(tgt, + src, + inputs_segmentation, + targets_segmentation, + decode, + nhead): + """ Utility for creating target and memory mask and adjust them for PyTorch + Transformer API.""" + if not decode: + tgt_mask = torch.logical_and( + torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)), + make_causal_mask(tgt, device=tgt.device)) + memory_mask = torch.mul((tgt > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)) + else: + tgt_mask = None + memory_mask = torch.mul((torch.ones_like(tgt) > 0).unsqueeze(-1), + (src > 0).unsqueeze(-2)) + # Add segmentation block-diagonal attention masks if using segmented data. + if inputs_segmentation is not None: + tgt_mask = torch.logical_and( + tgt_mask, + torch.eq( + targets_segmentation.unsqueeze(-1), + targets_segmentation.unsqueeze(-2))) + memory_mask = torch.logical_and( + memory_mask, + torch.eq( + targets_segmentation.unsqueeze(-1), + inputs_segmentation.unsqueeze(-2))) + # Flip values and ensure numerical stability. + memory_mask = torch.repeat_interleave( + torch.logical_not(memory_mask), repeats=nhead, dim=0) + new_memory_mask = torch.zeros_like(memory_mask, dtype=torch.float32) + new_memory_mask.masked_fill_(memory_mask, -1e10) + if tgt_mask is not None: + tgt_mask = torch.repeat_interleave( + torch.logical_not(tgt_mask), repeats=nhead, dim=0) + new_tgt_mask = torch.zeros_like(tgt_mask, dtype=torch.float32) + new_tgt_mask.masked_fill_(tgt_mask, -1e10) + tgt_mask = new_tgt_mask + return tgt_mask, new_memory_mask + + +def shift_right(x, axis=1): + """Shift the input to the right by padding on axis 1.""" + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[axis] = (1, 0) + pad_widths = tuple(t for tup in reversed(pad_widths) for t in tup) + padded = F.pad(x, pad_widths, mode='constant') + return padded[:, :-1] + + +class Transformer(nn.Module): + """Transformer architecture based on the model from the WMT Jax workload.""" + + def __init__(self, + ntoken: int = 32000, + d_model: int = 1024, + nhead: int = 16, + d_hid: int = 1024, + nlayers: int = 6, + dropout_rate: Optional[float] = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True): + super().__init__() + if dropout_rate is None: + dropout_rate = 0.1 + self.pos_encoder = PositionalEncoding(d_model, dropout_rate) + self.shared_embedding = nn.Embedding(ntoken, d_model) + self.encoder = Encoder(d_model, + nhead, + d_hid, + nlayers, + dropout_rate, + activation, + glu, + layer_norm_eps, + attention_temp, + pre_ln) + self.decoder = Decoder(d_model, + nhead, + d_hid, + nlayers, + dropout_rate, + activation, + glu, + layer_norm_eps, + attention_temp, + pre_ln) + # Share positional encoding and embedding between encoder and decoder. + self.encoder.pos_encoder = self.pos_encoder + self.encoder.shared_embedding = self.shared_embedding + self.decoder.pos_encoder = self.pos_encoder + self.decoder.shared_embedding = self.shared_embedding + + self._reset_parameters() + + def _reset_parameters(self): + """Initiate parameters in the transformer model.""" + for module in self.modules(): + if isinstance(module, nn.Linear): + xavier_uniform_(module.weight) + if module.bias is not None: + normal_(module.bias, std=1e-6) + + def forward(self, + src: Tensor, + tgt: Tensor, + inputs_positions: Optional[Tensor] = None, + targets_positions: Optional[Tensor] = None, + inputs_segmentation: Optional[Tensor] = None, + targets_segmentation: Optional[Tensor] = None, + decode: bool = False) -> Tensor: + """ + Args: + src: Tensor, shape [batch_size, seq_len] + tgt: Tensor, shape [batch_size, seq_len] + inputs_positions: Optional[Tensor], shape [batch_size, seq_len] + targets_positions: Optional[Tensor], shape [batch_size, seq_len] + inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len] + targets_segmentation: Optional[Tensor], shape [batch_size, seq_len] + decode: bool + + Returns: + output Tensor of shape [batch_size, seq_len, ntoken] + """ + if src.size(0) != tgt.size(0): + raise RuntimeError('The batch size of src and tgt must be equal.') + memory = self.encoder( + src, + inputs_positions=inputs_positions, + inputs_segmentation=inputs_segmentation) + output = self.decoder( + tgt, + memory, + src, # just for calculating the padding mask + targets_positions=targets_positions, + inputs_segmentation=inputs_segmentation, + targets_segmentation=targets_segmentation, + decode=decode) + return output + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class. + num_layers: the number of sub-encoder-layers in the encoder. + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to + nested tensor (and convert back on output). This will improve + the overall performance of TransformerEncoder when padding + rate is high. + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(12, 8) + >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, 6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + __constants__ = ['norm'] + + def __init__(self, + encoder_layer, + num_layers, + norm=None, + enable_nested_tensor=True, + mask_check=True): + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for _ in range(num_layers)]) + self.num_layers = num_layers + self.norm = norm + self.enable_nested_tensor = enable_nested_tensor + self.mask_check = mask_check + + def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: + """Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + + Shape: + see the docs in Transformer class. + """ + output = src + convert_to_nested = False + + for mod in self.layers: + output = mod(output, src_mask=mask) + + if convert_to_nested: + output = output.to_padded_tensor(0.) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class Encoder(nn.Module): + + def __init__(self, + d_model: int = 1024, + nhead: int = 16, + d_hid: int = 1024, + nlayers: int = 6, + dropout_rate: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True): + super().__init__() + self.nhead = nhead + self.shared_embedding = None + self.pos_encoder = None + encoder_layer = TransformerEncoderLayer( + d_model, + nhead, + d_hid, + dropout_rate, + activation=activation, + glu=glu, + layer_norm_eps=layer_norm_eps, + attention_temp=attention_temp, + pre_ln=pre_ln) + encoder_norm = ( + nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None) + self.encoder = TransformerEncoder(encoder_layer, nlayers, encoder_norm) + + def forward(self, + src: Tensor, + inputs_positions: Optional[Tensor] = None, + inputs_segmentation: Optional[Tensor] = None) -> Tensor: + src = src.to(torch.int) + src_mask = make_src_mask(src, inputs_segmentation, self.nhead) + src = self.shared_embedding(src) + src = self.pos_encoder(src, inputs_positions) + memory = self.encoder(src, mask=src_mask) + return memory + + +class Decoder(nn.Module): + + def __init__(self, + d_model: int = 1024, + nhead: int = 16, + d_hid: int = 1024, + nlayers: int = 6, + dropout_rate: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True): + super().__init__() + self.nhead = nhead + self.shared_embedding = None + self.pos_encoder = None + self.decoder = TransformerDecoder(d_model, + nhead, + d_hid, + dropout_rate, + activation, + glu, + layer_norm_eps, + nlayers, + attention_temp, + pre_ln) + + def forward( + self, + tgt: Tensor, + memory: Tensor, + src: Tensor, # just for calculating the padding mask + targets_positions: Optional[Tensor] = None, + inputs_segmentation: Optional[Tensor] = None, + targets_segmentation: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None) -> Any: + tgt = tgt.to(torch.int) + tgt_mask, memory_mask = make_tgt_and_memory_mask( + tgt, src, inputs_segmentation, targets_segmentation, + decode, self.nhead) + if not decode: + tgt = shift_right(tgt) + tgt = self.shared_embedding(tgt) + tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache) + if decode: + tgt, cache = tgt + output = self.decoder( + tgt, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + decode=decode, + max_len=max_len, + cache=cache) + if decode: + output, cache = output + normalize = math.sqrt(output.shape[-1]) + output = torch.matmul(output, self.shared_embedding.weight.T) / normalize + if decode: + return output, cache + return output + + +class PositionalEncoding(nn.Module): + + def __init__(self, + d_model: int, + dropout_rate: float = 0.1, + max_len: int = 256): + super().__init__() + self.dropout = nn.Dropout(p=dropout_rate) + + position = torch.arange(max_len).unsqueeze(1) + scale_factor = -math.log(10000.0) / (d_model // 2 - 1) + div_term = torch.exp(torch.arange(d_model // 2) * scale_factor) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, :d_model // 2] = torch.sin(position * div_term) + pe[0, :, d_model // 2:2 * (d_model // 2)] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward( + self, + x: Tensor, + inputs_positions: Optional[Tensor] = None, + decode: bool = False, + cache: Optional[Dict[str, Dict[str, Tensor]]] = None + ) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]: + """ + Args: + x: Tensor (shape [batch_size, seq_len, embedding_dim]) + inputs_positions: Tensor (shape [batch_size, seq_len]) or None + decode: bool + cache: Dict[str, Dict[str, Tensor]] or None + Returns: + Tensor or Tuple[Tensor, Dict[str, Dict[str, Tensor]]] + """ + # We use a cache position index for tracking decoding position. + if decode: + name = self._get_name() + if cache is None: + cache = { + name: { + 'cache_index': + torch.tensor(0, dtype=torch.long, device=self.pe.device), + }, + } + pe = self.pe[0, cache[name]['cache_index'], :] + cache[name]['cache_index'] += 1 + return self.dropout(x + pe), cache + if inputs_positions is None: + # normal unpacked case: + pe = self.pe[:, :x.size(1), :] + else: + # for packed data we need to use known position indices: + pe = self.pe[0, inputs_positions, :] + return self.dropout(x + pe) + + +# TransformerEncoderLayer and TransformerDecoderLayer are taken from: +# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py +# Main difference is the use of custom MultiheadAttention modules. +class TransformerEncoderLayer(nn.Module): + r"""TransformerEncoderLayer is made up of self-attn and feedforward network. + This standard encoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, + Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all + you need. In Advances in Neural Information Processing Systems, + pages 6000-6010. Users may modify or implement in a different way during + application. + Args: + d_model: the number of expected features in the input (default=1024). + nhead: the number of heads in the multiheadattention models (default=16). + dim_feedforward: the dimension of the feedforward network model + (default=1024). + dropout_rate: the dropout_rate value (default=0.1). + activation: the activation function of the intermediate layer, can be a + string ("relu" or "gelu") or a unary callable (default=F.relu). + layer_norm_eps: the eps value in layer normalization components + (default=1e-6). + pre_ln: if ``True``, layer norm is done prior to attention and + feedforward operations, respectivaly. Otherwise it's done after. + Default: ``True``. + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(32, 10, 512) + >>> out = encoder_layer(src) + """ + __constants__ = ['pre_ln'] + + def __init__(self, + d_model: int = 1024, + nhead: int = 16, + dim_feedforward: int = 1024, + dropout_rate: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True, + device=None, + dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + self_attn=True, + dropout_rate=dropout_rate, + attention_temp=attention_temp, + bias=False, + **factory_kwargs) + + # Implementation of Feedforward model. + self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.glu = glu + if self.glu: + self.linear_glu = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.dropout = nn.Dropout(dropout_rate) + self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) + + self.pre_ln = pre_ln + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.dropout1 = nn.Dropout(dropout_rate) + self.dropout2 = nn.Dropout(dropout_rate) + + self.activation = activation + + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + + Shape: + see the docs in Transformer class. + """ + x = src + if self.pre_ln: + x = x + self._sa_block(self.norm1(x), src_mask) + x = x + self._ff_block(self.norm2(x)) + else: + x = self.norm1(x + self._sa_block(x, src_mask)) + x = self.norm2(x + self._ff_block(x)) + + return x + + # Self-attention block: + def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tensor: + x, _ = self.self_attn(x, attn_mask=attn_mask) + return self.dropout1(x) + + # Feed forward block: + def _ff_block(self, inputs: Tensor) -> Tensor: + x = self.activation(self.linear1(inputs)) + if self.glu: + y = self.linear_glu(inputs) + x = x * y + x = self.linear2(self.dropout(x)) + return self.dropout2(x) + + +# Modified to use cache for autoregressive decoding and custom +# MultiheadAttention modules. +class TransformerDecoder(nn.Module): + r"""TransformerDecoder is a stack of N decoder layers + Args: + d_model: the number of expected features in the input (default=1024) + nhead: the number of heads in the multiheadattention models (default=16) + d_hid: the dimension of the feedforward network model + (default=1024) + dropout_rate: the dropout_rate value (default=0.1) + layer_norm_eps: the eps value in layer normalization components + (default=1e-6). + decoder_layer: an instance of the TransformerDecoderLayer() class + num_layers: the number of sub-decoder-layers in the decoder + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(12, 8) + >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, 6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = transformer_decoder(tgt, memory) + """ + __constants__ = ['norm'] + + def __init__(self, + d_model, + nhead, + d_hid, + dropout_rate, + activation, + glu, + layer_norm_eps, + num_layers, + attention_temp, + pre_ln): + super().__init__() + self.layers = nn.ModuleList([ + TransformerDecoderLayer( + d_model, + nhead, + d_hid, + dropout_rate, + activation, + glu, + layer_norm_eps=layer_norm_eps, + attention_temp=attention_temp, + pre_ln=pre_ln) for _ in range(num_layers) + ]) + self.num_layers = num_layers + self.norm = (nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None) + + def forward(self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None) -> Any: + r"""Pass the inputs (and mask) through the decoder layer in turn. + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + decode: whether to use cache for autoregressive decoding or not. + max_len: maximum sequence length, necessary for decoding cache. + Shape: + see the docs in Transformer class. + """ + output = tgt + + for idx, mod in enumerate(self.layers): + output, cache = mod( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=idx) + + if self.norm is not None: + output = self.norm(output) + + if decode: + return output, cache + return output + + +# Modified to use cache for autoregressive decoding and custom +# MultiheadAttention modules. +class TransformerDecoderLayer(nn.Module): + r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and + feedforward network. + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, + Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all + you need. In Advances in Neural Information Processing Systems, + pages 6000-6010. Users may modify or implement in a different way during + application. + Args: + d_model: the number of expected features in the input (default=1024). + nhead: the number of heads in the multiheadattention models (default=16). + dim_feedforward: the dimension of the feedforward network model + (default=1024). + dropout_rate: the dropout_rate value (default=0.1). + activation: the activation function of the intermediate layer, can be a + string ("relu" or "gelu") or a unary callable (default=F.relu). + layer_norm_eps: the eps value in layer normalization components + (default=1e-6). + pre_ln: if ``True``, layer norm is done prior to self attention, + multihead attention and feedforward operations, respectivaly. + Otherwise it's done after. Default: ``True``. + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(32, 10, 512) + >>> tgt = torch.rand(32, 20, 512) + >>> out = decoder_layer(tgt, memory) + """ + __constants__ = ['pre_ln'] + + def __init__(self, + d_model: int = 1024, + nhead: int = 16, + dim_feedforward: int = 1024, + dropout_rate: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + pre_ln: bool = True, + attention_temp: float = 1.0, + device=None, + dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + self_attn=True, + dropout_rate=dropout_rate, + attention_temp=attention_temp, + bias=False, + **factory_kwargs) + self.multihead_attn = MultiheadAttention( + d_model, + nhead, + self_attn=False, + dropout_rate=dropout_rate, + attention_temp=attention_temp, + bias=False, + **factory_kwargs) + + # Implementation of Feedforward model. + self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.glu = glu + if self.glu: + self.linear_glu = nn.Linear(dim_feedforward, + dim_feedforward, + **factory_kwargs) + self.dropout = nn.Dropout(dropout_rate) + self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) + + self.pre_ln = pre_ln + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.dropout1 = nn.Dropout(dropout_rate) + self.dropout2 = nn.Dropout(dropout_rate) + self.dropout3 = nn.Dropout(dropout_rate) + + self.activation = activation + + def forward( # pylint: disable=arguments-renamed + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + index: Optional[int] = None) -> Any: + r"""Pass the inputs (and mask) through the decoder layer. + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + decode: wether to use cache for autoregressive decoding or not. + max_len: maximum sequence length, necessary for decoding cache. + Shape: + see the docs in Transformer class. + """ + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + + x = tgt + if self.pre_ln: + sa_out, cache = self._sa_block( + self.norm1(x), + tgt_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=index) + x = x + sa_out + x = x + self._mha_block(self.norm2(x), memory, memory_mask) + x = x + self._ff_block(self.norm3(x)) + else: + sa_out, cache = self._sa_block( + x, + tgt_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=index) + x = self.norm1(x + sa_out) + x = self.norm2(x + self._mha_block(x, memory, memory_mask)) + x = self.norm3(x + self._ff_block(x)) + + return x, cache + + # Self-attention block: + def _sa_block( # pylint: disable=arguments-renamed + self, + x: Tensor, + attn_mask: Optional[Tensor], + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + index: Optional[int] = None) -> Any: + x, cache = self.self_attn( + x, + attn_mask=attn_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=index) + return self.dropout1(x), cache + + # Multihead attention block: + def _mha_block(self, x: Tensor, mem: Tensor, + attn_mask: Optional[Tensor]) -> Tensor: + x, _ = self.multihead_attn(x, mem, attn_mask=attn_mask) + return self.dropout2(x) + + # Feed forward block. + def _ff_block(self, inputs: Tensor) -> Tensor: + x = self.activation(self.linear1(inputs)) + if self.glu: + y = self.linear_glu(inputs) + x = x * y + x = self.linear2(self.dropout(x)) + return self.dropout3(x) + + +class MultiheadAttention(nn.Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. Supports self-attention and + encoder-decoder attention. + See `Attention Is All You Need `_. + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + Args: + embed_dim: Total dimension of the model. + num_heads: Number of parallel attention heads. Note that ``embed_dim`` will + be split across ``num_heads`` (i.e. each head will have dimension + ``embed_dim // num_heads``). + self_attn: Whether self attention or encoder-decoder attention is used. + Default: ``True``. + dropout_rate: Dropout probability on ``attn_output_weights``. + Default: ``0.0`` (no dropout_rate). + bias: If specified, adds bias to input / output projection layers. + Default: ``False``. + device: The device of the module. + dtype: The dtype of the module. + Examples:: + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, cache = multihead_attn(x) + """ + + def __init__(self, + embed_dim: int, + num_heads: int, + self_attn: bool = True, + dropout_rate: float = 0., + attention_temp: float = 1.0, + bias: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.self_attn = self_attn + self.dropout = dropout_rate + self.head_dim = embed_dim // num_heads + self.attention_temp = attention_temp + assert self.head_dim * num_heads == self.embed_dim, \ + 'embed_dim must be divisible by num_heads.' + + factory_kwargs = {'device': device, 'dtype': dtype} + if self_attn: + # Self-attention. + self.in_proj = nn.Linear( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + else: + # Encoder-decoder attention. + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + self.kv_proj = nn.Linear( + embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + + self._reset_parameters() + + def _reset_parameters(self): + """Initiate parameters in the MultiheadAttention module.""" + for module in self.modules(): + if isinstance(module, nn.Linear): + xavier_uniform_(module.weight) + if module.bias is not None: + normal_(module.bias, std=1e-6) + + def forward(self, + x: Tensor, + mem: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + index: Optional[int] = None) -> Any: + r""" + Args: + x: Batch of input sequences of shape + (batch size, sequence length, embedding dimensionality) for self + attention mechanism. See "Attention Is All You Need" for more details. + mem: Batch of input sequences of shape + (batch size, sequence length, embedding dimensionality) for + encoder-decoder attention. See "Attention Is All You Need" for more + details. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain + positions. Must be of shape :math:`(L, S)` or + :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the + batch size, :math:`L` is the target sequence length, and :math:`S` + is the source sequence length. A 2D mask will be broadcasted across + the batch while a 3D mask allows for a different mask for each entry + in the batch. Binary, byte, and float masks are supported. + For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a byte mask, + a non-zero value indicates that the corresponding position is not + allowed to attend. For a float mask, the mask values will be added to + the attention weight. + decode: wether to use cache for autoregressive decoding or not. + max_len: maximum sequence length, necessary for decoding cache. + cache: cache dictionary for autoregressive decoding. + index: index of the current decoding step, necessary for decoding cache. + Outputs: + - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where + :math:`L` is the target sequence length, :math:`N` is the batch size, + and :math:`E` is the embedding dimension ``embed_dim``. + - **cache** - For autoregressive decoding. + """ + # Shape: (batch size, sequence length, embedding dimensionality) + bsz, seq_len, embed_dim = x.size() + # In projection. + if self.self_attn: + q, k, v = self.in_proj(x).split(self.embed_dim, dim=2) + else: + q = self.q_proj(x) + k, v = self.kv_proj(mem).split(self.embed_dim, dim=2) + # This is 1 (!= seq_len) during autoreregressive decoding. + tgt_len = q.size(1) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + name = f'decoder.layers.{index}.self_attn' + loc_cache = cache[name] if decode and name in cache else None + if decode: + if loc_cache is None: + loc_cache = { + 'cached_key': + torch.zeros((bsz, max_len, embed_dim), + dtype=k.dtype, + device=k.device), + 'cached_value': + torch.zeros((bsz, max_len, embed_dim), + dtype=v.dtype, + device=v.device), + 'cache_index': + torch.tensor(0, dtype=torch.long, device=k.device), + } + cached_key = loc_cache['cached_key'] + cached_value = loc_cache['cached_value'] + cache_index = loc_cache['cache_index'] + # Shape check of cached keys against query input. + expected_shape = (bsz, 1, embed_dim) + if expected_shape != x.shape: + raise ValueError('Autoregressive cache shape error, expected query ' + f'shape {expected_shape} instead got {x.shape}.') + # Update key, value caches with our new 1d spatial slices. + cached_key[:, cache_index:cache_index + 1, :] = k + cached_value[:, cache_index:cache_index + 1, :] = v + k = cached_key + v = cached_value + cache_index += 1 + # Causal mask for cached decoder self-attention: + # our single query position should only attend to those key + # positions that have already been generated and cached, + # not the remaining zero elements. + if attn_mask is not None: + raise ValueError('Attention mask has to be None for decode == True.') + attn_mask = (torch.arange(max_len, device=k.device) >= + cache_index).reshape(1, max_len) + + # Update sequence length to account for complete sequence. + seq_len = k.size(1) + + # Rearrange q, k, v for multihead attention. + q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + # Check dtype and shape of attention mask. + if not decode and attn_mask is not None: + assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ + f'Float and bool dtypes are supported, not {attn_mask.dtype}.' + # Ensure attn_mask's dim is 3. + if attn_mask.dim() == 3: + correct_3d_size = (bsz * self.num_heads, tgt_len, seq_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError(f'The shape of attn_mask is {attn_mask.shape}, ' + f'but should be {correct_3d_size}.') + else: + raise RuntimeError( + f"attn_mask's dimension {attn_mask.dim()} is not supported") + # Reshape attention mask to be consistent with q, k, v. + attn_mask = attn_mask.view(bsz, self.num_heads, tgt_len, seq_len) + + # Convert attention mask to float. + if attn_mask is not None and attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, -1e10) + attn_mask = new_attn_mask + + # Adjust dropout_rate probability. + dropout_rate = self.dropout if self.training else 0.0 + + # Calculate attention. + q = self.attention_temp * q + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask, dropout_rate) + # Rearrange for output projection. + attn_output = attn_output.transpose(1, 2).contiguous().view( + bsz, tgt_len, embed_dim) + # Output projection. + attn_output = self.out_proj(attn_output) + + if decode: + cache[name] = loc_cache + + return attn_output, cache From 84b1bd19bb1083947adfa87a9488b074a6b170ac Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 11 Jun 2025 15:37:33 +0200 Subject: [PATCH 04/23] dropout fix on wmt --- .../wmt/wmt_pytorch/models_dropout.py | 168 ++++++++++-------- 1 file changed, 91 insertions(+), 77 deletions(-) diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py index 588d06abf..c5014d87d 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py @@ -112,14 +112,15 @@ def __init__(self, pre_ln: bool = True): super().__init__() if dropout_rate is None: - dropout_rate = 0.1 - self.pos_encoder = PositionalEncoding(d_model, dropout_rate) + self.dropout_rate = 0.1 + else: + self.dropout_rate = dropout_rate + self.pos_encoder = PositionalEncoding(d_model) self.shared_embedding = nn.Embedding(ntoken, d_model) self.encoder = Encoder(d_model, nhead, d_hid, nlayers, - dropout_rate, activation, glu, layer_norm_eps, @@ -129,7 +130,6 @@ def __init__(self, nhead, d_hid, nlayers, - dropout_rate, activation, glu, layer_norm_eps, @@ -158,7 +158,8 @@ def forward(self, targets_positions: Optional[Tensor] = None, inputs_segmentation: Optional[Tensor] = None, targets_segmentation: Optional[Tensor] = None, - decode: bool = False) -> Tensor: + decode: bool = False, + dropout_rate: Optional[float] = None) -> Tensor: """ Args: src: Tensor, shape [batch_size, seq_len] @@ -168,16 +169,22 @@ def forward(self, inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len] targets_segmentation: Optional[Tensor], shape [batch_size, seq_len] decode: bool + dropout_rate: Optional[float] Returns: output Tensor of shape [batch_size, seq_len, ntoken] """ if src.size(0) != tgt.size(0): raise RuntimeError('The batch size of src and tgt must be equal.') + + if dropout_rate is None: + dropout_rate = self.dropout_rate + memory = self.encoder( src, inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation) + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate) output = self.decoder( tgt, memory, @@ -185,7 +192,8 @@ def forward(self, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, - decode=decode) + decode=decode, + dropout_rate=dropout_rate) return output @@ -224,12 +232,15 @@ def __init__(self, self.enable_nested_tensor = enable_nested_tensor self.mask_check = mask_check - def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: + def forward(self, src: Tensor, + mask: Optional[Tensor] = None, + dropout_rate: Optional[float] = None) -> Tensor: """Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). mask: the mask for the src sequence (optional). + dropout_rate: the dropout probability (optional) Shape: see the docs in Transformer class. @@ -238,7 +249,7 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: convert_to_nested = False for mod in self.layers: - output = mod(output, src_mask=mask) + output = mod(output, src_mask=mask, dropout_rate=dropout_rate) if convert_to_nested: output = output.to_padded_tensor(0.) @@ -256,7 +267,6 @@ def __init__(self, nhead: int = 16, d_hid: int = 1024, nlayers: int = 6, - dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -270,7 +280,6 @@ def __init__(self, d_model, nhead, d_hid, - dropout_rate, activation=activation, glu=glu, layer_norm_eps=layer_norm_eps, @@ -283,12 +292,13 @@ def __init__(self, def forward(self, src: Tensor, inputs_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None) -> Tensor: + inputs_segmentation: Optional[Tensor] = None, + dropout_rate: Optional[float] = None) -> Tensor: src = src.to(torch.int) src_mask = make_src_mask(src, inputs_segmentation, self.nhead) src = self.shared_embedding(src) - src = self.pos_encoder(src, inputs_positions) - memory = self.encoder(src, mask=src_mask) + src = self.pos_encoder(src, inputs_positions, dropout_rate=dropout_rate) + memory = self.encoder(src, mask=src_mask, dropout_rate=dropout_rate) return memory @@ -299,7 +309,6 @@ def __init__(self, nhead: int = 16, d_hid: int = 1024, nlayers: int = 6, - dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -312,7 +321,6 @@ def __init__(self, self.decoder = TransformerDecoder(d_model, nhead, d_hid, - dropout_rate, activation, glu, layer_norm_eps, @@ -330,7 +338,8 @@ def forward( targets_segmentation: Optional[Tensor] = None, decode: bool = False, max_len: Optional[int] = None, - cache: Optional[dict] = None) -> Any: + cache: Optional[dict] = None, + dropout_rate: Optional[float] = None) -> Any: tgt = tgt.to(torch.int) tgt_mask, memory_mask = make_tgt_and_memory_mask( tgt, src, inputs_segmentation, targets_segmentation, @@ -338,7 +347,7 @@ def forward( if not decode: tgt = shift_right(tgt) tgt = self.shared_embedding(tgt) - tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache) + tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache, dropout_rate=dropout_rate) if decode: tgt, cache = tgt output = self.decoder( @@ -348,7 +357,8 @@ def forward( memory_mask=memory_mask, decode=decode, max_len=max_len, - cache=cache) + cache=cache, + dropout_rate=dropout_rate) if decode: output, cache = output normalize = math.sqrt(output.shape[-1]) @@ -362,10 +372,8 @@ class PositionalEncoding(nn.Module): def __init__(self, d_model: int, - dropout_rate: float = 0.1, max_len: int = 256): super().__init__() - self.dropout = nn.Dropout(p=dropout_rate) position = torch.arange(max_len).unsqueeze(1) scale_factor = -math.log(10000.0) / (d_model // 2 - 1) @@ -380,7 +388,8 @@ def forward( x: Tensor, inputs_positions: Optional[Tensor] = None, decode: bool = False, - cache: Optional[Dict[str, Dict[str, Tensor]]] = None + cache: Optional[Dict[str, Dict[str, Tensor]]] = None, + dropout_rate: Optional[float] = None ) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]: """ Args: @@ -403,14 +412,14 @@ def forward( } pe = self.pe[0, cache[name]['cache_index'], :] cache[name]['cache_index'] += 1 - return self.dropout(x + pe), cache + return F.dropout(x + pe, dropout_rate, self.training), cache if inputs_positions is None: # normal unpacked case: pe = self.pe[:, :x.size(1), :] else: # for packed data we need to use known position indices: pe = self.pe[0, inputs_positions, :] - return self.dropout(x + pe) + return F.dropout(x + pe, dropout_rate, self.training) # TransformerEncoderLayer and TransformerDecoderLayer are taken from: @@ -448,7 +457,6 @@ def __init__(self, d_model: int = 1024, nhead: int = 16, dim_feedforward: int = 1024, - dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -462,7 +470,6 @@ def __init__(self, d_model, nhead, self_attn=True, - dropout_rate=dropout_rate, attention_temp=attention_temp, bias=False, **factory_kwargs) @@ -472,50 +479,55 @@ def __init__(self, self.glu = glu if self.glu: self.linear_glu = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.dropout = nn.Dropout(dropout_rate) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.pre_ln = pre_ln self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.dropout1 = nn.Dropout(dropout_rate) - self.dropout2 = nn.Dropout(dropout_rate) self.activation = activation - def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: + def forward(self, + src: Tensor, + src_mask: Optional[Tensor] = None, + dropout_rate: Optional[float] = None) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). - + dropout_rate: the dropout probability value (optional). Shape: see the docs in Transformer class. """ x = src if self.pre_ln: - x = x + self._sa_block(self.norm1(x), src_mask) - x = x + self._ff_block(self.norm2(x)) + x = x + self._sa_block(self.norm1(x), src_mask, dropout_rate) + x = x + self._ff_block(self.norm2(x), dropout_rate) else: - x = self.norm1(x + self._sa_block(x, src_mask)) - x = self.norm2(x + self._ff_block(x)) + x = self.norm1(x + self._sa_block(x, src_mask, dropout_rate)) + x = self.norm2(x + self._ff_block(x, dropout_rate)) return x # Self-attention block: - def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tensor: - x, _ = self.self_attn(x, attn_mask=attn_mask) - return self.dropout1(x) + def _sa_block(self, + x: Tensor, + attn_mask: Optional[Tensor], + dropout_rate: Optional[float] = None) -> Tensor: + x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, training=self.training) # Feed forward block: - def _ff_block(self, inputs: Tensor) -> Tensor: + def _ff_block(self, + inputs: Tensor, + dropout_rate: Optional[float] = None) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) x = x * y - x = self.linear2(self.dropout(x)) - return self.dropout2(x) + x = self.linear2(F.dropout(x, dropout_rate, training=self.training)) + return F.dropout(x, dropout_rate, training=self.training) # Modified to use cache for autoregressive decoding and custom @@ -527,7 +539,6 @@ class TransformerDecoder(nn.Module): nhead: the number of heads in the multiheadattention models (default=16) d_hid: the dimension of the feedforward network model (default=1024) - dropout_rate: the dropout_rate value (default=0.1) layer_norm_eps: the eps value in layer normalization components (default=1e-6). decoder_layer: an instance of the TransformerDecoderLayer() class @@ -545,7 +556,6 @@ def __init__(self, d_model, nhead, d_hid, - dropout_rate, activation, glu, layer_norm_eps, @@ -558,7 +568,6 @@ def __init__(self, d_model, nhead, d_hid, - dropout_rate, activation, glu, layer_norm_eps=layer_norm_eps, @@ -575,7 +584,8 @@ def forward(self, memory_mask: Optional[Tensor] = None, decode: bool = False, max_len: Optional[int] = None, - cache: Optional[dict] = None) -> Any: + cache: Optional[dict] = None, + dropout_rate: Optional[float] = None) -> Any: r"""Pass the inputs (and mask) through the decoder layer in turn. Args: tgt: the sequence to the decoder (required). @@ -584,6 +594,7 @@ def forward(self, memory_mask: the mask for the memory sequence (optional). decode: whether to use cache for autoregressive decoding or not. max_len: maximum sequence length, necessary for decoding cache. + dropout_rate: the dropout probability value (optional) Shape: see the docs in Transformer class. """ @@ -598,7 +609,8 @@ def forward(self, decode=decode, max_len=max_len, cache=cache, - index=idx) + index=idx, + dropout_rate=dropout_rate) if self.norm is not None: output = self.norm(output) @@ -624,7 +636,6 @@ class TransformerDecoderLayer(nn.Module): nhead: the number of heads in the multiheadattention models (default=16). dim_feedforward: the dimension of the feedforward network model (default=1024). - dropout_rate: the dropout_rate value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components @@ -644,7 +655,6 @@ def __init__(self, d_model: int = 1024, nhead: int = 16, dim_feedforward: int = 1024, - dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -658,7 +668,6 @@ def __init__(self, d_model, nhead, self_attn=True, - dropout_rate=dropout_rate, attention_temp=attention_temp, bias=False, **factory_kwargs) @@ -666,7 +675,6 @@ def __init__(self, d_model, nhead, self_attn=False, - dropout_rate=dropout_rate, attention_temp=attention_temp, bias=False, **factory_kwargs) @@ -678,16 +686,12 @@ def __init__(self, self.linear_glu = nn.Linear(dim_feedforward, dim_feedforward, **factory_kwargs) - self.dropout = nn.Dropout(dropout_rate) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.pre_ln = pre_ln self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.dropout1 = nn.Dropout(dropout_rate) - self.dropout2 = nn.Dropout(dropout_rate) - self.dropout3 = nn.Dropout(dropout_rate) self.activation = activation @@ -700,7 +704,8 @@ def forward( # pylint: disable=arguments-renamed decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + index: Optional[int] = None, + dropout_rate: Optional[float] = None) -> Any: r"""Pass the inputs (and mask) through the decoder layer. Args: tgt: the sequence to the decoder layer (required). @@ -709,6 +714,7 @@ def forward( # pylint: disable=arguments-renamed memory_mask: the mask for the memory sequence (optional). decode: wether to use cache for autoregressive decoding or not. max_len: maximum sequence length, necessary for decoding cache. + dropout_rate: the dropout probability value (optional) Shape: see the docs in Transformer class. """ @@ -722,10 +728,11 @@ def forward( # pylint: disable=arguments-renamed decode=decode, max_len=max_len, cache=cache, - index=index) + index=index, + dropout_rate=dropout_rate) x = x + sa_out - x = x + self._mha_block(self.norm2(x), memory, memory_mask) - x = x + self._ff_block(self.norm3(x)) + x = x + self._mha_block(self.norm2(x), memory, memory_mask, dropout_rate) + x = x + self._ff_block(self.norm3(x), dropout_rate) else: sa_out, cache = self._sa_block( x, @@ -733,10 +740,11 @@ def forward( # pylint: disable=arguments-renamed decode=decode, max_len=max_len, cache=cache, - index=index) + index=index, + dropout_rate=dropout_rate) x = self.norm1(x + sa_out) - x = self.norm2(x + self._mha_block(x, memory, memory_mask)) - x = self.norm3(x + self._ff_block(x)) + x = self.norm2(x + self._mha_block(x, memory, memory_mask, dropout_rate)) + x = self.norm3(x + self._ff_block(x, dropout_rate)) return x, cache @@ -748,30 +756,38 @@ def _sa_block( # pylint: disable=arguments-renamed decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + index: Optional[int] = None, + dropout_rate: Optional[float] = None) -> Any: x, cache = self.self_attn( x, attn_mask=attn_mask, decode=decode, max_len=max_len, cache=cache, - index=index) - return self.dropout1(x), cache + index=index, + dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, self.training), cache # Multihead attention block: def _mha_block(self, x: Tensor, mem: Tensor, - attn_mask: Optional[Tensor]) -> Tensor: - x, _ = self.multihead_attn(x, mem, attn_mask=attn_mask) - return self.dropout2(x) + attn_mask: Optional[Tensor], + dropout_rate: Optional[float] = None) -> Tensor: + x, _ = self.multihead_attn( + x, + mem, + attn_mask=attn_mask, + dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, self.training) # Feed forward block. - def _ff_block(self, inputs: Tensor) -> Tensor: + def _ff_block(self, inputs: Tensor, + dropout_rate: Optional[float] = None) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) x = x * y - x = self.linear2(self.dropout(x)) - return self.dropout3(x) + x = self.linear2(F.dropout(x, dropout_rate, self.training)) + return F.dropout(x, dropout_rate, self.training) class MultiheadAttention(nn.Module): @@ -789,8 +805,6 @@ class MultiheadAttention(nn.Module): ``embed_dim // num_heads``). self_attn: Whether self attention or encoder-decoder attention is used. Default: ``True``. - dropout_rate: Dropout probability on ``attn_output_weights``. - Default: ``0.0`` (no dropout_rate). bias: If specified, adds bias to input / output projection layers. Default: ``False``. device: The device of the module. @@ -804,7 +818,6 @@ def __init__(self, embed_dim: int, num_heads: int, self_attn: bool = True, - dropout_rate: float = 0., attention_temp: float = 1.0, bias: bool = False, device: Optional[torch.device] = None, @@ -813,7 +826,6 @@ def __init__(self, self.embed_dim = embed_dim self.num_heads = num_heads self.self_attn = self_attn - self.dropout = dropout_rate self.head_dim = embed_dim // num_heads self.attention_temp = attention_temp assert self.head_dim * num_heads == self.embed_dim, \ @@ -848,7 +860,8 @@ def forward(self, decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + index: Optional[int] = None, + dropout_rate: Optional[float] = None) -> Any: r""" Args: x: Batch of input sequences of shape @@ -874,6 +887,7 @@ def forward(self, max_len: maximum sequence length, necessary for decoding cache. cache: cache dictionary for autoregressive decoding. index: index of the current decoding step, necessary for decoding cache. + dropout_rate: dropout probability on ``attn_output_weights``. Outputs: - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where :math:`L` is the target sequence length, :math:`N` is the batch size, @@ -963,12 +977,12 @@ def forward(self, attn_mask = new_attn_mask # Adjust dropout_rate probability. - dropout_rate = self.dropout if self.training else 0.0 + attn_dropout_rate = dropout_rate if self.training else 0.0 # Calculate attention. q = self.attention_temp * q attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask, dropout_rate) + q, k, v, attn_mask, attn_dropout_rate) # Rearrange for output projection. attn_output = attn_output.transpose(1, 2).contiguous().view( bsz, tgt_len, embed_dim) From af08bb91f93266d12e6ffeab7045b7c57cb9143f Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 11 Jun 2025 17:19:44 +0200 Subject: [PATCH 05/23] fix dropout, ALL tested --- .../criteo1tb_pytorch/models_dropout.py | 35 +- .../models_functional_dropout.py | 308 ------------------ .../fastmri/fastmri_pytorch/models_dropout.py | 13 +- .../librispeech_pytorch/models_dropout.py | 2 +- .../librispeech_pytorch/models_dropout.py | 2 +- 5 files changed, 37 insertions(+), 323 deletions(-) delete mode 100644 algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py index 8042ec31e..d8d7393e4 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py @@ -79,6 +79,10 @@ def __init__(self, self.mlp_bottom_dims = mlp_bottom_dims self.mlp_top_dims = mlp_top_dims self.embed_dim = embed_dim + if dropout_rate is None: + self.dropout_rate = 0.0 + else: + self.dropout_rate = dropout_rate # Ideally, we should use the pooled embedding implementation from # `TorchRec`. However, in order to have identical implementation @@ -127,17 +131,16 @@ def __init__(self, block.append(nn.Linear(fan_in, fan_out)) if layer_idx < (num_layers_top - 1): block.append(nn.ReLU(inplace=True)) - if (dropout_rate is not None and dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - block.append(CustomDropout()) # (nico) - block = SequentialWithDropout(*block) # (nico) + if layer_idx == num_layers_top - 2: + block.append(CustomDropout()) + block = SequentialWithDropout(*block) if (layer_idx != 0) and (layer_idx != num_layers_top - 1): block = DenseBlockWithDropout(block, resnet=True) else: block = DenseBlockWithDropout(block) mlp_top_blocks.append(block) fan_in = fan_out - self.top_mlp = SequentialWithDropout(*mlp_top_blocks) # (nico) + self.top_mlp = SequentialWithDropout(*mlp_top_blocks) for module in self.top_mlp.modules(): if isinstance(module, nn.Linear): @@ -149,7 +152,10 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x, dropout_rate): + def forward(self, x, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.dropout_rate + batch_size = x.shape[0] dense_features, sparse_features = torch.split( @@ -201,6 +207,11 @@ def __init__(self, self.mlp_top_dims = mlp_top_dims self.embed_dim = embed_dim self.embedding_init_multiplier = embedding_init_multiplier + self.dropout_rate = dropout_rate + if dropout_rate is None: + self.dropout_rate = 0.0 + else: + self.dropout_rate = dropout_rate # Ideally, we should use the pooled embedding implementation from # `TorchRec`. However, in order to have identical implementation @@ -253,10 +264,9 @@ def __init__(self, top_mlp_layers.append(nn.ReLU(inplace=True)) if use_layer_norm: top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) - if (dropout_rate is not None and dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - top_mlp_layers.append(CustomDropout()) # (nico) - self.top_mlp = SequentialWithDropout(*top_mlp_layers) # (nico) + if layer_idx == num_layers_top - 2: + top_mlp_layers.append(CustomDropout()) + self.top_mlp = SequentialWithDropout(*top_mlp_layers) if use_layer_norm: self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) else: @@ -271,7 +281,10 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x, dropout_rate): + def forward(self, x, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.dropout_rate + batch_size = x.shape[0] dense_features, sparse_features = torch.split( diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py deleted file mode 100644 index 346e0e72a..000000000 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py +++ /dev/null @@ -1,308 +0,0 @@ -"""Pytorch implementation of DLRM-Small.""" - -import math - -import torch -import torch.nn.functional as F -from torch import nn - - -class DenseBlock(nn.Module): - """Dense block with optional residual connection.""" "" - - def __init__(self, module, resnet=False): - super().__init__() - self.module = module - self.resnet = resnet - - def forward(self, x): - if self.resnet: - return self.module(x) + x - else: - return self.module(x) - - -class DotInteract(nn.Module): - """Performs feature interaction operation between dense or sparse features.""" - - def __init__(self, num_sparse_features): - super().__init__() - self.triu_indices = torch.triu_indices(num_sparse_features + 1, - num_sparse_features + 1) - - def forward(self, dense_features, sparse_features): - combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features), - dim=1) - interactions = torch.bmm(combined_values, - torch.transpose(combined_values, 1, 2)) - interactions_flat = interactions[:, - self.triu_indices[0], - self.triu_indices[1]] - return torch.cat((dense_features, interactions_flat), dim=1) - - -class DLRMResNet(nn.Module): - """Define a DLRM-Small model. - - Parameters: - vocab_size: vocab size of embedding table. - num_dense_features: number of dense features as the bottom mlp input. - mlp_bottom_dims: dimensions of dense layers of the bottom mlp. - mlp_top_dims: dimensions of dense layers of the top mlp. - embed_dim: embedding dimension. - """ - - def __init__(self, - vocab_size, - num_dense_features=13, - num_sparse_features=26, - mlp_bottom_dims=(256, 256, 256), - mlp_top_dims=(256, 256, 256, 256, 1), - embed_dim=128, - # dropout_rate=0.0, - use_layer_norm=False, - embedding_init_multiplier=None): - super().__init__() - self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) - self.num_dense_features = num_dense_features - self.num_sparse_features = num_sparse_features - self.mlp_bottom_dims = mlp_bottom_dims - self.mlp_top_dims = mlp_top_dims - self.embed_dim = embed_dim - - # Ideally, we should use the pooled embedding implementation from - # `TorchRec`. However, in order to have identical implementation - # with that of Jax, we define a single embedding matrix. - num_chunks = 4 - assert vocab_size % num_chunks == 0 - self.embedding_table_chucks = [] - scale = 1.0 / torch.sqrt(self.vocab_size) - for i in range(num_chunks): - chunk = nn.Parameter( - torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) - chunk.data.uniform_(0, 1) - chunk.data = scale * chunk.data - self.register_parameter(f'embedding_chunk_{i}', chunk) - self.embedding_table_chucks.append(chunk) - - input_dim = self.num_dense_features - bot_mlp_blocks = [] - for layer_idx, dense_dim in enumerate(self.mlp_bottom_dims): - block = [] - block.append(nn.Linear(input_dim, dense_dim)) - block.append(nn.ReLU(inplace=True)) - block = nn.Sequential(*block) - if layer_idx > 0: - block = DenseBlock(block, resnet=True) - else: - block = DenseBlock(block) - bot_mlp_blocks.append(block) - input_dim = dense_dim - self.bot_mlp = nn.Sequential(*bot_mlp_blocks) - - for module in self.bot_mlp.modules(): - if isinstance(module, nn.Linear): - limit = math.sqrt(6. / (module.in_features + module.out_features)) - nn.init.uniform_(module.weight.data, -limit, limit) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - # Number of sparse features = 26 - fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1] - num_layers_top = len(self.mlp_top_dims) - mlp_top_blocks = [] - for layer_idx, fan_out in enumerate(self.mlp_top_dims): - block = [] - block.append(nn.Linear(fan_in, fan_out)) - if layer_idx < (num_layers_top - 1): - block.append(nn.ReLU(inplace=True)) - # if (dropout_rate is not None and dropout_rate > 0.0 and - # layer_idx == num_layers_top - 2): - # block.append(nn.Dropout(p=dropout_rate)) - block = nn.Sequential(*block) - if (layer_idx != 0) and (layer_idx != num_layers_top - 1): - block = DenseBlock(block, resnet=True) - else: - block = DenseBlock(block) - mlp_top_blocks.append(block) - fan_in = fan_out - self.top_mlp = nn.Sequential(*mlp_top_blocks) - - for module in self.top_mlp.modules(): - if isinstance(module, nn.Linear): - nn.init.normal_( - module.weight.data, - 0., - math.sqrt(2. / (module.in_features + module.out_features))) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - def forward(self, x, dropout_rate): - batch_size = x.shape[0] - - dense_features, sparse_features = torch.split( - x, [self.num_dense_features, self.num_sparse_features], 1) - - # Bottom MLP. - embedded_dense = self.bot_mlp(dense_features) - - # Sparse feature processing. - sparse_features = sparse_features.to(dtype=torch.int32) - idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size - embedding_table = torch.cat(self.embedding_table_chucks, dim=0) - embedded_sparse = embedding_table[idx_lookup] - embedded_sparse = torch.reshape(embedded_sparse, - [batch_size, 26 * self.embed_dim]) - top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) - - # Final MLP (horrible!!). - h = top_mlp_input - num_layers_top = len(self.mlp_top_dims) - for layer_idx, block in enumerate(self.top_mlp): - # block.module is nn.Sequential([...]) - seq = block.module - # 1) linear - out = seq[0](h) - # 2) ReLU (if present) - if layer_idx < (num_layers_top - 1): - out = seq[1](out) - # 3) functional dropout at penult layer - if dropout_rate > 0 and layer_idx == num_layers_top - 2: - out = F.dropout(out, dropout_rate, training=self.training) - # 4) wrap in residual if needed - h = out + h if block.resnet else out - return h - - -class DlrmSmall(nn.Module): - """Define a DLRM-Small model. - - Parameters: - vocab_size: vocab size of embedding table. - num_dense_features: number of dense features as the bottom mlp input. - mlp_bottom_dims: dimensions of dense layers of the bottom mlp. - mlp_top_dims: dimensions of dense layers of the top mlp. - embed_dim: embedding dimension. - """ - - def __init__(self, - vocab_size, - num_dense_features=13, - num_sparse_features=26, - mlp_bottom_dims=(512, 256, 128), - mlp_top_dims=(1024, 1024, 512, 256, 1), - embed_dim=128, - # dropout_rate=0.0, - use_layer_norm=False, - embedding_init_multiplier=None): - super().__init__() - self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) - self.num_dense_features = num_dense_features - self.num_sparse_features = num_sparse_features - self.mlp_bottom_dims = mlp_bottom_dims - self.mlp_top_dims = mlp_top_dims - self.embed_dim = embed_dim - self.embedding_init_multiplier = embedding_init_multiplier - - # Ideally, we should use the pooled embedding implementation from - # `TorchRec`. However, in order to have identical implementation - # with that of Jax, we define a single embedding matrix. - num_chunks = 4 - assert vocab_size % num_chunks == 0 - self.embedding_table_chucks = [] - - if self.embedding_init_multiplier is None: - scale = 1.0 / torch.sqrt(self.vocab_size) - else: - scale = self.embedding_init_multiplier - - for i in range(num_chunks): - chunk = nn.Parameter( - torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) - chunk.data.uniform_(0, 1) - chunk.data = scale * chunk.data - self.register_parameter(f'embedding_chunk_{i}', chunk) - self.embedding_table_chucks.append(chunk) - - input_dim = self.num_dense_features - bottom_mlp_layers = [] - for dense_dim in self.mlp_bottom_dims: - bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim)) - bottom_mlp_layers.append(nn.ReLU(inplace=True)) - if use_layer_norm: - bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) - input_dim = dense_dim - self.bot_mlp = nn.Sequential(*bottom_mlp_layers) - for module in self.bot_mlp.modules(): - if isinstance(module, nn.Linear): - limit = math.sqrt(6. / (module.in_features + module.out_features)) - nn.init.uniform_(module.weight.data, -limit, limit) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) - - # TODO: Write down the formula here instead of the constant. - input_dims = 506 - num_layers_top = len(self.mlp_top_dims) - top_mlp_layers = [] - for layer_idx, fan_out in enumerate(self.mlp_top_dims): - fan_in = input_dims if layer_idx == 0 \ - else self.mlp_top_dims[layer_idx - 1] - top_mlp_layers.append(nn.Linear(fan_in, fan_out)) - if layer_idx < (num_layers_top - 1): - top_mlp_layers.append(nn.ReLU(inplace=True)) - if use_layer_norm: - top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) - # if (dropout_rate is not None and dropout_rate > 0.0 and - # layer_idx == num_layers_top - 2): - # top_mlp_layers.append(nn.Dropout(p=dropout_rate)) - self.top_mlp = nn.Sequential(*top_mlp_layers) - if use_layer_norm: - self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) - else: - self.embed_ln = None - for module in self.top_mlp.modules(): - if isinstance(module, nn.Linear): - nn.init.normal_( - module.weight.data, - 0., - math.sqrt(2. / (module.in_features + module.out_features))) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - def forward(self, x, dropout_rate): - batch_size = x.shape[0] - - dense_features, sparse_features = torch.split( - x, [self.num_dense_features, self.num_sparse_features], 1) - - # Bottom MLP. - embedded_dense = self.bot_mlp(dense_features) - - # Sparse feature processing. - sparse_features = sparse_features.to(dtype=torch.int32) - idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size - embedding_table = torch.cat(self.embedding_table_chucks, dim=0) - embedded_sparse = embedding_table[idx_lookup] - embedded_sparse = torch.reshape(embedded_sparse, - [batch_size, -1, self.embed_dim]) - if self.embed_ln: - embedded_sparse = self.embed_ln(embedded_sparse) - # Dot product interactions. - concatenated_dense = self.dot_interact( - dense_features=embedded_dense, sparse_features=embedded_sparse) - - # Final MLP: run each layer, and after the penultimate layer do functional dropout - h = concatenated_dense - N = len(self.top_mlp) - for idx, layer in enumerate(self.top_mlp): - h = layer(h) - # insert dropout exactly where nn.Dropout used to live - if dropout_rate > 0 and idx == N - 2: - h = F.dropout(h, dropout_rate, training=self.training) - return h diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py index 5862f6352..8954cb737 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py @@ -29,6 +29,7 @@ def __init__(self, out_chans: int = 1, num_channels: int = 32, num_pool_layers: int = 4, + dropout_rate: Optional[float] = 0.0, use_tanh: bool = False, use_layer_norm: bool = False) -> None: super().__init__() @@ -37,6 +38,11 @@ def __init__(self, self.out_chans = out_chans self.num_channels = num_channels self.num_pool_layers = num_pool_layers + if dropout_rate is None: + self.dropout_rate = 0.0 + else: + self.dropout_rate = dropout_rate + self.down_sample_layers = nn.ModuleList([ ConvBlock(in_chans, num_channels, @@ -72,7 +78,10 @@ def __init__(self, if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): init_utils.pytorch_default_init(m) - def forward(self, x: Tensor, dropout_rate: float) -> Tensor: + def forward(self, x: Tensor, dropout_rate: Optional[float] = None) -> Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate + stack = [] output = x @@ -136,7 +145,7 @@ def __init__(self, CustomDropout2d(), ) - def forward(self, x: Tensor, dropout_rate: float) -> Tensor: + def forward(self, x: Tensor, dropout_rate: Optional[float] = None) -> Tensor: return self.conv_layers(x, dropout_rate) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py index da66dfe43..9ff662fb8 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py @@ -510,7 +510,7 @@ def forward(self, inputs, input_paddings, dropout_rate=None): outputs, output_paddings = self.preprocessor(outputs, output_paddings) if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings) + outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) for conformer in self.conformers: outputs = conformer(outputs, output_paddings, dropout_rate) outputs = self.ln(outputs) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py index e68a820ed..8797aa578 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py @@ -383,7 +383,7 @@ def forward(self, inputs, input_paddings, dropout_rate=None): for idx in range(self.config.num_ffn_layers): if self.config.enable_residual_connections: - outputs = outputs + self.ffns[idx](outputs, output_paddings) + outputs = outputs + self.ffns[idx](outputs, output_paddings, dropout_rate) else: outputs = self.ffns[idx](outputs, output_paddings, dropout_rate) From 7a6651a69953af4655e0f7c50b8c7fefe71aa9ca Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 11 Jun 2025 17:20:39 +0200 Subject: [PATCH 06/23] add dropout equivalence tests --- .../test_model_equivalence.py | 77 ++++++++++++ .../fastmri_pytorch/test_model_equivalence.py | 98 +++++++++++++++ .../test_model_equivalence.py | 112 ++++++++++++++++++ .../test_model_equivalence.py | 91 ++++++++++++++ .../test_model_equivalence.py | 89 ++++++++++++++ .../ogbg_pytorch/test_model_equivalence.py | 76 ++++++++++++ .../wmt_pytorch/test_model_equivalence.py | 83 +++++++++++++ 7 files changed, 626 insertions(+) create mode 100644 tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py create mode 100644 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py create mode 100644 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py create mode 100644 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py create mode 100644 tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py create mode 100644 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py create mode 100644 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..b9b1232ef --- /dev/null +++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py @@ -0,0 +1,77 @@ +""" +Runs fwd pass with random input for our DLRM models and compares outputs. +Run it as: + python3 tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch +import os + +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import ( + DLRMResNet as OriginalDLRMResNet, + DlrmSmall as OriginalDlrmSmall, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models_dropout import ( + DLRMResNet as CustomDLRMResNet, + DlrmSmall as CustomDlrmSmall, +) + + +BATCH, DENSE, SPARSE = 16, 13, 26 +FEATURES = DENSE + SPARSE +VOCAB = 1000 +DEVICE = 'cuda' +TORCH_COMPILE = False +SEED = 1996 + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) + +class ModelEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='DLRMResNet, p=None', model='dlrm_resnet', dropout_rate=None), + dict(testcase_name='DlrmSmall, p=None', model='dlrm_small', dropout_rate=None), + dict(testcase_name='DLRMResNet, p=0.0', model='dlrm_resnet', dropout_rate=0.0), + dict(testcase_name='DlrmSmall, p=0.0', model='dlrm_small', dropout_rate=0.0), + dict(testcase_name='DLRMResNet, p=0.1', model='dlrm_resnet', dropout_rate=0.1), + dict(testcase_name='DlrmSmall, p=0.1', model='dlrm_small', dropout_rate=0.1), + dict(testcase_name='DLRMResNet, p=1.0', model='dlrm_resnet', dropout_rate=1.0), + dict(testcase_name='DlrmSmall, p=1.0', model='dlrm_small', dropout_rate=1.0), + ) + def test_forward(self, model, dropout_rate): + OrigCls, CustCls = ( + (OriginalDLRMResNet, CustomDLRMResNet) + if model == 'dlrm_resnet' + else (OriginalDlrmSmall, CustomDlrmSmall) + ) + + # Test initalizing custom model with a None dropout_rate + for custom_init_dropout_rate in [dropout_rate, None]: + + torch.manual_seed(SEED) + orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate) + orig.to(DEVICE) + + torch.manual_seed(SEED) + cust = CustCls(vocab_size=VOCAB, dropout_rate=custom_init_dropout_rate) + cust.to(DEVICE) + + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + x = torch.randn(BATCH, FEATURES, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(SEED); y1 = orig(x) + torch.manual_seed(SEED); y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..6339ff21b --- /dev/null +++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py @@ -0,0 +1,98 @@ +""" +Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. +Run it as: + python3 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch +import os + +from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet as OriginalUNet +from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import UNet as CustomUNet + +BATCH, IN_CHANS, H, W = 4, 1, 256, 256 +OUT_CHANS, C, LAYERS = 1, 32, 4 +DEVICE = 'cuda' +TORCH_COMPILE = True +SEED = 1996 + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) + +class FastMRIModeEquivalenceTest(parameterized.TestCase): + + def fwd_pass(self, orig, cust, dropout_rate): + x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(0) + y1 = orig(x) + torch.manual_seed(0) + y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name='p=None', dropout_rate=None), + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.1', dropout_rate=0.1), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_dropout_values(self, dropout_rate): + """Test different values of dropout_rate.""" + + # Test initalizing custom model with a None dropout_rate + for custom_init_dropout_rate in [dropout_rate, None]: + + torch.manual_seed(SEED) + orig = OriginalUNet( + IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomUNet( + IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=custom_init_dropout_rate + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + self.fwd_pass(orig, cust, dropout_rate) + + + @parameterized.named_parameters( + dict(testcase_name='default', use_tanh=False, use_layer_norm=False), + dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False), + dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True), + dict(testcase_name='both', use_tanh=True, use_layer_norm=True), + ) + def test_arch_setups(self, use_tanh, use_layer_norm): + """Test different architecture configurations, fixed dropout_rate.""" + dropout_rate = 0.1 + + torch.manual_seed(SEED) + orig = OriginalUNet( + IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate, + use_tanh=use_tanh, use_layer_norm=use_layer_norm + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomUNet( + IN_CHANS, OUT_CHANS, C, LAYERS, + use_tanh=use_tanh, use_layer_norm=use_layer_norm + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + self.fwd_pass(orig, cust, dropout_rate) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..56644f152 --- /dev/null +++ b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py @@ -0,0 +1,112 @@ +""" +Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. +Run it as: + python3 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch +import os +import itertools + +from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import ViT as OriginalVit +from algoperf.workloads.imagenet_vit.imagenet_pytorch.models_dropout import ViT as CustomVit + +# Model / test hyper-params +BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W) +WIDTH, DEPTH, HEADS = 256, 4, 8 +DROPOUT_RATE = None +DEVICE = 'cuda' +SEED = 1996 + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) + +class ImageNetVitModeEquivalenceTest(parameterized.TestCase): + + def fwd_pass(self, orig, cust, dropout_rate): + x = torch.randn(BATCH, C, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name='p=None', dropout_rate=None), + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.1', dropout_rate=0.1), + dict(testcase_name='p=0.6', dropout_rate=0.6), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_dropout_values(self, dropout_rate): + """Test different dropout_values.""" + + # Test initalizing custom model with a None dropout_rate + for custom_init_dropout_rate in [dropout_rate, None]: + + torch.manual_seed(SEED) + orig = OriginalVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + dropout_rate=dropout_rate, + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + dropout_rate=custom_init_dropout_rate, + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + self.fwd_pass(orig, cust, dropout_rate) + + + @parameterized.named_parameters([ + dict( + testcase_name=f"GLU={use_glu}_LN={use_post_ln}_MAP={use_map}", + use_glu=use_glu, + use_post_ln=use_post_ln, + use_map=use_map, + ) + for use_glu, use_post_ln, use_map in itertools.product([False, True], repeat=3) + ]) + def test_arch(self, use_glu, use_post_ln, use_map): + """Test different architecture configurations, fixed dropout_rate.""" + dropout_rate = 0.1 + + torch.manual_seed(SEED) + orig = OriginalVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + use_glu=use_glu, + use_post_layer_norm=use_post_ln, + use_map=use_map, + dropout_rate=dropout_rate, + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + use_glu=use_glu, + use_post_layer_norm=use_post_ln, + use_map=use_map, + dropout_rate=None, + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + self.fwd_pass(orig, cust, dropout_rate) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..19525a98b --- /dev/null +++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py @@ -0,0 +1,91 @@ +""" +Runs fwd pass with random input for LIBRISPEECH Conformer models and compares outputs. +Run with: + python3 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py + +`dropout_rate` controls the following args: +- `attention_residual_dropout_rate` (if None, 0.1 +- `conv_residual_dropout_rate` (if None, 0.0) +- `feed_forward_residual_dropout_rate` (if None, 0.1) +- `input_dropout_rate` (if None, 0.1) +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch +import os + +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( + # ConformerConfig, + ConformerEncoderDecoder as OriginalConf +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import( + ConformerEncoderDecoder as CustomConf, + ConformerConfig, +) + +B, T = 32, 36_000 +DEVICE = 'cuda' + +os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(mode=True) +SEED = 1996 + + +class ConformerEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='p=None', dropout_rate=None), + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.2', dropout_rate=0.2), + dict(testcase_name='p=0.7', dropout_rate=0.7), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_forward(self, dropout_rate): + + # Test initalizing custom model with a None dropout_rate + for custom_init_dropout_rate in [None, dropout_rate]: + + torch.manual_seed(SEED) + orig = OriginalConf( + ConformerConfig( + num_encoder_layers=3, + attention_residual_dropout_rate=dropout_rate, + conv_residual_dropout_rate=dropout_rate, + feed_forward_residual_dropout_rate=dropout_rate, + input_dropout_rate=dropout_rate, + )).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomConf( + ConformerConfig( + num_encoder_layers=3, + attention_residual_dropout_rate=custom_init_dropout_rate, + conv_residual_dropout_rate=custom_init_dropout_rate, + feed_forward_residual_dropout_rate=custom_init_dropout_rate, + input_dropout_rate=custom_init_dropout_rate + )).to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + y1, p1 = orig(x, paddings) + + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..e31f4a7eb --- /dev/null +++ b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py @@ -0,0 +1,89 @@ +""" +Runs fwd pass with random input for LIBRISPEECH Deepspeech models and compares outputs. +Run with: + python3 tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py + +`dropout_rate` controls the following args: +- `input_dropout_rate` (if None, 0.1 +- `feed_forward_dropout_rate` (if None, 0.1) +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch +import os + +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import ( + DeepspeechEncoderDecoder as OriginalModel +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import( + DeepspeechEncoderDecoder as CustomModel, + DeepspeechConfig, +) + +B, T = 32, 30_000 +DEVICE = 'cuda' +TORCH_COMPILE = True + +os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(mode=True) +SEED = 1996 + + +class DeepSpeechEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='p=None', dropout_rate=None), + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.2', dropout_rate=0.2), + dict(testcase_name='p=0.7', dropout_rate=0.7), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_forward(self, dropout_rate): + + # Test initalizing custom model with a None dropout_rate + for custom_init_dropout_rate in [None, dropout_rate]: + + torch.manual_seed(SEED) + orig = OriginalModel( + DeepspeechConfig( + num_lstm_layers=2, + num_ffn_layers=2, + input_dropout_rate=dropout_rate, + feed_forward_dropout_rate=dropout_rate, + )).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomModel(DeepspeechConfig( + num_lstm_layers=2, + num_ffn_layers=2, + input_dropout_rate=custom_init_dropout_rate, + feed_forward_dropout_rate=custom_init_dropout_rate, + )).to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + y1, p1 = orig(x, paddings) + + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..cc1857705 --- /dev/null +++ b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py @@ -0,0 +1,76 @@ +""" +Runs fwd pass with random graphs for OGBG GNN models and compares outputs. +Run with: + python3 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch, os, random, numpy as np +from jraph import GraphsTuple + +from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as OriginalModel +from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import GNN as CustomModel + +B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph +NODE_FDIM, EDGE_FDIM = 9, 3 # expected feature dims +DEVICE = 'cuda' + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) +SEED = 1996 + + +def _rand_graph(): + total_nodes, total_edges = B * N, B * E + nodes = torch.randn(total_nodes, NODE_FDIM, device=DEVICE) + edges = torch.randn(total_edges, EDGE_FDIM, device=DEVICE) + senders, receivers = [], [] + for i in range(B): + offset = i * N + s = torch.randint(N, (E,), device=DEVICE) + offset + r = torch.randint(N, (E,), device=DEVICE) + offset + senders.append(s), receivers.append(r) + senders = torch.cat(senders); receivers = torch.cat(receivers) + n_node = torch.full((B,), N, device=DEVICE, dtype=torch.int32) + n_edge = torch.full((B,), E, device=DEVICE, dtype=torch.int32) + return GraphsTuple(nodes, edges, receivers, senders, None, n_node, n_edge) + + +class GNNEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='None', dropout_rate=None), + dict(testcase_name='0.0', dropout_rate=0.0), + dict(testcase_name='0.2', dropout_rate=0.2), + dict(testcase_name='0.7', dropout_rate=0.7), + dict(testcase_name='1.0', dropout_rate=1.0), + ) + def test_forward(self, dropout_rate): + + # Test initalizing custom model with a None dropout_rate + for custom_init_dropout_rate in [None, dropout_rate]: + + orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE) + cust = CustomModel(dropout_rate=custom_init_dropout_rate).to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights + + graph = _rand_graph() + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(graph) + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(graph, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..9aca717d9 --- /dev/null +++ b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py @@ -0,0 +1,83 @@ +""" +Runs fwd pass with random input for WMT Transformer models and compares outputs. +Run with: + python3 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch, os, random, numpy as np + +from algoperf.workloads.wmt.wmt_pytorch.models import ( + Transformer as OriginalModel, +) +from algoperf.workloads.wmt.wmt_pytorch.models_dropout import ( + Transformer as CustomModel, +) + +B, SRC_LEN, TGT_LEN, NTOK = 16, 80, 80, 32_000 +DEVICE = "cuda" +SEED = 1996 + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) + + +def _rand_tokens(bs, seqlen): + return torch.randint(1, NTOK, (bs, seqlen), device=DEVICE) + + +class TransformerEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + # NOTE: removed dropout=1.0 will generate nan in scaled_dot_product_attention + + dict(testcase_name="None", dropout_rate=None, compile=False), + dict(testcase_name="0.0", dropout_rate=0.0, compile=False), + dict(testcase_name="0.2", dropout_rate=0.2, compile=False), + dict(testcase_name="0.7", dropout_rate=0.7, compile=False), + + dict(testcase_name="p=None, compile", dropout_rate=None, compile=True), + dict(testcase_name="p=0.0, compile", dropout_rate=0.0, compile=True), + dict(testcase_name="p=0.2, compile", dropout_rate=0.2, compile=True), + dict(testcase_name="p=0.7, compile", dropout_rate=0.7, compile=True), + ) + def test_forward(self, dropout_rate, compile): + + # Test initalizing custom model with a None dropout_rate + for custom_init_dropout_rate in [None, dropout_rate]: + + orig = OriginalModel( + dropout_rate=dropout_rate, + attention_dropout_rate=dropout_rate + ).to(DEVICE) + cust = CustomModel( + dropout_rate=custom_init_dropout_rate + ).to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if compile: + orig = torch.compile(orig) + cust = torch.compile(cust) + + src = _rand_tokens(B, SRC_LEN) + tgt = _rand_tokens(B, TGT_LEN) + + for mode in ("train", "eval"): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(src, tgt) + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(src, tgt, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + + +if __name__ == "__main__": + absltest.main() From a7ff3d1ab09c57807f4d0c7b219803407c085c69 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 11 Jun 2025 17:39:46 +0200 Subject: [PATCH 07/23] moved custom dropout to pytorch_utils --- algoperf/pytorch_utils.py | 38 ++++++++++++++++++ .../criteo1tb_pytorch/models_dropout.py | 2 +- algoperf/workloads/dropout_modules.py | 40 ------------------- .../fastmri/fastmri_pytorch/models_dropout.py | 2 +- .../ogbg/ogbg_pytorch/models_dropout.py | 2 +- 5 files changed, 41 insertions(+), 43 deletions(-) delete mode 100644 algoperf/workloads/dropout_modules.py diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index 4a674985d..4af77088e 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -5,7 +5,10 @@ import jax import tensorflow as tf import torch +from torch import Tensor +import torch.nn as nn import torch.distributed as dist +import torch.nn.functional as F from algoperf import spec from algoperf.profiler import Profiler @@ -77,3 +80,38 @@ def update_batch_norm_fn(module: spec.ParameterContainer, module.momentum = 0.0 elif hasattr(module, 'momentum_backup'): module.momentum = module.momentum_backup + + +class CustomDropout(nn.Module): + """A module around torch.nn.functional.dropout.""" + def __init__(self): + super().__init__() + self._supports_custom_dropout = True + + def forward(self, input: Tensor, p: float) -> Tensor: + return F.dropout(input, p, training=self.training) + + +class CustomDropout2d(nn.Module): + """A module around torch.nn.functional.dropout2d.""" + def __init__(self): + super().__init__() + self._supports_custom_dropout = True + + def forward(self, input: Tensor, p: float) -> Tensor: + return F.dropout2d(input, p, training=self.training) + + +class SequentialWithDropout(nn.Sequential): + """Sequential of modules with dropout.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._supports_custom_dropout = True + + def forward(self, x: Tensor, p: float) -> Tensor: + for module in self: + if getattr(module, '_supports_custom_dropout', False): + x = module(x, p) + else: + x = module(x) + return x diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py index d8d7393e4..065ebd1f8 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from torch import nn -from algoperf.workloads.dropout_modules import CustomDropout, SequentialWithDropout +from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout class DenseBlock(nn.Module): diff --git a/algoperf/workloads/dropout_modules.py b/algoperf/workloads/dropout_modules.py deleted file mode 100644 index 6cec3f7ad..000000000 --- a/algoperf/workloads/dropout_modules.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Custom classes to support a dynamic modulized dropout, see issue??TODO""" - -from torch import Tensor -from torch import nn -import torch.nn.functional as F - - -class CustomDropout(nn.Module): - """A module around torch.nn.functional.dropout.""" - def __init__(self): - super().__init__() - self._supports_custom_dropout = True - - def forward(self, input: Tensor, p: float) -> Tensor: - return F.dropout(input, p, training=self.training) - - -class CustomDropout2d(nn.Module): - """A module around torch.nn.functional.dropout2d.""" - def __init__(self): - super().__init__() - self._supports_custom_dropout = True - - def forward(self, input: Tensor, p: float) -> Tensor: - return F.dropout2d(input, p, training=self.training) - - -class SequentialWithDropout(nn.Sequential): - """Sequential of modules with dropout.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._supports_custom_dropout = True - - def forward(self, x: Tensor, p: float) -> Tensor: - for module in self: - if getattr(module, '_supports_custom_dropout', False): - x = module(x, p) - else: - x = module(x) - return x diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py index 8954cb737..260cb7e44 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py @@ -13,7 +13,7 @@ from torch.nn import functional as F from algoperf import init_utils -from algoperf.workloads.dropout_modules import CustomDropout2d, SequentialWithDropout +from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py index 1d89ea9e7..b86b88caa 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py @@ -9,7 +9,7 @@ from torch import nn from algoperf import init_utils -from algoperf.workloads.dropout_modules import CustomDropout, SequentialWithDropout +from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout def _make_mlp(in_dim, hidden_dims, activation_fn): From f26ab02d987a839c481dc74d3d15c1d920165d38 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 11 Jun 2025 18:14:26 +0200 Subject: [PATCH 08/23] remove aux_dropout from pytorch workloads --- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 4 +--- .../workloads/fastmri/fastmri_pytorch/workload.py | 4 +--- .../imagenet_resnet/imagenet_pytorch/workload.py | 4 +--- .../imagenet_vit/imagenet_pytorch/workload.py | 4 +--- .../librispeech_pytorch/workload.py | 11 +++-------- .../librispeech_pytorch/workload.py | 11 +++-------- algoperf/workloads/ogbg/ogbg_pytorch/workload.py | 5 +---- algoperf/workloads/wmt/wmt_pytorch/workload.py | 6 ++---- 8 files changed, 13 insertions(+), 36 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 726aa8705..638022a5e 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -68,10 +68,8 @@ def loss_fn( def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + dropout_rate: Optional[float] = None) -> spec.ModelInitState: """Only dropout is used.""" - del aux_dropout_rate torch.random.manual_seed(rng[0]) # Disable cudnn benchmark to avoid OOM errors. torch.backends.cudnn.benchmark = False diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 58943de2f..9582325e1 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -108,9 +108,7 @@ def _build_input_queue(self, def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - del aux_dropout_rate + dropout_rate: Optional[float] = None) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = UNet( num_pool_layers=self.num_pool_layers, diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index ed29271f3..372cac7fa 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -157,11 +157,9 @@ def _build_dataset( def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + dropout_rate: Optional[float] = None) -> spec.ModelInitState: """Dropout is unused.""" del dropout_rate - del aux_dropout_rate torch.random.manual_seed(rng[0]) if self.use_silu and self.use_gelu: diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 97bb38515..1a6bb1381 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -24,9 +24,7 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - del aux_dropout_rate + dropout_rate: Optional[float] = None) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = models.ViT( dropout_rate=dropout_rate, diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 5ed37957e..39f33f4aa 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -64,13 +64,8 @@ def attention_temperature(self) -> float: def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Conformer model init function. - - Here we use dropout_rate as residual_dropout_rate, and aux_dropout_rate as - input_dropout_rate. - """ + dropout_rate: Optional[float] = None) -> spec.ModelInitState: + """Conformer model init function.""" torch.random.manual_seed(rng[0]) # Configure torch backends to avoid OOM errors. torch.backends.cudnn.benchmark = False @@ -86,7 +81,7 @@ def init_model_fn( attention_residual_dropout_rate=dropout_rate, feed_forward_residual_dropout_rate=dropout_rate, conv_residual_dropout_rate=dropout_rate, - input_dropout_rate=aux_dropout_rate, + input_dropout_rate=dropout_rate, use_specaug=self.use_specaug, attention_temperature=self.attention_temperature, use_post_layer_norm=self.use_post_layer_norm, diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index e5387f5cb..932ba9392 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -25,19 +25,14 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Deepspeech model init function. - - Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate - as input_dropout_rate. - """ + dropout_rate: Optional[float] = None) -> spec.ModelInitState: + """Deepspeech model init function.""" torch.random.manual_seed(rng[0]) model = DeepspeechEncoderDecoder( DeepspeechConfig( feed_forward_dropout_rate=dropout_rate, use_specaug=self.use_specaug, - input_dropout_rate=aux_dropout_rate, + input_dropout_rate=dropout_rate, use_tanh=self.use_tanh, enable_residual_connections=self.enable_residual_connections, enable_decoder_layer_norm=self.enable_decoder_layer_norm, diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py index 45295ac7f..1dd85951d 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py @@ -139,10 +139,7 @@ def _build_input_queue(self, def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is unused.""" - del aux_dropout_rate + dropout_rate: Optional[float] = None) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = GNN( num_outputs=self._num_outputs, diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py index d0716d6c8..64eea73b7 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/workload.py +++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py @@ -168,9 +168,7 @@ def translate_and_calculate_bleu(self, def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" + dropout_rate: Optional[float] = None) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) if self.activation == 'relu': @@ -182,7 +180,7 @@ def init_model_fn( model = Transformer( dropout_rate=dropout_rate, - attention_dropout_rate=aux_dropout_rate, + attention_dropout_rate=dropout_rate, pre_ln=self.pre_ln, attention_temp=self.attention_temp, activation=activation, From e0a0e624b7cd70fea51a054d85cd1b418076cdef Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 10:57:15 +0200 Subject: [PATCH 09/23] criteo rm dropout from init --- .../criteo1tb_pytorch/models_dropout.py | 21 +++-------- .../criteo1tb/criteo1tb_pytorch/workload.py | 5 +-- .../test_model_equivalence.py | 35 ++++++++----------- 3 files changed, 20 insertions(+), 41 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py index 065ebd1f8..2ac5c2d1b 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py @@ -8,6 +8,8 @@ from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout +DEFAULT_DROPOUT_RATE = 0.0 + class DenseBlock(nn.Module): """Dense block with optional residual connection.""" "" @@ -69,7 +71,6 @@ def __init__(self, mlp_bottom_dims=(256, 256, 256), mlp_top_dims=(256, 256, 256, 256, 1), embed_dim=128, - dropout_rate=0.0, use_layer_norm=False, embedding_init_multiplier=None): super().__init__() @@ -79,10 +80,6 @@ def __init__(self, self.mlp_bottom_dims = mlp_bottom_dims self.mlp_top_dims = mlp_top_dims self.embed_dim = embed_dim - if dropout_rate is None: - self.dropout_rate = 0.0 - else: - self.dropout_rate = dropout_rate # Ideally, we should use the pooled embedding implementation from # `TorchRec`. However, in order to have identical implementation @@ -152,9 +149,7 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward(self, x, dropout_rate=DEFAULT_DROPOUT_RATE): batch_size = x.shape[0] @@ -196,7 +191,6 @@ def __init__(self, mlp_bottom_dims=(512, 256, 128), mlp_top_dims=(1024, 1024, 512, 256, 1), embed_dim=128, - dropout_rate=0.0, use_layer_norm=False, embedding_init_multiplier=None): super().__init__() @@ -207,11 +201,6 @@ def __init__(self, self.mlp_top_dims = mlp_top_dims self.embed_dim = embed_dim self.embedding_init_multiplier = embedding_init_multiplier - self.dropout_rate = dropout_rate - if dropout_rate is None: - self.dropout_rate = 0.0 - else: - self.dropout_rate = dropout_rate # Ideally, we should use the pooled embedding implementation from # `TorchRec`. However, in order to have identical implementation @@ -281,9 +270,7 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward(self, x, dropout_rate=DEFAULT_DROPOUT_RATE): batch_size = x.shape[0] diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 638022a5e..b128f5bd5 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -67,9 +67,7 @@ def loss_fn( def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Only dropout is used.""" + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) # Disable cudnn benchmark to avoid OOM errors. torch.backends.cudnn.benchmark = False @@ -83,7 +81,6 @@ def init_model_fn( mlp_bottom_dims=self.mlp_bottom_dims, mlp_top_dims=self.mlp_top_dims, embed_dim=self.embed_dim, - dropout_rate=dropout_rate, use_layer_norm=self.use_layer_norm, embedding_init_multiplier=self.embedding_init_multiplier) self._param_shapes = param_utils.pytorch_param_shapes(model) diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py index b9b1232ef..c4f074ff3 100644 --- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py @@ -34,8 +34,6 @@ class ModelEquivalenceTest(parameterized.TestCase): @parameterized.named_parameters( - dict(testcase_name='DLRMResNet, p=None', model='dlrm_resnet', dropout_rate=None), - dict(testcase_name='DlrmSmall, p=None', model='dlrm_small', dropout_rate=None), dict(testcase_name='DLRMResNet, p=0.0', model='dlrm_resnet', dropout_rate=0.0), dict(testcase_name='DlrmSmall, p=0.0', model='dlrm_small', dropout_rate=0.0), dict(testcase_name='DLRMResNet, p=0.1', model='dlrm_resnet', dropout_rate=0.1), @@ -50,27 +48,24 @@ def test_forward(self, model, dropout_rate): else (OriginalDlrmSmall, CustomDlrmSmall) ) - # Test initalizing custom model with a None dropout_rate - for custom_init_dropout_rate in [dropout_rate, None]: + torch.manual_seed(SEED) + orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate) + orig.to(DEVICE) - torch.manual_seed(SEED) - orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate) - orig.to(DEVICE) + torch.manual_seed(SEED) + cust = CustCls(vocab_size=VOCAB) + cust.to(DEVICE) - torch.manual_seed(SEED) - cust = CustCls(vocab_size=VOCAB, dropout_rate=custom_init_dropout_rate) - cust.to(DEVICE) - - if TORCH_COMPILE: - orig = torch.compile(orig); cust = torch.compile(cust) - - x = torch.randn(BATCH, FEATURES, device=DEVICE) + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + x = torch.randn(BATCH, FEATURES, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)(); getattr(cust, mode)() - torch.manual_seed(SEED); y1 = orig(x) - torch.manual_seed(SEED); y2 = cust(x, dropout_rate) - assert_close(y1, y2, atol=0, rtol=0) + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(SEED); y1 = orig(x) + torch.manual_seed(SEED); y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) if __name__ == '__main__': From 1e2f379a955bbd75c57042cb6b73750fd3f06eb7 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:07:04 +0200 Subject: [PATCH 10/23] criteo rm dropout from init --- .../criteo1tb_pytorch/test_model_equivalence.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py index c4f074ff3..f40f7b3df 100644 --- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py @@ -63,10 +63,15 @@ def test_forward(self, model, dropout_rate): for mode in ('train', 'eval'): getattr(orig, mode)(); getattr(cust, mode)() - torch.manual_seed(SEED); y1 = orig(x) - torch.manual_seed(SEED); y2 = cust(x, dropout_rate) + torch.manual_seed(SEED) + y1 = orig(x) + torch.manual_seed(SEED) + if mode == 'train': + y2 = cust(x, dropout_rate) + else: + y2 = cust(x) assert_close(y1, y2, atol=0, rtol=0) - + if __name__ == '__main__': absltest.main() From f10e3dc6bdebdf87a3460b600c729c63512f7397 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:11:26 +0200 Subject: [PATCH 11/23] criteo rm dropout from init --- .../dropout_fix/criteo1tb_pytorch/test_model_equivalence.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py index f40f7b3df..32aa5b34f 100644 --- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py @@ -62,7 +62,9 @@ def test_forward(self, model, dropout_rate): x = torch.randn(BATCH, FEATURES, device=DEVICE) for mode in ('train', 'eval'): - getattr(orig, mode)(); getattr(cust, mode)() + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(SEED) y1 = orig(x) torch.manual_seed(SEED) @@ -70,6 +72,7 @@ def test_forward(self, model, dropout_rate): y2 = cust(x, dropout_rate) else: y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) From 027b053e838a57e5a1c33389d0b8f1b0d4269aa2 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:25:31 +0200 Subject: [PATCH 12/23] criteo rm dropout from init --- .../workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py index 2ac5c2d1b..b5ee465e2 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py @@ -3,7 +3,6 @@ import math import torch -import torch.nn.functional as F from torch import nn from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout @@ -30,7 +29,7 @@ def __init__(self, module, resnet=False): self.resnet = resnet self._supports_custom_dropout = True - def forward(self, x, p=None): + def forward(self, x, p): return self.module(x, p) + x if self.resnet else self.module(x, p) From 74c43aa20e3c61b488234014816f1bffeec1508d Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:26:04 +0200 Subject: [PATCH 13/23] fastmri rm dropout from init --- .../fastmri/fastmri_pytorch/models_dropout.py | 12 ++----- .../fastmri/fastmri_pytorch/workload.py | 6 ++-- .../fastmri_pytorch/test_model_equivalence.py | 32 ++++++++----------- 3 files changed, 19 insertions(+), 31 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py index 260cb7e44..73b1d81d9 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py @@ -15,6 +15,7 @@ from algoperf import init_utils from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout +DEFAULT_DROPOUT_RATE = 0.0 class UNet(nn.Module): @@ -29,7 +30,6 @@ def __init__(self, out_chans: int = 1, num_channels: int = 32, num_pool_layers: int = 4, - dropout_rate: Optional[float] = 0.0, use_tanh: bool = False, use_layer_norm: bool = False) -> None: super().__init__() @@ -38,10 +38,6 @@ def __init__(self, self.out_chans = out_chans self.num_channels = num_channels self.num_pool_layers = num_pool_layers - if dropout_rate is None: - self.dropout_rate = 0.0 - else: - self.dropout_rate = dropout_rate self.down_sample_layers = nn.ModuleList([ ConvBlock(in_chans, @@ -78,9 +74,7 @@ def __init__(self, if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): init_utils.pytorch_default_init(m) - def forward(self, x: Tensor, dropout_rate: Optional[float] = None) -> Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward(self, x: Tensor, dropout_rate: Optional[float] = DEFAULT_DROPOUT_RATE) -> Tensor: stack = [] output = x @@ -145,7 +139,7 @@ def __init__(self, CustomDropout2d(), ) - def forward(self, x: Tensor, dropout_rate: Optional[float] = None) -> Tensor: + def forward(self, x: Tensor, dropout_rate: float) -> Tensor: return self.conv_layers(x, dropout_rate) diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 9582325e1..6da0bb0af 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -107,15 +107,13 @@ def _build_input_queue(self, def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = UNet( num_pool_layers=self.num_pool_layers, num_channels=self.num_channels, use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm, - dropout_rate=dropout_rate) + use_layer_norm=self.use_layer_norm) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py index 6339ff21b..c71ff8980 100644 --- a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py @@ -33,36 +33,32 @@ def fwd_pass(self, orig, cust, dropout_rate): torch.manual_seed(0) y1 = orig(x) torch.manual_seed(0) - y2 = cust(x, dropout_rate) + if mode == 'train': + y2 = cust(x, dropout_rate) + else: + y2 = cust(x) assert_close(y1, y2, atol=0, rtol=0) @parameterized.named_parameters( - dict(testcase_name='p=None', dropout_rate=None), dict(testcase_name='p=0.0', dropout_rate=0.0), dict(testcase_name='p=0.1', dropout_rate=0.1), + dict(testcase_name='p=0.7', dropout_rate=0.7), dict(testcase_name='p=1.0', dropout_rate=1.0), ) def test_dropout_values(self, dropout_rate): """Test different values of dropout_rate.""" - # Test initalizing custom model with a None dropout_rate - for custom_init_dropout_rate in [dropout_rate, None]: - - torch.manual_seed(SEED) - orig = OriginalUNet( - IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate - ).to(DEVICE) + torch.manual_seed(SEED) + orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE) - torch.manual_seed(SEED) - cust = CustomUNet( - IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=custom_init_dropout_rate - ).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) - cust.load_state_dict(orig.state_dict()) # sync weights - if TORCH_COMPILE: - orig = torch.compile(orig); cust = torch.compile(cust) + cust.load_state_dict(orig.state_dict()) # sync weights + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) - self.fwd_pass(orig, cust, dropout_rate) + self.fwd_pass(orig, cust, dropout_rate) @parameterized.named_parameters( @@ -71,7 +67,7 @@ def test_dropout_values(self, dropout_rate): dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True), dict(testcase_name='both', use_tanh=True, use_layer_norm=True), ) - def test_arch_setups(self, use_tanh, use_layer_norm): + def test_arch_configs(self, use_tanh, use_layer_norm): """Test different architecture configurations, fixed dropout_rate.""" dropout_rate = 0.1 From 64276ef67cb0ab56bac998c665c82d62e7722087 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:44:11 +0200 Subject: [PATCH 14/23] vit rm dropout at init --- .../imagenet_pytorch/models_dropout.py | 72 +++++++------------ .../wmt/wmt_pytorch/models_dropout.py | 2 +- .../test_model_equivalence.py | 72 ++++++++++++------- 3 files changed, 72 insertions(+), 74 deletions(-) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py index f5e315fd7..8641847b0 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py @@ -14,7 +14,9 @@ from algoperf import init_utils from algoperf import spec -from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention +from algoperf.workloads.wmt.wmt_pytorch.models_dropout import MultiheadAttention + +DEFAULT_DROPOUT_RATE = 0.0 def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: @@ -41,14 +43,12 @@ def __init__( self, width: int, mlp_dim: Optional[int] = None, # Defaults to 4x input dim. - use_glu: bool = False, - dropout_rate: float = 0.0) -> None: + use_glu: bool = False) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim or 4 * width self.use_glu = use_glu - self.dropout_rate = dropout_rate self.linear1 = nn.Linear(self.width, self.mlp_dim) self.act_fnc = nn.GELU(approximate='tanh') @@ -69,9 +69,7 @@ def reset_parameters(self) -> None: if module.bias is not None: module.bias.data.normal_(std=1e-6) - def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: x = self.linear1(x) x = self.act_fnc(x) @@ -90,8 +88,7 @@ class SelfAttention(nn.Module): def __init__(self, width: int, - num_heads: int = 8, - dropout_rate: float = 0.0) -> None: + num_heads: int = 8) -> None: super().__init__() self.width = width @@ -102,7 +99,6 @@ def __init__(self, self.head_dim = int(width / num_heads) self.all_head_dim = self.num_heads * self.head_dim - self.dropout_rate = dropout_rate self.query = nn.Linear(self.width, self.all_head_dim) self.key = nn.Linear(self.width, self.all_head_dim) @@ -122,9 +118,7 @@ def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor: x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: mixed_query_layer = self.query(x) @@ -136,7 +130,7 @@ def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: attention_scores = attention_scores / math.sqrt(self.head_dim) attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = F.dropout(attention_probs, dropout_rate, training=self.training) + attention_probs = F.dropout(attention_probs, dropout_rate, self.training) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() @@ -154,8 +148,7 @@ def __init__(self, mlp_dim: Optional[int] = None, num_heads: int = 12, use_glu: bool = False, - use_post_layer_norm: bool = False, - dropout_rate: float = 0.0) -> None: + use_post_layer_norm: bool = False) -> None: super().__init__() self.width = width @@ -163,7 +156,6 @@ def __init__(self, self.num_heads = num_heads self.use_glu = use_glu self.use_post_layer_norm = use_post_layer_norm - self.dropout_rate = dropout_rate self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) self.self_attention1 = SelfAttention(self.width, self.num_heads) @@ -171,32 +163,29 @@ def __init__(self, self.mlp3 = MlpBlock( width=self.width, mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - dropout_rate=dropout_rate) + use_glu=self.use_glu) - def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: if not self.use_post_layer_norm: y = self.layer_norm0(x) - y = self.self_attention1(y) + y = self.self_attention1(y, dropout_rate) y = F.dropout(y, dropout_rate, training=self.training) x = x + y y = self.layer_norm2(x) - y = self.mlp3(y) + y = self.mlp3(y, dropout_rate) y = F.dropout(y, dropout_rate, training=self.training) x = x + y else: y = x - y = self.self_attention1(y) + y = self.self_attention1(y, dropout_rate) y = F.dropout(y, dropout_rate, training=self.training) x = x + y x = self.layer_norm0(x) y = x - y = self.mlp3(y) + y = self.mlp3(y, dropout_rate) y = F.dropout(y, dropout_rate, training=self.training) x = x + y x = self.layer_norm2(x) @@ -212,8 +201,7 @@ def __init__(self, mlp_dim: Optional[int] = None, num_heads: int = 12, use_glu: bool = False, - use_post_layer_norm: bool = False, - dropout_rate: float = 0.0) -> None: + use_post_layer_norm: bool = False) -> None: super().__init__() self.depth = depth @@ -228,8 +216,7 @@ def __init__(self, self.mlp_dim, self.num_heads, self.use_glu, - self.use_post_layer_norm, - dropout_rate) for _ in range(depth) + self.use_post_layer_norm) for _ in range(depth) ]) if not self.use_post_layer_norm: @@ -237,10 +224,10 @@ def __init__(self, else: self.encoder_norm = None - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: # Input Encoder. for block in self.net: - x = block(x) + x = block(x, dropout_rate) if not self.use_post_layer_norm: return self.encoder_norm(x) else: @@ -267,13 +254,13 @@ def __init__(self, self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: n, _, _ = x.shape probe = torch.tile(self.probe, [n, 1, 1]) - x = self.mha(probe, x)[0] + x = self.mha(probe, x, dropout_rate=dropout_rate)[0] y = self.layer_norm(x) - x = x + self.mlp(y) + x = x + self.mlp(y, dropout_rate) return x[:, 0] @@ -293,15 +280,12 @@ def __init__( mlp_dim: Optional[int] = None, # Defaults to 4x input dim. num_heads: int = 12, rep_size: Union[int, bool] = True, - dropout_rate: Optional[float] = 0.0, head_zeroinit: bool = True, use_glu: bool = False, use_post_layer_norm: bool = False, use_map: bool = False, dtype: Any = torch.float32) -> None: super().__init__() - if dropout_rate is None: - dropout_rate = 0.0 self.num_classes = num_classes self.patch_size = patch_size @@ -315,7 +299,6 @@ def __init__( self.use_post_layer_norm = use_post_layer_norm self.use_map = use_map self.dtype = dtype - self.dropout_rate = dropout_rate if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size @@ -334,8 +317,7 @@ def __init__( mlp_dim=self.mlp_dim, num_heads=self.num_heads, use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=dropout_rate) + use_post_layer_norm=self.use_post_layer_norm) if self.num_classes: self.head = nn.Linear(self.width, self.num_classes) @@ -363,9 +345,7 @@ def reset_parameters(self) -> None: def get_posemb(self, x: spec.Tensor) -> spec.Tensor: return posemb_sincos_2d(x).type(self.dtype) - def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward(self, x: spec.Tensor, dropout_rate=DEFAULT_DROPOUT_RATE) -> spec.Tensor: # Patch extraction. x = self.conv_patch_extract(x) @@ -379,10 +359,10 @@ def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: x = x + pes x = F.dropout(x, dropout_rate, training=self.training) - x = self.encoder(x) + x = self.encoder(x, dropout_rate) if self.use_map: - x = self.map(x) + x = self.map(x, dropout_rate) else: x = torch.mean(x, dim=1) diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py index c5014d87d..6e265cd7f 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py @@ -861,7 +861,7 @@ def forward(self, max_len: Optional[int] = None, cache: Optional[dict] = None, index: Optional[int] = None, - dropout_rate: Optional[float] = None) -> Any: + dropout_rate: Optional[float] = None) -> Any: # TODO: (nico) remove default?! r""" Args: x: Batch of input sequences of shape diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py index 56644f152..32db2e7d4 100644 --- a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py @@ -32,41 +32,41 @@ def fwd_pass(self, orig, cust, dropout_rate): for mode in ('train', 'eval'): getattr(orig, mode)() getattr(cust, mode)() - torch.manual_seed(0); y1 = orig(x) - torch.manual_seed(0); y2 = cust(x, dropout_rate) + torch.manual_seed(0) + y1 = orig(x) + torch.manual_seed(0) + if mode == 'train': + y2 = cust(x, dropout_rate) + else: + y2 = cust(x) assert_close(y1, y2, atol=0, rtol=0) @parameterized.named_parameters( - dict(testcase_name='p=None', dropout_rate=None), dict(testcase_name='p=0.0', dropout_rate=0.0), dict(testcase_name='p=0.1', dropout_rate=0.1), dict(testcase_name='p=0.6', dropout_rate=0.6), dict(testcase_name='p=1.0', dropout_rate=1.0), ) def test_dropout_values(self, dropout_rate): - """Test different dropout_values.""" - - # Test initalizing custom model with a None dropout_rate - for custom_init_dropout_rate in [dropout_rate, None]: - - torch.manual_seed(SEED) - orig = OriginalVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - dropout_rate=dropout_rate, - ).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - dropout_rate=custom_init_dropout_rate, - ).to(DEVICE) - - cust.load_state_dict(orig.state_dict()) # sync weights - self.fwd_pass(orig, cust, dropout_rate) + """Test different dropout_rates.""" + + torch.manual_seed(SEED) + orig = OriginalVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + dropout_rate=dropout_rate, + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + self.fwd_pass(orig, cust, dropout_rate) @parameterized.named_parameters([ @@ -101,12 +101,30 @@ def test_arch(self, use_glu, use_post_ln, use_map): use_glu=use_glu, use_post_layer_norm=use_post_ln, use_map=use_map, - dropout_rate=None, ).to(DEVICE) cust.load_state_dict(orig.state_dict()) # sync weights self.fwd_pass(orig, cust, dropout_rate) + @parameterized.named_parameters( + dict(testcase_name=''), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) + cust.load_state_dict(orig.state_dict()) # sync weights + + x = torch.randn(BATCH, C, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) + if __name__ == '__main__': absltest.main() From 44029d2ef1b8ac81556e90e40b119b642c2b4e41 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:44:16 +0200 Subject: [PATCH 15/23] vit rm dropout at init --- algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 1a6bb1381..20bd3828b 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -23,11 +23,9 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = models.ViT( - dropout_rate=dropout_rate, num_classes=self._num_classes, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, From 44ffec1d6821228eb839a0069b21daeaf32cb08c Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:47:52 +0200 Subject: [PATCH 16/23] add default dropout test --- .../test_model_equivalence.py | 33 ++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py index 32aa5b34f..c59331a3d 100644 --- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py @@ -23,7 +23,6 @@ FEATURES = DENSE + SPARSE VOCAB = 1000 DEVICE = 'cuda' -TORCH_COMPILE = False SEED = 1996 os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" @@ -49,16 +48,11 @@ def test_forward(self, model, dropout_rate): ) torch.manual_seed(SEED) - orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate) - orig.to(DEVICE) + orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate).to(DEVICE) torch.manual_seed(SEED) - cust = CustCls(vocab_size=VOCAB) - cust.to(DEVICE) + cust = CustCls(vocab_size=VOCAB).to(DEVICE) - if TORCH_COMPILE: - orig = torch.compile(orig); cust = torch.compile(cust) - x = torch.randn(BATCH, FEATURES, device=DEVICE) for mode in ('train', 'eval'): @@ -75,6 +69,29 @@ def test_forward(self, model, dropout_rate): assert_close(y1, y2, atol=0, rtol=0) + @parameterized.named_parameters( + dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'), + dict(testcase_name='DlrmSmall, default', model='dlrm_small'), + ) + def test_default_dropout(self, model): + """Test default dropout_rate.""" + OrigCls, CustCls = ( + (OriginalDLRMResNet, CustomDLRMResNet) + if model == 'dlrm_resnet' + else (OriginalDlrmSmall, CustomDlrmSmall) + ) + + torch.manual_seed(SEED) + orig = OrigCls(vocab_size=VOCAB).to(DEVICE) + torch.manual_seed(SEED) + cust = CustCls(vocab_size=VOCAB).to(DEVICE) + + x = torch.randn(BATCH, FEATURES, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) if __name__ == '__main__': absltest.main() From 9d12fa65b1963ab1b77d1cd8787ddad380bf803a Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:49:55 +0200 Subject: [PATCH 17/23] add default dropout test --- .../fastmri_pytorch/test_model_equivalence.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py index c71ff8980..6c8ca896c 100644 --- a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py @@ -15,7 +15,7 @@ BATCH, IN_CHANS, H, W = 4, 1, 256, 256 OUT_CHANS, C, LAYERS = 1, 32, 4 DEVICE = 'cuda' -TORCH_COMPILE = True +TORCH_COMPILE = False SEED = 1996 os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" @@ -89,6 +89,24 @@ def test_arch_configs(self, use_tanh, use_layer_norm): self.fwd_pass(orig, cust, dropout_rate) + @parameterized.named_parameters( + dict(testcase_name=''), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) + cust.load_state_dict(orig.state_dict()) # sync weights + + x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) if __name__ == '__main__': absltest.main() From ac45a9fc07aab0b3af28d06ca7540acb06bcc561 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 12:36:14 +0200 Subject: [PATCH 18/23] conformer: rm dropout_rate from init --- .../librispeech_pytorch/models_dropout.py | 41 ++----- .../librispeech_pytorch/workload.py | 7 +- .../test_model_equivalence.py | 106 +++++++++++------- 3 files changed, 74 insertions(+), 80 deletions(-) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py index 9ff662fb8..f77c8a814 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py @@ -17,6 +17,11 @@ from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ SpecAug +DEFAULT_ATTN_RESIDUAL_DROPOUT_RATE = 0.1 +DEFAULT_CONV_RESIDUAL_DROPOUT_RATE = 0.0 +DEFAULT_FFN_RESIDUAL_DROPOUT_RATE = 0.1 +DEFAULT_INPUT_DROPOUT_RATE = 0.1 + @dataclass class ConformerConfig: @@ -26,13 +31,7 @@ class ConformerConfig: num_attention_heads: int = 8 num_encoder_layers: int = 4 attention_dropout_rate: float = 0.0 - # If None, defaults to 0.1. - attention_residual_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.0. - conv_residual_dropout_rate: Optional[float] = 0.0 feed_forward_dropout_rate: float = 0.0 - # If None, defaults to 0.1. - feed_forward_residual_dropout_rate: Optional[float] = 0.1 convolution_kernel_size: int = 5 feed_forward_expansion_factor: int = 4 freq_mask_count: int = 2 @@ -42,8 +41,6 @@ class ConformerConfig: time_mask_max_ratio: float = 0.05 time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 use_specaug: bool = True @@ -81,11 +78,9 @@ class Subsample(nn.Module): def __init__(self, encoder_dim: int = 0, - input_dropout_rate: float = 0.0, num_bins: int = 80): super().__init__() self.encoder_dim = encoder_dim - self.input_dropout_rate = input_dropout_rate self.conv1 = Conv2dSubsampling( input_channels=1, output_channels=encoder_dim) @@ -100,7 +95,7 @@ def __init__(self, def forward(self, inputs, input_paddings, dropout_rate=None): if dropout_rate is None: - dropout_rate = self.input_dropout_rate + dropout_rate = DEFAULT_INPUT_DROPOUT_RATE output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -207,14 +202,9 @@ def __init__(self, config: ConformerConfig): out_features=config.encoder_dim, bias=True) - if config.feed_forward_residual_dropout_rate is None: - self.feed_forward_residual_dropout_rate = 0.1 - else: - self.feed_forward_residual_dropout_rate = config.feed_forward_residual_dropout_rate - def forward(self, inputs, padding_mask, dropout_rate=None): if dropout_rate is None: - dropout_rate = self.feed_forward_residual_dropout_rate + dropout_rate = DEFAULT_FFN_RESIDUAL_DROPOUT_RATE inputs = self.ln(inputs) inputs = self.linear1(inputs) @@ -319,14 +309,10 @@ def __init__(self, config: ConformerConfig): self.ln = LayerNorm(dim=config.encoder_dim) self.self_attention = MHSAwithQS(config) - if config.attention_residual_dropout_rate is None: - self.attention_residual_dropout_rate = 0.1 - else: - self.attention_residual_dropout_rate = config.attention_residual_dropout_rate def forward(self, outputs, paddings, dropout_rate=None): if dropout_rate is None: - dropout_rate = self.attention_residual_dropout_rate + dropout_rate = DEFAULT_ATTN_RESIDUAL_DROPOUT_RATE outputs = self.ln(outputs) outputs = self.self_attention( @@ -413,14 +399,10 @@ def __init__(self, config): groups=config.encoder_dim) self.bn = BatchNorm(config) self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim) - if config.conv_residual_dropout_rate is None: - self.conv_residual_dropout_rate = 0.0 - else: - self.conv_residual_dropout_rate = config.conv_residual_dropout_rate def forward(self, inputs, input_paddings, dropout_rate=None): if dropout_rate is None: - dropout_rate = self.conv_residual_dropout_rate + dropout_rate = DEFAULT_CONV_RESIDUAL_DROPOUT_RATE inputs = self.ln(inputs) @@ -490,13 +472,8 @@ def __init__(self, config: ConformerConfig): time_masks_per_frame=config.time_masks_per_frame, use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames ) - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate self.subsample = Subsample( encoder_dim=config.encoder_dim, - input_dropout_rate=input_dropout_rate, num_bins=preprocessing_config.num_bins) self.conformers = nn.ModuleList( [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 39f33f4aa..2d0942fe9 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -63,8 +63,7 @@ def attention_temperature(self) -> float: def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: """Conformer model init function.""" torch.random.manual_seed(rng[0]) # Configure torch backends to avoid OOM errors. @@ -78,10 +77,6 @@ def init_model_fn( activation_function_name = 'swish' model = models.ConformerEncoderDecoder( models.ConformerConfig( - attention_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - conv_residual_dropout_rate=dropout_rate, - input_dropout_rate=dropout_rate, use_specaug=self.use_specaug, attention_temperature=self.attention_temperature, use_post_layer_norm=self.use_post_layer_norm, diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py index 19525a98b..ec8318c9a 100644 --- a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py @@ -16,14 +16,15 @@ import os from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( - # ConformerConfig, - ConformerEncoderDecoder as OriginalConf + ConformerConfig as OriginalConfig, + ConformerEncoderDecoder as OriginalModel ) from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import( - ConformerEncoderDecoder as CustomConf, - ConformerConfig, + ConformerConfig as CustomConfig, + ConformerEncoderDecoder as CustomModel, ) +N_LAYERS = 3 B, T = 32, 36_000 DEVICE = 'cuda' @@ -37,55 +38,76 @@ class ConformerEquivalenceTest(parameterized.TestCase): @parameterized.named_parameters( - dict(testcase_name='p=None', dropout_rate=None), dict(testcase_name='p=0.0', dropout_rate=0.0), dict(testcase_name='p=0.2', dropout_rate=0.2), dict(testcase_name='p=0.7', dropout_rate=0.7), dict(testcase_name='p=1.0', dropout_rate=1.0), ) def test_forward(self, dropout_rate): - - # Test initalizing custom model with a None dropout_rate - for custom_init_dropout_rate in [None, dropout_rate]: + + torch.manual_seed(SEED) + orig = OriginalModel( + OriginalConfig( + num_encoder_layers=N_LAYERS, + attention_residual_dropout_rate=dropout_rate, + conv_residual_dropout_rate=dropout_rate, + feed_forward_residual_dropout_rate=dropout_rate, + input_dropout_rate=dropout_rate, + )).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomModel( + CustomConfig( + num_encoder_layers=N_LAYERS + ) + ).to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() torch.manual_seed(SEED) - orig = OriginalConf( - ConformerConfig( - num_encoder_layers=3, - attention_residual_dropout_rate=dropout_rate, - conv_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - input_dropout_rate=dropout_rate, - )).to(DEVICE) - + y1, p1 = orig(x, paddings) torch.manual_seed(SEED) - cust = CustomConf( - ConformerConfig( - num_encoder_layers=3, - attention_residual_dropout_rate=custom_init_dropout_rate, - conv_residual_dropout_rate=custom_init_dropout_rate, - feed_forward_residual_dropout_rate=custom_init_dropout_rate, - input_dropout_rate=custom_init_dropout_rate - )).to(DEVICE) - - orig.load_state_dict(cust.state_dict()) # sync weights - - x = torch.randn(B, T, device=DEVICE) - paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) - - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - y1, p1 = orig(x, paddings) - - torch.manual_seed(SEED) + if mode == 'train': y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) - - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) + else: + y2, p2 = cust(x, paddings) + + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name=''), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalModel(OriginalConfig(num_encoder_layers=N_LAYERS)).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomModel(CustomConfig(num_encoder_layers=N_LAYERS)).to(DEVICE) + orig.load_state_dict(cust.state_dict()) + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + y1, p1 = orig(x, paddings) + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings) + + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) if __name__ == '__main__': absltest.main() From 31d64f6c3228eb410065b212994d3f02883d19ee Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 13:36:15 +0200 Subject: [PATCH 19/23] rm dropout_rate at init from all workloads --- .../fastmri/fastmri_pytorch/models_dropout.py | 5 +- .../imagenet_pytorch/models_dropout.py | 5 +- .../librispeech_pytorch/models_dropout.py | 24 +---- .../librispeech_pytorch/workload.py | 5 +- .../ogbg/ogbg_pytorch/models_dropout.py | 15 +-- .../workloads/ogbg/ogbg_pytorch/workload.py | 4 +- .../wmt/wmt_pytorch/models_dropout.py | 44 ++++---- .../workloads/wmt/wmt_pytorch/workload.py | 5 +- .../test_model_equivalence.py | 17 ++- .../fastmri_pytorch/test_model_equivalence.py | 15 ++- .../test_model_equivalence.py | 15 ++- .../test_model_equivalence.py | 12 ++- .../test_model_equivalence.py | 100 +++++++++++------- .../ogbg_pytorch/test_model_equivalence.py | 55 +++++++--- .../wmt_pytorch/test_model_equivalence.py | 97 +++++++++++------ 15 files changed, 234 insertions(+), 184 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py index 73b1d81d9..0e59e1436 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py @@ -74,7 +74,10 @@ def __init__(self, if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): init_utils.pytorch_default_init(m) - def forward(self, x: Tensor, dropout_rate: Optional[float] = DEFAULT_DROPOUT_RATE) -> Tensor: + def forward( + self, + x: Tensor, + dropout_rate: float = DEFAULT_DROPOUT_RATE) -> Tensor: stack = [] output = x diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py index 8641847b0..570cee575 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py @@ -345,7 +345,10 @@ def reset_parameters(self) -> None: def get_posemb(self, x: spec.Tensor) -> spec.Tensor: return posemb_sincos_2d(x).type(self.dtype) - def forward(self, x: spec.Tensor, dropout_rate=DEFAULT_DROPOUT_RATE) -> spec.Tensor: + def forward( + self, + x: spec.Tensor, + dropout_rate: float = DEFAULT_DROPOUT_RATE) -> spec.Tensor: # Patch extraction. x = self.conv_patch_extract(x) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py index 8797aa578..21a4df614 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py @@ -17,6 +17,7 @@ SpecAug USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ +DEFAULT_DROPOUT_RATE = 0.1 @dataclass @@ -38,10 +39,6 @@ class DeepspeechConfig: use_dynamic_time_mask_max_frames: bool = True batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.1. - feed_forward_dropout_rate: Optional[float] = 0.1 enable_residual_connections: bool = True enable_decoder_layer_norm: bool = True bidirectional: bool = True @@ -87,14 +84,7 @@ def __init__(self, config: DeepspeechConfig): self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True) - if config.input_dropout_rate is None: - self.input_dropout_rate = 0.1 - else: - self.input_dropout_rate = config.input_dropout_rate - - def forward(self, inputs, input_paddings, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.input_dropout_rate + def forward(self, inputs, input_paddings, dropout_rate): output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -207,14 +197,8 @@ def __init__(self, config: DeepspeechConfig): batch_norm_momentum=config.batch_norm_momentum, batch_norm_epsilon=config.batch_norm_epsilon) self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True) - if config.feed_forward_dropout_rate is None: - self.feed_forward_dropout_rate = 0.1 - else: - self.feed_forward_dropout_rate = config.feed_forward_dropout_rate - def forward(self, inputs, input_paddings, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.feed_forward_dropout_rate + def forward(self, inputs, input_paddings, dropout_rate): padding_mask = (1 - input_paddings)[:, :, None] if self.config.layernorm_everywhere: @@ -367,7 +351,7 @@ def __init__(self, config: DeepspeechConfig): self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - def forward(self, inputs, input_paddings, dropout_rate=None): + def forward(self, inputs, input_paddings, dropout_rate=DEFAULT_DROPOUT_RATE): outputs = inputs output_paddings = input_paddings diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 932ba9392..e6ec4764f 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -24,15 +24,12 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: """Deepspeech model init function.""" torch.random.manual_seed(rng[0]) model = DeepspeechEncoderDecoder( DeepspeechConfig( - feed_forward_dropout_rate=dropout_rate, use_specaug=self.use_specaug, - input_dropout_rate=dropout_rate, use_tanh=self.use_tanh, enable_residual_connections=self.enable_residual_connections, enable_decoder_layer_norm=self.enable_decoder_layer_norm, diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py index b86b88caa..c8ed23dda 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py @@ -11,6 +11,8 @@ from algoperf import init_utils from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout +DEFAULT_DROPOUT_RATE = 0.1 + def _make_mlp(in_dim, hidden_dims, activation_fn): """Creates a MLP with specified dimensions.""" @@ -34,7 +36,6 @@ class GNN(nn.Module): def __init__(self, num_outputs: int = 128, - dropout_rate: Optional[float] = 0.1, activation_fn_name: str = 'relu', latent_dim: int = 256, hidden_dims: Tuple[int] = (256,), @@ -44,8 +45,6 @@ def __init__(self, self.hidden_dims = hidden_dims self.num_message_passing_steps = num_message_passing_steps self.num_outputs = num_outputs - if dropout_rate is None: - self.dropout_rate = 0.1 # in_features are specifically chosen for the ogbg workload. self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim) self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim) @@ -94,9 +93,10 @@ def __init__(self, if isinstance(m, nn.Linear): init_utils.pytorch_default_init(m) - def forward(self, graph: GraphsTuple, dropout_rate=None) -> torch.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward( + self, + graph: GraphsTuple, + dropout_rate: float = DEFAULT_DROPOUT_RATE) -> torch.Tensor: graph = graph._replace( globals=torch.zeros([graph.n_node.shape[0], self.num_outputs], @@ -148,7 +148,7 @@ def __init__(self, self.update_global_fn = update_global_fn self._supports_custom_dropout = True # supports SequentialWithDropout - def forward(self, graph: GraphsTuple, dropout_rate=None) -> GraphsTuple: + def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple: """Applies a configured GraphNetwork to a graph. This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 There is one difference. For the nodes update the class aggregates over the @@ -161,6 +161,7 @@ def forward(self, graph: GraphsTuple, dropout_rate=None) -> GraphsTuple: GraphNets, for more information please see the paper. Args: graph: a `GraphsTuple` containing the graph. + dropout_rate: dropout probability value. Returns: Updated `GraphsTuple`. """ diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py index 1dd85951d..7ead696ce 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py @@ -138,12 +138,10 @@ def _build_input_queue(self, def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = GNN( num_outputs=self._num_outputs, - dropout_rate=dropout_rate, hidden_dims=self.hidden_dims, latent_dim=self.latent_dim, num_message_passing_steps=self.num_message_passing_steps, diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py index 6e265cd7f..a5d822669 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py @@ -9,6 +9,8 @@ from torch.nn.init import normal_ from torch.nn.init import xavier_uniform_ +DEFAULT_DROPOUT_RATE = 0.1 + def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: """Make a causal mask for self-attention. @@ -104,17 +106,12 @@ def __init__(self, nhead: int = 16, d_hid: int = 1024, nlayers: int = 6, - dropout_rate: Optional[float] = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, attention_temp: float = 1.0, pre_ln: bool = True): super().__init__() - if dropout_rate is None: - self.dropout_rate = 0.1 - else: - self.dropout_rate = dropout_rate self.pos_encoder = PositionalEncoding(d_model) self.shared_embedding = nn.Embedding(ntoken, d_model) self.encoder = Encoder(d_model, @@ -159,7 +156,7 @@ def forward(self, inputs_segmentation: Optional[Tensor] = None, targets_segmentation: Optional[Tensor] = None, decode: bool = False, - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: float = DEFAULT_DROPOUT_RATE) -> Tensor: """ Args: src: Tensor, shape [batch_size, seq_len] @@ -169,7 +166,7 @@ def forward(self, inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len] targets_segmentation: Optional[Tensor], shape [batch_size, seq_len] decode: bool - dropout_rate: Optional[float] + dropout_rate: float Returns: output Tensor of shape [batch_size, seq_len, ntoken] @@ -177,9 +174,6 @@ def forward(self, if src.size(0) != tgt.size(0): raise RuntimeError('The batch size of src and tgt must be equal.') - if dropout_rate is None: - dropout_rate = self.dropout_rate - memory = self.encoder( src, inputs_positions=inputs_positions, @@ -234,13 +228,13 @@ def __init__(self, def forward(self, src: Tensor, mask: Optional[Tensor] = None, - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: Optional[float] = 0.0) -> Tensor: """Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). mask: the mask for the src sequence (optional). - dropout_rate: the dropout probability (optional) + dropout_rate: the dropout probability (optional). Shape: see the docs in Transformer class. @@ -293,7 +287,7 @@ def forward(self, src: Tensor, inputs_positions: Optional[Tensor] = None, inputs_segmentation: Optional[Tensor] = None, - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: Optional[float] = 0.0) -> Tensor: src = src.to(torch.int) src_mask = make_src_mask(src, inputs_segmentation, self.nhead) src = self.shared_embedding(src) @@ -339,7 +333,7 @@ def forward( decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - dropout_rate: Optional[float] = None) -> Any: + dropout_rate: Optional[float] = 0.0) -> Any: tgt = tgt.to(torch.int) tgt_mask, memory_mask = make_tgt_and_memory_mask( tgt, src, inputs_segmentation, targets_segmentation, @@ -389,7 +383,7 @@ def forward( inputs_positions: Optional[Tensor] = None, decode: bool = False, cache: Optional[Dict[str, Dict[str, Tensor]]] = None, - dropout_rate: Optional[float] = None + dropout_rate: Optional[float] = 0.0 ) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]: """ Args: @@ -397,6 +391,7 @@ def forward( inputs_positions: Tensor (shape [batch_size, seq_len]) or None decode: bool cache: Dict[str, Dict[str, Tensor]] or None + dropout_rate: Optional[float] Returns: Tensor or Tuple[Tensor, Dict[str, Dict[str, Tensor]]] """ @@ -438,7 +433,6 @@ class TransformerEncoderLayer(nn.Module): nhead: the number of heads in the multiheadattention models (default=16). dim_feedforward: the dimension of the feedforward network model (default=1024). - dropout_rate: the dropout_rate value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components @@ -490,7 +484,7 @@ def __init__(self, def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: Optional[float] = 0.0) -> Tensor: r"""Pass the input through the encoder layer. Args: @@ -514,14 +508,14 @@ def forward(self, def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: Optional[float] = 0.0) -> Tensor: x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate) return F.dropout(x, dropout_rate, training=self.training) # Feed forward block: def _ff_block(self, inputs: Tensor, - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: Optional[float] = 0.0) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) @@ -585,7 +579,7 @@ def forward(self, decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - dropout_rate: Optional[float] = None) -> Any: + dropout_rate: Optional[float] = 0.0) -> Any: r"""Pass the inputs (and mask) through the decoder layer in turn. Args: tgt: the sequence to the decoder (required). @@ -705,7 +699,7 @@ def forward( # pylint: disable=arguments-renamed max_len: Optional[int] = None, cache: Optional[dict] = None, index: Optional[int] = None, - dropout_rate: Optional[float] = None) -> Any: + dropout_rate: Optional[float] = 0.0) -> Any: r"""Pass the inputs (and mask) through the decoder layer. Args: tgt: the sequence to the decoder layer (required). @@ -757,7 +751,7 @@ def _sa_block( # pylint: disable=arguments-renamed max_len: Optional[int] = None, cache: Optional[dict] = None, index: Optional[int] = None, - dropout_rate: Optional[float] = None) -> Any: + dropout_rate: Optional[float] = 0.0) -> Any: x, cache = self.self_attn( x, attn_mask=attn_mask, @@ -771,7 +765,7 @@ def _sa_block( # pylint: disable=arguments-renamed # Multihead attention block: def _mha_block(self, x: Tensor, mem: Tensor, attn_mask: Optional[Tensor], - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: Optional[float] = 0.0) -> Tensor: x, _ = self.multihead_attn( x, mem, @@ -781,7 +775,7 @@ def _mha_block(self, x: Tensor, mem: Tensor, # Feed forward block. def _ff_block(self, inputs: Tensor, - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: Optional[float] = 0.0) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) @@ -861,7 +855,7 @@ def forward(self, max_len: Optional[int] = None, cache: Optional[dict] = None, index: Optional[int] = None, - dropout_rate: Optional[float] = None) -> Any: # TODO: (nico) remove default?! + dropout_rate: Optional[float] = 0.0) -> Any: # TODO: (nico) remove default?! r""" Args: x: Batch of input sequences of shape diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py index 64eea73b7..bb9c3834f 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/workload.py +++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py @@ -167,8 +167,7 @@ def translate_and_calculate_bleu(self, def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) if self.activation == 'relu': @@ -179,8 +178,6 @@ def init_model_fn( raise ValueError(f'Unknown activation function {self.activation}.') model = Transformer( - dropout_rate=dropout_rate, - attention_dropout_rate=dropout_rate, pre_ln=self.pre_ln, attention_temp=self.attention_temp, activation=activation, diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py index c59331a3d..db56b17cf 100644 --- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py @@ -56,18 +56,13 @@ def test_forward(self, model, dropout_rate): x = torch.randn(BATCH, FEATURES, device=DEVICE) for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - y1 = orig(x) - torch.manual_seed(SEED) - if mode == 'train': - y2 = cust(x, dropout_rate) - else: - y2 = cust(x) - + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(SEED); y1 = orig(x) + torch.manual_seed(SEED); y2 = cust(x, dropout_rate) assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED); y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) @parameterized.named_parameters( dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'), diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py index 6c8ca896c..0d3d52980 100644 --- a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py @@ -28,16 +28,13 @@ class FastMRIModeEquivalenceTest(parameterized.TestCase): def fwd_pass(self, orig, cust, dropout_rate): x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(0) - y1 = orig(x) - torch.manual_seed(0) - if mode == 'train': - y2 = cust(x, dropout_rate) - else: - y2 = cust(x) + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x, dropout_rate) assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(0); y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) @parameterized.named_parameters( dict(testcase_name='p=0.0', dropout_rate=0.0), diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py index 32db2e7d4..d19fad0ba 100644 --- a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py @@ -30,16 +30,13 @@ class ImageNetVitModeEquivalenceTest(parameterized.TestCase): def fwd_pass(self, orig, cust, dropout_rate): x = torch.randn(BATCH, C, H, W, device=DEVICE) for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(0) - y1 = orig(x) - torch.manual_seed(0) - if mode == 'train': - y2 = cust(x, dropout_rate) - else: - y2 = cust(x) + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x, dropout_rate) assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(0); y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) @parameterized.named_parameters( dict(testcase_name='p=0.0', dropout_rate=0.0), diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py index ec8318c9a..a4238bbc9 100644 --- a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py @@ -74,14 +74,16 @@ def test_forward(self, dropout_rate): torch.manual_seed(SEED) y1, p1 = orig(x, paddings) torch.manual_seed(SEED) - if mode == 'train': - y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) - else: - y2, p2 = cust(x, paddings) - + y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) assert_close(p1, p2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings) + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) @parameterized.named_parameters( dict(testcase_name=''), diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py index e31f4a7eb..acdc8c5b3 100644 --- a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py @@ -14,11 +14,12 @@ import os from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import ( - DeepspeechEncoderDecoder as OriginalModel + DeepspeechEncoderDecoder as OriginalModel, + DeepspeechConfig as OriginalConfig ) from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import( DeepspeechEncoderDecoder as CustomModel, - DeepspeechConfig, + DeepspeechConfig as CustomConfig ) B, T = 32, 30_000 @@ -35,55 +36,82 @@ class DeepSpeechEquivalenceTest(parameterized.TestCase): @parameterized.named_parameters( - dict(testcase_name='p=None', dropout_rate=None), dict(testcase_name='p=0.0', dropout_rate=0.0), dict(testcase_name='p=0.2', dropout_rate=0.2), dict(testcase_name='p=0.7', dropout_rate=0.7), dict(testcase_name='p=1.0', dropout_rate=1.0), ) def test_forward(self, dropout_rate): - - # Test initalizing custom model with a None dropout_rate - for custom_init_dropout_rate in [None, dropout_rate]: + """Test different dropout_rate values.""" + torch.manual_seed(SEED) + orig = OriginalModel( + OriginalConfig( + num_lstm_layers=2, + num_ffn_layers=2, + input_dropout_rate=dropout_rate, + feed_forward_dropout_rate=dropout_rate, + )).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomModel(CustomConfig( + num_lstm_layers=2, + num_ffn_layers=2, + )).to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(SEED) - orig = OriginalModel( - DeepspeechConfig( - num_lstm_layers=2, - num_ffn_layers=2, - input_dropout_rate=dropout_rate, - feed_forward_dropout_rate=dropout_rate, - )).to(DEVICE) + y1, p1 = orig(x, paddings) torch.manual_seed(SEED) - cust = CustomModel(DeepspeechConfig( - num_lstm_layers=2, - num_ffn_layers=2, - input_dropout_rate=custom_init_dropout_rate, - feed_forward_dropout_rate=custom_init_dropout_rate, - )).to(DEVICE) + y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) - orig.load_state_dict(cust.state_dict()) # sync weights - - if TORCH_COMPILE: - orig = torch.compile(orig); cust = torch.compile(cust) - - x = torch.randn(B, T, device=DEVICE) - paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) - - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - y1, p1 = orig(x, paddings) - + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + if mode == 'eval': # one extra test: omit dropout at eval torch.manual_seed(SEED) - y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) - + y2, p2 = cust(x, paddings) assert_close(y1, y2, atol=0, rtol=0) assert_close(p1, p2, atol=0, rtol=0) + @parameterized.named_parameters( + dict(testcase_name=''), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalModel(OriginalConfig( num_lstm_layers=2, num_ffn_layers=2)).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomModel(CustomConfig( num_lstm_layers=2, num_ffn_layers=2)).to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights + + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(SEED); y1, p1 = orig(x, paddings) + torch.manual_seed(SEED); y2, p2 = cust(x, paddings) + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + if __name__ == '__main__': absltest.main() diff --git a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py index cc1857705..aaca6cebd 100644 --- a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py @@ -42,35 +42,62 @@ def _rand_graph(): class GNNEquivalenceTest(parameterized.TestCase): @parameterized.named_parameters( - dict(testcase_name='None', dropout_rate=None), dict(testcase_name='0.0', dropout_rate=0.0), dict(testcase_name='0.2', dropout_rate=0.2), dict(testcase_name='0.7', dropout_rate=0.7), dict(testcase_name='1.0', dropout_rate=1.0), ) def test_forward(self, dropout_rate): + """Test different dropout_rates.""" - # Test initalizing custom model with a None dropout_rate - for custom_init_dropout_rate in [None, dropout_rate]: + orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE) + cust = CustomModel().to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights - orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE) - cust = CustomModel(dropout_rate=custom_init_dropout_rate).to(DEVICE) - orig.load_state_dict(cust.state_dict()) # sync weights + graph = _rand_graph() - graph = _rand_graph() + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(graph) - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y1 = orig(graph) + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(graph, dropout_rate=dropout_rate) - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y2 = cust(graph, dropout_rate=dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(graph) assert_close(y1, y2, atol=0, rtol=0) + @parameterized.named_parameters( + dict(testcase_name=''), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + + orig = OriginalModel().to(DEVICE) + cust = CustomModel().to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights + + graph = _rand_graph() + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(graph) + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(graph) + + assert_close(y1, y2, atol=0, rtol=0) + + if __name__ == '__main__': absltest.main() diff --git a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py index 9aca717d9..9675f1df2 100644 --- a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py @@ -32,52 +32,79 @@ def _rand_tokens(bs, seqlen): class TransformerEquivalenceTest(parameterized.TestCase): @parameterized.named_parameters( - # NOTE: removed dropout=1.0 will generate nan in scaled_dot_product_attention - - dict(testcase_name="None", dropout_rate=None, compile=False), + # NOTE: removed dropout=1.0 since it will generate nan in scaled_dot_product_attention dict(testcase_name="0.0", dropout_rate=0.0, compile=False), dict(testcase_name="0.2", dropout_rate=0.2, compile=False), dict(testcase_name="0.7", dropout_rate=0.7, compile=False), - - dict(testcase_name="p=None, compile", dropout_rate=None, compile=True), - dict(testcase_name="p=0.0, compile", dropout_rate=0.0, compile=True), - dict(testcase_name="p=0.2, compile", dropout_rate=0.2, compile=True), - dict(testcase_name="p=0.7, compile", dropout_rate=0.7, compile=True), + dict(testcase_name="p=0.0_compile", dropout_rate=0.0, compile=True), + dict(testcase_name="p=0.2_compile", dropout_rate=0.2, compile=True), + dict(testcase_name="p=0.7_compile", dropout_rate=0.7, compile=True), ) - def test_forward(self, dropout_rate, compile): - - # Test initalizing custom model with a None dropout_rate - for custom_init_dropout_rate in [None, dropout_rate]: - - orig = OriginalModel( - dropout_rate=dropout_rate, - attention_dropout_rate=dropout_rate - ).to(DEVICE) - cust = CustomModel( - dropout_rate=custom_init_dropout_rate - ).to(DEVICE) - - orig.load_state_dict(cust.state_dict()) # sync weights - - if compile: - orig = torch.compile(orig) - cust = torch.compile(cust) + def test_dropout_value(self, dropout_rate, compile): - src = _rand_tokens(B, SRC_LEN) - tgt = _rand_tokens(B, TGT_LEN) + orig = OriginalModel( + dropout_rate=dropout_rate, + attention_dropout_rate=dropout_rate + ).to(DEVICE) + cust = CustomModel().to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if compile: + orig = torch.compile(orig) + cust = torch.compile(cust) - for mode in ("train", "eval"): - getattr(orig, mode)() - getattr(cust, mode)() + src = _rand_tokens(B, SRC_LEN) + tgt = _rand_tokens(B, TGT_LEN) - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y1 = orig(src, tgt) + for mode in ("train", "eval"): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(src, tgt) + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(src, tgt, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + + if mode == 'eval': # one extra test: omit dropout at eval torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y2 = cust(src, tgt, dropout_rate=dropout_rate) - + y2 = cust(src, tgt) assert_close(y1, y2, atol=0, rtol=0) + @parameterized.named_parameters( + dict(testcase_name="default", compile=False), + dict(testcase_name="default_compile", compile=True), + ) + def test_default(self, compile): + + orig = OriginalModel().to(DEVICE) + cust = CustomModel().to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if compile: + orig = torch.compile(orig) + cust = torch.compile(cust) + + src = _rand_tokens(B, SRC_LEN) + tgt = _rand_tokens(B, TGT_LEN) + + for mode in ("train", "eval"): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(src, tgt) + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(src, tgt) + + assert_close(y1, y2, atol=0, rtol=0) + + if __name__ == "__main__": absltest.main() From 0128c9fe4caa2590c60af6c7276ab6789d06ae6d Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 13 Jun 2025 13:50:00 +0200 Subject: [PATCH 20/23] pipe dropout to model_fn, set default in workload --- algoperf/spec.py | 6 +++-- .../criteo1tb_pytorch/models_dropout.py | 6 ++--- .../criteo1tb/criteo1tb_pytorch/workload.py | 6 +++-- .../fastmri/fastmri_pytorch/models_dropout.py | 4 +-- .../fastmri/fastmri_pytorch/workload.py | 9 ++++--- .../imagenet_pytorch/workload.py | 9 +++---- .../imagenet_pytorch/models_dropout.py | 4 +-- .../imagenet_vit/imagenet_pytorch/workload.py | 7 +++-- .../librispeech_pytorch/models_dropout.py | 27 +++++-------------- .../librispeech_pytorch/workload.py | 7 +++-- .../librispeech_pytorch/models_dropout.py | 4 +-- .../ogbg/ogbg_pytorch/models_dropout.py | 4 +-- .../workloads/ogbg/ogbg_pytorch/workload.py | 8 ++++-- .../wmt/wmt_pytorch/models_dropout.py | 4 +-- .../workloads/wmt/wmt_pytorch/workload.py | 8 ++++-- 15 files changed, 60 insertions(+), 53 deletions(-) diff --git a/algoperf/spec.py b/algoperf/spec.py index cf4f1a14e..9670dcb76 100644 --- a/algoperf/spec.py +++ b/algoperf/spec.py @@ -247,7 +247,8 @@ def init_model_fn(self, # ModelAuxiliaryState, # ForwardPassMode, # RandomState, - # bool], + # bool, + # float], # Tensor] @abc.abstractmethod def model_fn(self, @@ -256,7 +257,8 @@ def model_fn(self, model_state: ModelAuxiliaryState, mode: ForwardPassMode, rng: RandomState, - update_batch_norm: bool) -> Tuple[Tensor, ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float) -> Tuple[Tensor, ModelAuxiliaryState]: """Return logits_batch""" # Possible side effect of updating BN. diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py index b5ee465e2..f0653a665 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py @@ -7,7 +7,7 @@ from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout -DEFAULT_DROPOUT_RATE = 0.0 +DROPOUT_RATE = 0.0 class DenseBlock(nn.Module): @@ -148,7 +148,7 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x, dropout_rate=DEFAULT_DROPOUT_RATE): + def forward(self, x, dropout_rate=DROPOUT_RATE): batch_size = x.shape[0] @@ -269,7 +269,7 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x, dropout_rate=DEFAULT_DROPOUT_RATE): + def forward(self, x, dropout_rate=DROPOUT_RATE): batch_size = x.shape[0] diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py index b128f5bd5..74cb3e140 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -15,6 +15,7 @@ BaseCriteo1TbDlrmSmallWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() +DROPOUT_RATE = 0.0 class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload): @@ -103,7 +104,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -123,7 +125,7 @@ def model_fn( } with contexts[mode](): - logits_batch = model(inputs) + logits_batch = model(inputs, dropout_rate=dropout_rate) return logits_batch, None diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py index 0e59e1436..0b8ac5499 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py @@ -15,7 +15,7 @@ from algoperf import init_utils from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout -DEFAULT_DROPOUT_RATE = 0.0 +DROPOUT_RATE = 0.0 class UNet(nn.Module): @@ -77,7 +77,7 @@ def __init__(self, def forward( self, x: Tensor, - dropout_rate: float = DEFAULT_DROPOUT_RATE) -> Tensor: + dropout_rate: float = DROPOUT_RATE) -> Tensor: stack = [] output = x diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 6da0bb0af..6374c62d6 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -19,6 +19,8 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() +DROPOUT_RATE = 0.0 + class FastMRIWorkload(BaseFastMRIWorkload): @@ -134,7 +136,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -154,8 +157,8 @@ def model_fn( with contexts[mode](): logit_batch = model( - augmented_and_preprocessed_input_batch['inputs'].unsqueeze( - 1)).squeeze(1) + augmented_and_preprocessed_input_batch['inputs'].unsqueeze(1), + dropout_rate=dropout_rate).squeeze(1) return logit_batch, None diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 372cac7fa..f28eb1762 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -156,10 +156,7 @@ def _build_dataset( def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Dropout is unused.""" - del dropout_rate + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) if self.use_silu and self.use_gelu: @@ -192,9 +189,11 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng + del dropout_rate model = params diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py index 570cee575..60e09edb5 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py @@ -16,7 +16,7 @@ from algoperf import spec from algoperf.workloads.wmt.wmt_pytorch.models_dropout import MultiheadAttention -DEFAULT_DROPOUT_RATE = 0.0 +DROPOUT_RATE = 0.0 def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: @@ -348,7 +348,7 @@ def get_posemb(self, x: spec.Tensor) -> spec.Tensor: def forward( self, x: spec.Tensor, - dropout_rate: float = DEFAULT_DROPOUT_RATE) -> spec.Tensor: + dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: # Patch extraction. x = self.conv_patch_extract(x) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 20bd3828b..8b011071a 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -16,6 +16,7 @@ from algoperf.workloads.imagenet_vit.workload import decode_variant USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() +DROPOUT_RATE = 0.0 # Make sure we inherit from the ViT base workload first. @@ -51,7 +52,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -70,7 +72,8 @@ def model_fn( } with contexts[mode](): - logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) + logits_batch = model(augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate) return logits_batch, None diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py index f77c8a814..a6a60bf95 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py @@ -17,10 +17,7 @@ from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ SpecAug -DEFAULT_ATTN_RESIDUAL_DROPOUT_RATE = 0.1 -DEFAULT_CONV_RESIDUAL_DROPOUT_RATE = 0.0 -DEFAULT_FFN_RESIDUAL_DROPOUT_RATE = 0.1 -DEFAULT_INPUT_DROPOUT_RATE = 0.1 +DROPOUT_RATE = 0.1 @dataclass @@ -93,9 +90,7 @@ def __init__(self, bias=True) self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) - def forward(self, inputs, input_paddings, dropout_rate=None): - if dropout_rate is None: - dropout_rate = DEFAULT_INPUT_DROPOUT_RATE + def forward(self, inputs, input_paddings, dropout_rate): output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -202,9 +197,7 @@ def __init__(self, config: ConformerConfig): out_features=config.encoder_dim, bias=True) - def forward(self, inputs, padding_mask, dropout_rate=None): - if dropout_rate is None: - dropout_rate = DEFAULT_FFN_RESIDUAL_DROPOUT_RATE + def forward(self, inputs, padding_mask, dropout_rate): inputs = self.ln(inputs) inputs = self.linear1(inputs) @@ -310,10 +303,7 @@ def __init__(self, config: ConformerConfig): self.ln = LayerNorm(dim=config.encoder_dim) self.self_attention = MHSAwithQS(config) - def forward(self, outputs, paddings, dropout_rate=None): - if dropout_rate is None: - dropout_rate = DEFAULT_ATTN_RESIDUAL_DROPOUT_RATE - + def forward(self, outputs, paddings, dropout_rate): outputs = self.ln(outputs) outputs = self.self_attention( outputs, @@ -400,10 +390,7 @@ def __init__(self, config): self.bn = BatchNorm(config) self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim) - def forward(self, inputs, input_paddings, dropout_rate=None): - if dropout_rate is None: - dropout_rate = DEFAULT_CONV_RESIDUAL_DROPOUT_RATE - + def forward(self, inputs, input_paddings, dropout_rate): inputs = self.ln(inputs) inputs = F.glu(torch.cat([self.lin1(inputs), self.lin2(inputs)], dim=2)) @@ -442,7 +429,7 @@ def __init__(self, config: ConformerConfig): if config.use_post_layer_norm: self.ln = LayerNorm(dim=config.encoder_dim) - def forward(self, inputs, input_paddings, dropout_rate=None): + def forward(self, inputs, input_paddings, dropout_rate): padding_mask = 1 - input_paddings[:, :, None] inputs = inputs + 0.5 * self.ff1(inputs, padding_mask, dropout_rate) inputs = inputs + self.mhsa(inputs, input_paddings, dropout_rate) @@ -481,7 +468,7 @@ def __init__(self, config: ConformerConfig): self.ln = LayerNorm(config.encoder_dim) self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - def forward(self, inputs, input_paddings, dropout_rate=None): + def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs = inputs output_paddings = input_paddings outputs, output_paddings = self.preprocessor(outputs, output_paddings) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 2d0942fe9..d99bc1608 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -24,6 +24,7 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() MAX_INPUT_LENGTH = 320000 +DROPOUT_RATE = 0.1 class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): @@ -105,7 +106,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng @@ -126,7 +128,8 @@ def model_fn( with contexts[mode](): inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] logits, logits_paddings = model(inputs.to(DEVICE), - input_paddings.to(DEVICE)) + input_paddings.to(DEVICE), + dropout_rate=dropout_rate) return (logits, logits_paddings), None def _build_input_queue( diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py index 21a4df614..a8480a343 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py @@ -17,7 +17,7 @@ SpecAug USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ -DEFAULT_DROPOUT_RATE = 0.1 +DROPOUT_RATE = 0.1 @dataclass @@ -351,7 +351,7 @@ def __init__(self, config: DeepspeechConfig): self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - def forward(self, inputs, input_paddings, dropout_rate=DEFAULT_DROPOUT_RATE): + def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs = inputs output_paddings = input_paddings diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py index c8ed23dda..be5882333 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py @@ -11,7 +11,7 @@ from algoperf import init_utils from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout -DEFAULT_DROPOUT_RATE = 0.1 +DROPOUT_RATE = 0.1 def _make_mlp(in_dim, hidden_dims, activation_fn): @@ -96,7 +96,7 @@ def __init__(self, def forward( self, graph: GraphsTuple, - dropout_rate: float = DEFAULT_DROPOUT_RATE) -> torch.Tensor: + dropout_rate: float = DROPOUT_RATE) -> torch.Tensor: graph = graph._replace( globals=torch.zeros([graph.n_node.shape[0], self.num_outputs], diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py index 7ead696ce..281f4cd08 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py @@ -17,6 +17,8 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() +DROPOUT_RATE = 0.1 + def _pytorch_map(inputs: Any) -> Any: if USE_PYTORCH_DDP: @@ -166,7 +168,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del rng del update_batch_norm # No BN in the GNN model. @@ -186,7 +189,8 @@ def model_fn( } with contexts[mode](): - logits = model(augmented_and_preprocessed_input_batch['inputs']) + logits = model(augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate) return logits, None diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py index a5d822669..a43df30d4 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py @@ -9,7 +9,7 @@ from torch.nn.init import normal_ from torch.nn.init import xavier_uniform_ -DEFAULT_DROPOUT_RATE = 0.1 +DROPOUT_RATE = 0.1 def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: @@ -156,7 +156,7 @@ def forward(self, inputs_segmentation: Optional[Tensor] = None, targets_segmentation: Optional[Tensor] = None, decode: bool = False, - dropout_rate: float = DEFAULT_DROPOUT_RATE) -> Tensor: + dropout_rate: float = DROPOUT_RATE) -> Tensor: """ Args: src: Tensor, shape [batch_size, seq_len] diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py index bb9c3834f..d30abc4c7 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/workload.py +++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py @@ -22,6 +22,8 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() +DROPOUT_RATE = 0.1 + class WmtWorkload(BaseWmtWorkload): """WMT PyTorch workload.""" @@ -202,7 +204,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -228,7 +231,8 @@ def model_fn( inputs_segmentation=augmented_and_preprocessed_input_batch.get( 'inputs_segmentation', None), targets_segmentation=augmented_and_preprocessed_input_batch.get( - 'targets_segmentation', None)) + 'targets_segmentation', None), + dropout_rate=dropout_rate) return logits_batch, None From a7cba1a1acc9da53b9500cab8755f49c88bdbb2c Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 13 Jun 2025 14:19:36 +0200 Subject: [PATCH 21/23] remove aux_dropout from pytorch workloads --- .../test_model_equivalence.py | 32 +------------------ .../test_model_equivalence.py | 2 +- 2 files changed, 2 insertions(+), 32 deletions(-) diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py index a4238bbc9..4a1252a39 100644 --- a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py @@ -3,11 +3,7 @@ Run with: python3 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py -`dropout_rate` controls the following args: -- `attention_residual_dropout_rate` (if None, 0.1 -- `conv_residual_dropout_rate` (if None, 0.0) -- `feed_forward_residual_dropout_rate` (if None, 0.1) -- `input_dropout_rate` (if None, 0.1) +NOTE: we don't test for default dropout_rate values, since they changed. """ from absl.testing import absltest, parameterized @@ -85,31 +81,5 @@ def test_forward(self, dropout_rate): assert_close(y1, y2, atol=0, rtol=0) assert_close(p1, p2, atol=0, rtol=0) - @parameterized.named_parameters( - dict(testcase_name=''), - ) - def test_default_dropout(self): - """Test default dropout_rate.""" - - torch.manual_seed(SEED) - orig = OriginalModel(OriginalConfig(num_encoder_layers=N_LAYERS)).to(DEVICE) - torch.manual_seed(SEED) - cust = CustomModel(CustomConfig(num_encoder_layers=N_LAYERS)).to(DEVICE) - orig.load_state_dict(cust.state_dict()) - - x = torch.randn(B, T, device=DEVICE) - paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - y1, p1 = orig(x, paddings) - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings) - - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) - if __name__ == '__main__': absltest.main() diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py index acdc8c5b3..58ddb354e 100644 --- a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py @@ -24,7 +24,7 @@ B, T = 32, 30_000 DEVICE = 'cuda' -TORCH_COMPILE = True +TORCH_COMPILE = False os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8" torch.backends.cudnn.benchmark = False From d8e39b0da371abbcf311ce1d09e06439bd5a0eec Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sun, 15 Jun 2025 17:08:50 +0200 Subject: [PATCH 22/23] fix to model_fn default dropout_rate --- .../criteo1tb/criteo1tb_pytorch/workload.py | 3 +-- .../fastmri/fastmri_pytorch/workload.py | 5 ++--- .../imagenet_vit/imagenet_pytorch/workload.py | 3 +-- .../librispeech_pytorch/workload.py | 3 +-- .../librispeech_pytorch/workload.py | 17 ++++++++++++++++- .../workloads/ogbg/ogbg_pytorch/workload.py | 5 ++--- algoperf/workloads/wmt/wmt_pytorch/workload.py | 5 ++--- 7 files changed, 25 insertions(+), 16 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 74cb3e140..48c6592f2 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -15,7 +15,6 @@ BaseCriteo1TbDlrmSmallWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() -DROPOUT_RATE = 0.0 class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload): @@ -105,7 +104,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 6374c62d6..9b96230fc 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -13,14 +13,13 @@ from algoperf import pytorch_utils from algoperf import spec import algoperf.random_utils as prng +from algoperf.workloads.fastmri.fastmri_pytorch import models from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() -DROPOUT_RATE = 0.0 - class FastMRIWorkload(BaseFastMRIWorkload): @@ -137,7 +136,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 8b011071a..f86a1b1c2 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -16,7 +16,6 @@ from algoperf.workloads.imagenet_vit.workload import decode_variant USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() -DROPOUT_RATE = 0.0 # Make sure we inherit from the ViT base workload first. @@ -53,7 +52,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index d99bc1608..0477a7389 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -24,7 +24,6 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() MAX_INPUT_LENGTH = 320000 -DROPOUT_RATE = 0.1 class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): @@ -107,7 +106,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index e6ec4764f..bf345cfc9 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Dict, Tuple import torch from torch.nn.parallel import DistributedDataParallel as DDP @@ -6,6 +6,7 @@ from algoperf import param_utils from algoperf import spec from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ initialize from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ @@ -54,6 +55,20 @@ def init_model_fn( else: model = torch.nn.DataParallel(model) return model, None + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + # override super method, changing only the default dropout_rate + return super().model_fn( + params, augmented_and_preprocessed_input_batch, model_state, + mode, rng, update_batch_norm, dropout_rate) def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['lin.weight', 'lin.bias'] diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py index 281f4cd08..758b36b60 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py @@ -12,13 +12,12 @@ from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.ogbg import metrics +from algoperf.workloads.ogbg.ogbg_pytorch import models from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN from algoperf.workloads.ogbg.workload import BaseOgbgWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() -DROPOUT_RATE = 0.1 - def _pytorch_map(inputs: Any) -> Any: if USE_PYTORCH_DDP: @@ -169,7 +168,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del rng del update_batch_norm # No BN in the GNN model. diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py index d30abc4c7..4c787becc 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/workload.py +++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py @@ -17,13 +17,12 @@ from algoperf import spec from algoperf.workloads.wmt import bleu from algoperf.workloads.wmt.wmt_pytorch import decode +from algoperf.workloads.wmt.wmt_pytorch import models from algoperf.workloads.wmt.wmt_pytorch.models import Transformer from algoperf.workloads.wmt.workload import BaseWmtWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() -DROPOUT_RATE = 0.1 - class WmtWorkload(BaseWmtWorkload): """WMT PyTorch workload.""" @@ -205,7 +204,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm From 7a0015830e840211b8002f068b1eb918c8390f5c Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sun, 15 Jun 2025 17:14:18 +0200 Subject: [PATCH 23/23] rm models_dropout torch files --- .../criteo1tb/criteo1tb_pytorch/models.py | 54 +- .../criteo1tb_pytorch/models_dropout.py | 297 ------ .../fastmri/fastmri_pytorch/models.py | 44 +- .../fastmri/fastmri_pytorch/models_dropout.py | 173 --- .../imagenet_vit/imagenet_pytorch/models.py | 84 +- .../imagenet_pytorch/models_dropout.py | 378 ------- .../librispeech_pytorch/models.py | 70 +- .../librispeech_pytorch/models_dropout.py | 482 --------- .../librispeech_pytorch/models.py | 32 +- .../librispeech_pytorch/models_dropout.py | 379 ------- .../workloads/ogbg/ogbg_pytorch/models.py | 35 +- .../ogbg/ogbg_pytorch/models_dropout.py | 315 ------ algoperf/workloads/wmt/wmt_pytorch/models.py | 181 ++-- .../wmt/wmt_pytorch/models_dropout.py | 989 ------------------ 14 files changed, 234 insertions(+), 3279 deletions(-) delete mode 100644 algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py delete mode 100644 algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py delete mode 100644 algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py delete mode 100644 algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py delete mode 100644 algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py delete mode 100644 algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py delete mode 100644 algoperf/workloads/wmt/wmt_pytorch/models_dropout.py diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py index 7a40f0e81..f0653a665 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -5,20 +5,32 @@ import torch from torch import nn +from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout + +DROPOUT_RATE = 0.0 + class DenseBlock(nn.Module): """Dense block with optional residual connection.""" "" - def __init__(self, module, resnet=False): super().__init__() self.module = module self.resnet = resnet def forward(self, x): - if self.resnet: - return self.module(x) + x - else: - return self.module(x) + return self.module(x) + x if self.resnet else self.module(x) + + +class DenseBlockWithDropout(nn.Module): + """Dense block with optional residual connection and support for dropout.""" + def __init__(self, module, resnet=False): + super().__init__() + self.module = module + self.resnet = resnet + self._supports_custom_dropout = True + + def forward(self, x, p): + return self.module(x, p) + x if self.resnet else self.module(x, p) class DotInteract(nn.Module): @@ -58,7 +70,6 @@ def __init__(self, mlp_bottom_dims=(256, 256, 256), mlp_top_dims=(256, 256, 256, 256, 1), embed_dim=128, - dropout_rate=0.0, use_layer_norm=False, embedding_init_multiplier=None): super().__init__() @@ -116,17 +127,16 @@ def __init__(self, block.append(nn.Linear(fan_in, fan_out)) if layer_idx < (num_layers_top - 1): block.append(nn.ReLU(inplace=True)) - if (dropout_rate is not None and dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - block.append(nn.Dropout(p=dropout_rate)) - block = nn.Sequential(*block) + if layer_idx == num_layers_top - 2: + block.append(CustomDropout()) + block = SequentialWithDropout(*block) if (layer_idx != 0) and (layer_idx != num_layers_top - 1): - block = DenseBlock(block, resnet=True) + block = DenseBlockWithDropout(block, resnet=True) else: - block = DenseBlock(block) + block = DenseBlockWithDropout(block) mlp_top_blocks.append(block) fan_in = fan_out - self.top_mlp = nn.Sequential(*mlp_top_blocks) + self.top_mlp = SequentialWithDropout(*mlp_top_blocks) for module in self.top_mlp.modules(): if isinstance(module, nn.Linear): @@ -138,7 +148,8 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x): + def forward(self, x, dropout_rate=DROPOUT_RATE): + batch_size = x.shape[0] dense_features, sparse_features = torch.split( @@ -157,7 +168,7 @@ def forward(self, x): top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) # Final MLP. - logits = self.top_mlp(top_mlp_input) + logits = self.top_mlp(top_mlp_input, dropout_rate) return logits @@ -179,7 +190,6 @@ def __init__(self, mlp_bottom_dims=(512, 256, 128), mlp_top_dims=(1024, 1024, 512, 256, 1), embed_dim=128, - dropout_rate=0.0, use_layer_norm=False, embedding_init_multiplier=None): super().__init__() @@ -242,10 +252,9 @@ def __init__(self, top_mlp_layers.append(nn.ReLU(inplace=True)) if use_layer_norm: top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) - if (dropout_rate is not None and dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - top_mlp_layers.append(nn.Dropout(p=dropout_rate)) - self.top_mlp = nn.Sequential(*top_mlp_layers) + if layer_idx == num_layers_top - 2: + top_mlp_layers.append(CustomDropout()) + self.top_mlp = SequentialWithDropout(*top_mlp_layers) if use_layer_norm: self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) else: @@ -260,7 +269,8 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x): + def forward(self, x, dropout_rate=DROPOUT_RATE): + batch_size = x.shape[0] dense_features, sparse_features = torch.split( @@ -283,5 +293,5 @@ def forward(self, x): dense_features=embedded_dense, sparse_features=embedded_sparse) # Final MLP. - logits = self.top_mlp(concatenated_dense) + logits = self.top_mlp(concatenated_dense, dropout_rate) return logits diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py deleted file mode 100644 index f0653a665..000000000 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py +++ /dev/null @@ -1,297 +0,0 @@ -"""Pytorch implementation of DLRM-Small.""" - -import math - -import torch -from torch import nn - -from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout - -DROPOUT_RATE = 0.0 - - -class DenseBlock(nn.Module): - """Dense block with optional residual connection.""" "" - def __init__(self, module, resnet=False): - super().__init__() - self.module = module - self.resnet = resnet - - def forward(self, x): - return self.module(x) + x if self.resnet else self.module(x) - - -class DenseBlockWithDropout(nn.Module): - """Dense block with optional residual connection and support for dropout.""" - def __init__(self, module, resnet=False): - super().__init__() - self.module = module - self.resnet = resnet - self._supports_custom_dropout = True - - def forward(self, x, p): - return self.module(x, p) + x if self.resnet else self.module(x, p) - - -class DotInteract(nn.Module): - """Performs feature interaction operation between dense or sparse features.""" - - def __init__(self, num_sparse_features): - super().__init__() - self.triu_indices = torch.triu_indices(num_sparse_features + 1, - num_sparse_features + 1) - - def forward(self, dense_features, sparse_features): - combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features), - dim=1) - interactions = torch.bmm(combined_values, - torch.transpose(combined_values, 1, 2)) - interactions_flat = interactions[:, - self.triu_indices[0], - self.triu_indices[1]] - return torch.cat((dense_features, interactions_flat), dim=1) - - -class DLRMResNet(nn.Module): - """Define a DLRM-Small model. - - Parameters: - vocab_size: vocab size of embedding table. - num_dense_features: number of dense features as the bottom mlp input. - mlp_bottom_dims: dimensions of dense layers of the bottom mlp. - mlp_top_dims: dimensions of dense layers of the top mlp. - embed_dim: embedding dimension. - """ - - def __init__(self, - vocab_size, - num_dense_features=13, - num_sparse_features=26, - mlp_bottom_dims=(256, 256, 256), - mlp_top_dims=(256, 256, 256, 256, 1), - embed_dim=128, - use_layer_norm=False, - embedding_init_multiplier=None): - super().__init__() - self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) - self.num_dense_features = num_dense_features - self.num_sparse_features = num_sparse_features - self.mlp_bottom_dims = mlp_bottom_dims - self.mlp_top_dims = mlp_top_dims - self.embed_dim = embed_dim - - # Ideally, we should use the pooled embedding implementation from - # `TorchRec`. However, in order to have identical implementation - # with that of Jax, we define a single embedding matrix. - num_chunks = 4 - assert vocab_size % num_chunks == 0 - self.embedding_table_chucks = [] - scale = 1.0 / torch.sqrt(self.vocab_size) - for i in range(num_chunks): - chunk = nn.Parameter( - torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) - chunk.data.uniform_(0, 1) - chunk.data = scale * chunk.data - self.register_parameter(f'embedding_chunk_{i}', chunk) - self.embedding_table_chucks.append(chunk) - - input_dim = self.num_dense_features - bot_mlp_blocks = [] - for layer_idx, dense_dim in enumerate(self.mlp_bottom_dims): - block = [] - block.append(nn.Linear(input_dim, dense_dim)) - block.append(nn.ReLU(inplace=True)) - block = nn.Sequential(*block) - if layer_idx > 0: - block = DenseBlock(block, resnet=True) - else: - block = DenseBlock(block) - bot_mlp_blocks.append(block) - input_dim = dense_dim - self.bot_mlp = nn.Sequential(*bot_mlp_blocks) - - for module in self.bot_mlp.modules(): - if isinstance(module, nn.Linear): - limit = math.sqrt(6. / (module.in_features + module.out_features)) - nn.init.uniform_(module.weight.data, -limit, limit) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - # Number of sparse features = 26 - fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1] - num_layers_top = len(self.mlp_top_dims) - mlp_top_blocks = [] - for layer_idx, fan_out in enumerate(self.mlp_top_dims): - block = [] - block.append(nn.Linear(fan_in, fan_out)) - if layer_idx < (num_layers_top - 1): - block.append(nn.ReLU(inplace=True)) - if layer_idx == num_layers_top - 2: - block.append(CustomDropout()) - block = SequentialWithDropout(*block) - if (layer_idx != 0) and (layer_idx != num_layers_top - 1): - block = DenseBlockWithDropout(block, resnet=True) - else: - block = DenseBlockWithDropout(block) - mlp_top_blocks.append(block) - fan_in = fan_out - self.top_mlp = SequentialWithDropout(*mlp_top_blocks) - - for module in self.top_mlp.modules(): - if isinstance(module, nn.Linear): - nn.init.normal_( - module.weight.data, - 0., - math.sqrt(2. / (module.in_features + module.out_features))) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - def forward(self, x, dropout_rate=DROPOUT_RATE): - - batch_size = x.shape[0] - - dense_features, sparse_features = torch.split( - x, [self.num_dense_features, self.num_sparse_features], 1) - - # Bottom MLP. - embedded_dense = self.bot_mlp(dense_features) - - # Sparse feature processing. - sparse_features = sparse_features.to(dtype=torch.int32) - idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size - embedding_table = torch.cat(self.embedding_table_chucks, dim=0) - embedded_sparse = embedding_table[idx_lookup] - embedded_sparse = torch.reshape(embedded_sparse, - [batch_size, 26 * self.embed_dim]) - top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) - - # Final MLP. - logits = self.top_mlp(top_mlp_input, dropout_rate) - return logits - - -class DlrmSmall(nn.Module): - """Define a DLRM-Small model. - - Parameters: - vocab_size: vocab size of embedding table. - num_dense_features: number of dense features as the bottom mlp input. - mlp_bottom_dims: dimensions of dense layers of the bottom mlp. - mlp_top_dims: dimensions of dense layers of the top mlp. - embed_dim: embedding dimension. - """ - - def __init__(self, - vocab_size, - num_dense_features=13, - num_sparse_features=26, - mlp_bottom_dims=(512, 256, 128), - mlp_top_dims=(1024, 1024, 512, 256, 1), - embed_dim=128, - use_layer_norm=False, - embedding_init_multiplier=None): - super().__init__() - self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) - self.num_dense_features = num_dense_features - self.num_sparse_features = num_sparse_features - self.mlp_bottom_dims = mlp_bottom_dims - self.mlp_top_dims = mlp_top_dims - self.embed_dim = embed_dim - self.embedding_init_multiplier = embedding_init_multiplier - - # Ideally, we should use the pooled embedding implementation from - # `TorchRec`. However, in order to have identical implementation - # with that of Jax, we define a single embedding matrix. - num_chunks = 4 - assert vocab_size % num_chunks == 0 - self.embedding_table_chucks = [] - - if self.embedding_init_multiplier is None: - scale = 1.0 / torch.sqrt(self.vocab_size) - else: - scale = self.embedding_init_multiplier - - for i in range(num_chunks): - chunk = nn.Parameter( - torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) - chunk.data.uniform_(0, 1) - chunk.data = scale * chunk.data - self.register_parameter(f'embedding_chunk_{i}', chunk) - self.embedding_table_chucks.append(chunk) - - input_dim = self.num_dense_features - bottom_mlp_layers = [] - for dense_dim in self.mlp_bottom_dims: - bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim)) - bottom_mlp_layers.append(nn.ReLU(inplace=True)) - if use_layer_norm: - bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) - input_dim = dense_dim - self.bot_mlp = nn.Sequential(*bottom_mlp_layers) - for module in self.bot_mlp.modules(): - if isinstance(module, nn.Linear): - limit = math.sqrt(6. / (module.in_features + module.out_features)) - nn.init.uniform_(module.weight.data, -limit, limit) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) - - # TODO: Write down the formula here instead of the constant. - input_dims = 506 - num_layers_top = len(self.mlp_top_dims) - top_mlp_layers = [] - for layer_idx, fan_out in enumerate(self.mlp_top_dims): - fan_in = input_dims if layer_idx == 0 \ - else self.mlp_top_dims[layer_idx - 1] - top_mlp_layers.append(nn.Linear(fan_in, fan_out)) - if layer_idx < (num_layers_top - 1): - top_mlp_layers.append(nn.ReLU(inplace=True)) - if use_layer_norm: - top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) - if layer_idx == num_layers_top - 2: - top_mlp_layers.append(CustomDropout()) - self.top_mlp = SequentialWithDropout(*top_mlp_layers) - if use_layer_norm: - self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) - else: - self.embed_ln = None - for module in self.top_mlp.modules(): - if isinstance(module, nn.Linear): - nn.init.normal_( - module.weight.data, - 0., - math.sqrt(2. / (module.in_features + module.out_features))) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - def forward(self, x, dropout_rate=DROPOUT_RATE): - - batch_size = x.shape[0] - - dense_features, sparse_features = torch.split( - x, [self.num_dense_features, self.num_sparse_features], 1) - - # Bottom MLP. - embedded_dense = self.bot_mlp(dense_features) - - # Sparse feature processing. - sparse_features = sparse_features.to(dtype=torch.int32) - idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size - embedding_table = torch.cat(self.embedding_table_chucks, dim=0) - embedded_sparse = embedding_table[idx_lookup] - embedded_sparse = torch.reshape(embedded_sparse, - [batch_size, -1, self.embed_dim]) - if self.embed_ln: - embedded_sparse = self.embed_ln(embedded_sparse) - # Dot product interactions. - concatenated_dense = self.dot_interact( - dense_features=embedded_dense, sparse_features=embedded_sparse) - - # Final MLP. - logits = self.top_mlp(concatenated_dense, dropout_rate) - return logits diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models.py b/algoperf/workloads/fastmri/fastmri_pytorch/models.py index 28f20bf20..0b8ac5499 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models.py @@ -13,6 +13,9 @@ from torch.nn import functional as F from algoperf import init_utils +from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout + +DROPOUT_RATE = 0.0 class UNet(nn.Module): @@ -27,7 +30,6 @@ def __init__(self, out_chans: int = 1, num_channels: int = 32, num_pool_layers: int = 4, - dropout_rate: Optional[float] = 0.0, use_tanh: bool = False, use_layer_norm: bool = False) -> None: super().__init__() @@ -36,21 +38,19 @@ def __init__(self, self.out_chans = out_chans self.num_channels = num_channels self.num_pool_layers = num_pool_layers - if dropout_rate is None: - dropout_rate = 0.0 + self.down_sample_layers = nn.ModuleList([ ConvBlock(in_chans, num_channels, - dropout_rate, use_tanh, use_layer_norm) ]) ch = num_channels for _ in range(num_pool_layers - 1): self.down_sample_layers.append( - ConvBlock(ch, ch * 2, dropout_rate, use_tanh, use_layer_norm)) + ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)) ch *= 2 - self.conv = ConvBlock(ch, ch * 2, dropout_rate, use_tanh, use_layer_norm) + self.conv = ConvBlock(ch, ch * 2, use_tanh, use_layer_norm) self.up_conv = nn.ModuleList() self.up_transpose_conv = nn.ModuleList() @@ -59,14 +59,14 @@ def __init__(self, self.up_transpose_conv.append( TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) self.up_conv.append( - ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm)) + ConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) ch //= 2 self.up_transpose_conv.append( TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) self.up_conv.append( - nn.Sequential( - ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm), + SequentialWithDropout( + ConvBlock(ch * 2, ch, use_tanh, use_layer_norm), nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), )) @@ -74,24 +74,28 @@ def __init__(self, if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): init_utils.pytorch_default_init(m) - def forward(self, x: Tensor) -> Tensor: + def forward( + self, + x: Tensor, + dropout_rate: float = DROPOUT_RATE) -> Tensor: + stack = [] output = x # apply down-sampling layers for layer in self.down_sample_layers: - output = layer(output) + output = layer(output, dropout_rate) stack.append(output) output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) - output = self.conv(output) + output = self.conv(output, dropout_rate) # apply up-sampling layers for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): downsample_layer = stack.pop() output = transpose_conv(output) - # reflect pad on the right/botton if needed to handle + # reflect pad on the right/bottom if needed to handle # odd input dimensions padding = [0, 0, 0, 0] if output.shape[-1] != downsample_layer.shape[-1]: @@ -102,7 +106,7 @@ def forward(self, x: Tensor) -> Tensor: output = F.pad(output, padding, "reflect") output = torch.cat([output, downsample_layer], dim=1) - output = conv(output) + output = conv(output, dropout_rate) return output @@ -114,10 +118,10 @@ class ConvBlock(nn.Module): def __init__(self, in_chans: int, out_chans: int, - dropout_rate: float, use_tanh: bool, use_layer_norm: bool) -> None: super().__init__() + self._supports_custom_dropout = True if use_layer_norm: norm_layer = partial(nn.GroupNorm, 1, eps=1e-6) @@ -127,19 +131,19 @@ def __init__(self, activation_fn = nn.Tanh() else: activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.conv_layers = nn.Sequential( + self.conv_layers = SequentialWithDropout( nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), norm_layer(out_chans), activation_fn, - nn.Dropout2d(dropout_rate), + CustomDropout2d(), nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), norm_layer(out_chans), activation_fn, - nn.Dropout2d(dropout_rate), + CustomDropout2d(), ) - def forward(self, x: Tensor) -> Tensor: - return self.conv_layers(x) + def forward(self, x: Tensor, dropout_rate: float) -> Tensor: + return self.conv_layers(x, dropout_rate) class TransposeConvBlock(nn.Module): diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py deleted file mode 100644 index 0b8ac5499..000000000 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py +++ /dev/null @@ -1,173 +0,0 @@ -"""U-Net Model. - -Adapted from fastMRI: -https://github.com/facebookresearch/fastMRI/blob/main/fastmri/models/unet.py -""" - -from functools import partial -from typing import Optional - -import torch -from torch import nn -from torch import Tensor -from torch.nn import functional as F - -from algoperf import init_utils -from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout - -DROPOUT_RATE = 0.0 - - -class UNet(nn.Module): - r"""U-Net model from - `"U-net: Convolutional networks - for biomedical image segmentation" - `_. - """ - - def __init__(self, - in_chans: int = 1, - out_chans: int = 1, - num_channels: int = 32, - num_pool_layers: int = 4, - use_tanh: bool = False, - use_layer_norm: bool = False) -> None: - super().__init__() - - self.in_chans = in_chans - self.out_chans = out_chans - self.num_channels = num_channels - self.num_pool_layers = num_pool_layers - - self.down_sample_layers = nn.ModuleList([ - ConvBlock(in_chans, - num_channels, - use_tanh, - use_layer_norm) - ]) - ch = num_channels - for _ in range(num_pool_layers - 1): - self.down_sample_layers.append( - ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)) - ch *= 2 - self.conv = ConvBlock(ch, ch * 2, use_tanh, use_layer_norm) - - self.up_conv = nn.ModuleList() - self.up_transpose_conv = nn.ModuleList() - - for _ in range(num_pool_layers - 1): - self.up_transpose_conv.append( - TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) - self.up_conv.append( - ConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) - ch //= 2 - - self.up_transpose_conv.append( - TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) - self.up_conv.append( - SequentialWithDropout( - ConvBlock(ch * 2, ch, use_tanh, use_layer_norm), - nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), - )) - - for m in self.modules(): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): - init_utils.pytorch_default_init(m) - - def forward( - self, - x: Tensor, - dropout_rate: float = DROPOUT_RATE) -> Tensor: - - stack = [] - output = x - - # apply down-sampling layers - for layer in self.down_sample_layers: - output = layer(output, dropout_rate) - stack.append(output) - output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) - - output = self.conv(output, dropout_rate) - - # apply up-sampling layers - for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): - downsample_layer = stack.pop() - output = transpose_conv(output) - - # reflect pad on the right/bottom if needed to handle - # odd input dimensions - padding = [0, 0, 0, 0] - if output.shape[-1] != downsample_layer.shape[-1]: - padding[1] = 1 # padding right - if output.shape[-2] != downsample_layer.shape[-2]: - padding[3] = 1 # padding bottom - if torch.sum(torch.tensor(padding)) != 0: - output = F.pad(output, padding, "reflect") - - output = torch.cat([output, downsample_layer], dim=1) - output = conv(output, dropout_rate) - - return output - - -class ConvBlock(nn.Module): - # A Convolutional Block that consists of two convolution layers each - # followed by instance normalization, LeakyReLU activation and dropout_rate. - - def __init__(self, - in_chans: int, - out_chans: int, - use_tanh: bool, - use_layer_norm: bool) -> None: - super().__init__() - self._supports_custom_dropout = True - - if use_layer_norm: - norm_layer = partial(nn.GroupNorm, 1, eps=1e-6) - else: - norm_layer = nn.InstanceNorm2d - if use_tanh: - activation_fn = nn.Tanh() - else: - activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.conv_layers = SequentialWithDropout( - nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), - norm_layer(out_chans), - activation_fn, - CustomDropout2d(), - nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), - norm_layer(out_chans), - activation_fn, - CustomDropout2d(), - ) - - def forward(self, x: Tensor, dropout_rate: float) -> Tensor: - return self.conv_layers(x, dropout_rate) - - -class TransposeConvBlock(nn.Module): - # A Transpose Convolutional Block that consists of one convolution transpose - # layers followed by instance normalization and LeakyReLU activation. - - def __init__( - self, - in_chans: int, - out_chans: int, - use_tanh: bool, - use_layer_norm: bool, - ): - super().__init__() - if use_tanh: - activation_fn = nn.Tanh() - else: - activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.layers = nn.Sequential( - nn.ConvTranspose2d( - in_chans, out_chans, kernel_size=2, stride=2, bias=False), - nn.InstanceNorm2d(out_chans), - activation_fn, - ) - - def forward(self, x: Tensor) -> Tensor: - return self.layers(x) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index fcf0992d3..60e09edb5 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -14,7 +14,9 @@ from algoperf import init_utils from algoperf import spec -from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention +from algoperf.workloads.wmt.wmt_pytorch.models_dropout import MultiheadAttention + +DROPOUT_RATE = 0.0 def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: @@ -41,18 +43,15 @@ def __init__( self, width: int, mlp_dim: Optional[int] = None, # Defaults to 4x input dim. - use_glu: bool = False, - dropout_rate: float = 0.0) -> None: + use_glu: bool = False) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim or 4 * width self.use_glu = use_glu - self.dropout_rate = dropout_rate self.linear1 = nn.Linear(self.width, self.mlp_dim) self.act_fnc = nn.GELU(approximate='tanh') - self.dropout = nn.Dropout(self.dropout_rate) if self.use_glu: self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) @@ -70,7 +69,8 @@ def reset_parameters(self) -> None: if module.bias is not None: module.bias.data.normal_(std=1e-6) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: + x = self.linear1(x) x = self.act_fnc(x) @@ -78,7 +78,7 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: y = self.glu_linear(x) x = x * y - x = self.dropout(x) + x = F.dropout(x, dropout_rate, training=self.training) x = self.linear2(x) return x @@ -88,8 +88,7 @@ class SelfAttention(nn.Module): def __init__(self, width: int, - num_heads: int = 8, - dropout_rate: float = 0.0) -> None: + num_heads: int = 8) -> None: super().__init__() self.width = width @@ -104,7 +103,6 @@ def __init__(self, self.query = nn.Linear(self.width, self.all_head_dim) self.key = nn.Linear(self.width, self.all_head_dim) self.value = nn.Linear(self.width, self.all_head_dim) - self.dropout = nn.Dropout(dropout_rate) self.out = nn.Linear(self.width, self.width) self.reset_parameters() @@ -120,7 +118,8 @@ def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor: x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: + mixed_query_layer = self.query(x) key_layer = self.transpose_for_scores(self.key(x)) @@ -131,7 +130,7 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: attention_scores = attention_scores / math.sqrt(self.head_dim) attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.dropout(attention_probs) + attention_probs = F.dropout(attention_probs, dropout_rate, self.training) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() @@ -149,8 +148,7 @@ def __init__(self, mlp_dim: Optional[int] = None, num_heads: int = 12, use_glu: bool = False, - use_post_layer_norm: bool = False, - dropout_rate: float = 0.0) -> None: + use_post_layer_norm: bool = False) -> None: super().__init__() self.width = width @@ -161,35 +159,34 @@ def __init__(self, self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) self.self_attention1 = SelfAttention(self.width, self.num_heads) - self.dropout = nn.Dropout(dropout_rate) self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) self.mlp3 = MlpBlock( width=self.width, mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - dropout_rate=dropout_rate) + use_glu=self.use_glu) + + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - def forward(self, x: spec.Tensor) -> spec.Tensor: if not self.use_post_layer_norm: y = self.layer_norm0(x) - y = self.self_attention1(y) - y = self.dropout(y) + y = self.self_attention1(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y y = self.layer_norm2(x) - y = self.mlp3(y) - y = self.dropout(y) + y = self.mlp3(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y else: y = x - y = self.self_attention1(y) - y = self.dropout(y) + y = self.self_attention1(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y x = self.layer_norm0(x) y = x - y = self.mlp3(y) - y = self.dropout(y) + y = self.mlp3(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y x = self.layer_norm2(x) return x @@ -204,8 +201,7 @@ def __init__(self, mlp_dim: Optional[int] = None, num_heads: int = 12, use_glu: bool = False, - use_post_layer_norm: bool = False, - dropout_rate: float = 0.0) -> None: + use_post_layer_norm: bool = False) -> None: super().__init__() self.depth = depth @@ -220,8 +216,7 @@ def __init__(self, self.mlp_dim, self.num_heads, self.use_glu, - self.use_post_layer_norm, - dropout_rate) for _ in range(depth) + self.use_post_layer_norm) for _ in range(depth) ]) if not self.use_post_layer_norm: @@ -229,10 +224,10 @@ def __init__(self, else: self.encoder_norm = None - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: # Input Encoder. for block in self.net: - x = block(x) + x = block(x, dropout_rate) if not self.use_post_layer_norm: return self.encoder_norm(x) else: @@ -259,13 +254,13 @@ def __init__(self, self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: n, _, _ = x.shape probe = torch.tile(self.probe, [n, 1, 1]) - x = self.mha(probe, x)[0] + x = self.mha(probe, x, dropout_rate=dropout_rate)[0] y = self.layer_norm(x) - x = x + self.mlp(y) + x = x + self.mlp(y, dropout_rate) return x[:, 0] @@ -285,15 +280,12 @@ def __init__( mlp_dim: Optional[int] = None, # Defaults to 4x input dim. num_heads: int = 12, rep_size: Union[int, bool] = True, - dropout_rate: Optional[float] = 0.0, head_zeroinit: bool = True, use_glu: bool = False, use_post_layer_norm: bool = False, use_map: bool = False, dtype: Any = torch.float32) -> None: super().__init__() - if dropout_rate is None: - dropout_rate = 0.0 self.num_classes = num_classes self.patch_size = patch_size @@ -318,7 +310,6 @@ def __init__( self.patch_size, stride=self.patch_size, padding='valid') - self.dropout = nn.Dropout(p=dropout_rate) self.encoder = Encoder( depth=self.depth, @@ -326,8 +317,7 @@ def __init__( mlp_dim=self.mlp_dim, num_heads=self.num_heads, use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=dropout_rate) + use_post_layer_norm=self.use_post_layer_norm) if self.num_classes: self.head = nn.Linear(self.width, self.num_classes) @@ -355,7 +345,11 @@ def reset_parameters(self) -> None: def get_posemb(self, x: spec.Tensor) -> spec.Tensor: return posemb_sincos_2d(x).type(self.dtype) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward( + self, + x: spec.Tensor, + dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: + # Patch extraction. x = self.conv_patch_extract(x) @@ -367,11 +361,11 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: x = torch.transpose(torch.reshape(x, (n, c, h * w)), 1, 2) x = x + pes - x = self.dropout(x) - x = self.encoder(x) + x = F.dropout(x, dropout_rate, training=self.training) + x = self.encoder(x, dropout_rate) if self.use_map: - x = self.map(x) + x = self.map(x, dropout_rate) else: x = torch.mean(x, dim=1) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py deleted file mode 100644 index 60e09edb5..000000000 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py +++ /dev/null @@ -1,378 +0,0 @@ -"""PyTorch implementation of refactored and simplified ViT. - -Adapted from: -https://github.com/huggingface/transformers/tree/main/src/transformers/models/vit -and https://github.com/lucidrains/vit-pytorch. -""" - -import math -from typing import Any, Optional, Tuple, Union - -import torch -from torch import nn -import torch.nn.functional as F - -from algoperf import init_utils -from algoperf import spec -from algoperf.workloads.wmt.wmt_pytorch.models_dropout import MultiheadAttention - -DROPOUT_RATE = 0.0 - - -def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: - """Follows the MoCo v3 logic.""" - _, width, h, w = patches.shape - device = patches.device - y, x = torch.meshgrid(torch.arange(h, device=device), - torch.arange(w, device=device), indexing='ij') - - if width % 4 != 0: - raise ValueError('Width must be mult of 4 for sincos posemb.') - omega = torch.arange(width // 4, device=device) / (width // 4 - 1) - omega = 1. / (temperature**omega) - y = y.flatten()[:, None] * omega[None, :] - x = x.flatten()[:, None] * omega[None, :] - pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) - return pe[None, :, :] - - -class MlpBlock(nn.Module): - """Transformer MLP / feed-forward block.""" - - def __init__( - self, - width: int, - mlp_dim: Optional[int] = None, # Defaults to 4x input dim. - use_glu: bool = False) -> None: - super().__init__() - - self.width = width - self.mlp_dim = mlp_dim or 4 * width - self.use_glu = use_glu - - self.linear1 = nn.Linear(self.width, self.mlp_dim) - self.act_fnc = nn.GELU(approximate='tanh') - - if self.use_glu: - self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) - else: - self.glu_linear = None - - self.linear2 = nn.Linear(self.mlp_dim, self.width) - - self.reset_parameters() - - def reset_parameters(self) -> None: - for module in self.modules(): - if isinstance(module, nn.Linear): - nn.init.xavier_uniform_(module.weight.data) - if module.bias is not None: - module.bias.data.normal_(std=1e-6) - - def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - - x = self.linear1(x) - x = self.act_fnc(x) - - if self.use_glu: - y = self.glu_linear(x) - x = x * y - - x = F.dropout(x, dropout_rate, training=self.training) - x = self.linear2(x) - return x - - -class SelfAttention(nn.Module): - """Self-attention special case of multi-head dot-product attention.""" - - def __init__(self, - width: int, - num_heads: int = 8) -> None: - super().__init__() - - self.width = width - self.num_heads = num_heads - - assert width % num_heads == 0, ( - 'Memory dimension must be divisible by number of heads.') - - self.head_dim = int(width / num_heads) - self.all_head_dim = self.num_heads * self.head_dim - - self.query = nn.Linear(self.width, self.all_head_dim) - self.key = nn.Linear(self.width, self.all_head_dim) - self.value = nn.Linear(self.width, self.all_head_dim) - self.out = nn.Linear(self.width, self.width) - self.reset_parameters() - - def reset_parameters(self) -> None: - for module in self.modules(): - if isinstance(module, nn.Linear): - nn.init.xavier_uniform_(module.weight.data) - if module.bias is not None: - nn.init.constant_(module.bias.data, 0.) - - def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor: - new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dim) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - - mixed_query_layer = self.query(x) - - key_layer = self.transpose_for_scores(self.key(x)) - value_layer = self.transpose_for_scores(self.value(x)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.head_dim) - - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = F.dropout(attention_probs, dropout_rate, self.training) - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_dim,) - context_layer = context_layer.view(new_context_layer_shape) - out = self.out(context_layer) - return out - - -class Encoder1DBlock(nn.Module): - """Single transformer encoder block (MHSA + MLP).""" - - def __init__(self, - width: int, - mlp_dim: Optional[int] = None, - num_heads: int = 12, - use_glu: bool = False, - use_post_layer_norm: bool = False) -> None: - super().__init__() - - self.width = width - self.mlp_dim = mlp_dim - self.num_heads = num_heads - self.use_glu = use_glu - self.use_post_layer_norm = use_post_layer_norm - - self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) - self.self_attention1 = SelfAttention(self.width, self.num_heads) - self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) - self.mlp3 = MlpBlock( - width=self.width, - mlp_dim=self.mlp_dim, - use_glu=self.use_glu) - - def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - - if not self.use_post_layer_norm: - y = self.layer_norm0(x) - y = self.self_attention1(y, dropout_rate) - y = F.dropout(y, dropout_rate, training=self.training) - x = x + y - - y = self.layer_norm2(x) - y = self.mlp3(y, dropout_rate) - y = F.dropout(y, dropout_rate, training=self.training) - x = x + y - else: - y = x - y = self.self_attention1(y, dropout_rate) - y = F.dropout(y, dropout_rate, training=self.training) - x = x + y - x = self.layer_norm0(x) - - y = x - y = self.mlp3(y, dropout_rate) - y = F.dropout(y, dropout_rate, training=self.training) - x = x + y - x = self.layer_norm2(x) - return x - - -class Encoder(nn.Module): - """Transformer Model Encoder for sequence to sequence translation.""" - - def __init__(self, - depth: int, - width: int, - mlp_dim: Optional[int] = None, - num_heads: int = 12, - use_glu: bool = False, - use_post_layer_norm: bool = False) -> None: - super().__init__() - - self.depth = depth - self.width = width - self.mlp_dim = mlp_dim - self.num_heads = num_heads - self.use_glu = use_glu - self.use_post_layer_norm = use_post_layer_norm - - self.net = nn.ModuleList([ - Encoder1DBlock(self.width, - self.mlp_dim, - self.num_heads, - self.use_glu, - self.use_post_layer_norm) for _ in range(depth) - ]) - - if not self.use_post_layer_norm: - self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) - else: - self.encoder_norm = None - - def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - # Input Encoder. - for block in self.net: - x = block(x, dropout_rate) - if not self.use_post_layer_norm: - return self.encoder_norm(x) - else: - return x - - -class MAPHead(nn.Module): - """Multihead Attention Pooling.""" - - def __init__(self, - width: int, - mlp_dim: Optional[int] = None, - num_heads: int = 12): - super().__init__() - self.width = width - self.mlp_dim = mlp_dim - self.num_heads = num_heads - - self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) - nn.init.xavier_uniform_(self.probe.data) - - self.mha = MultiheadAttention( - self.width, num_heads=self.num_heads, self_attn=False, bias=True) - self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) - self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) - - def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - n, _, _ = x.shape - probe = torch.tile(self.probe, [n, 1, 1]) - - x = self.mha(probe, x, dropout_rate=dropout_rate)[0] - y = self.layer_norm(x) - x = x + self.mlp(y, dropout_rate) - return x[:, 0] - - -class ViT(nn.Module): - """ViT model.""" - - image_height: int = 224 - image_width: int = 224 - channels: int = 3 - - def __init__( - self, - num_classes: int = 1000, - patch_size: Tuple[int, int] = (16, 16), - width: int = 768, - depth: int = 12, - mlp_dim: Optional[int] = None, # Defaults to 4x input dim. - num_heads: int = 12, - rep_size: Union[int, bool] = True, - head_zeroinit: bool = True, - use_glu: bool = False, - use_post_layer_norm: bool = False, - use_map: bool = False, - dtype: Any = torch.float32) -> None: - super().__init__() - - self.num_classes = num_classes - self.patch_size = patch_size - self.width = width - self.depth = depth - self.mlp_dim = mlp_dim - self.num_heads = num_heads - self.rep_size = rep_size - self.head_zeroinit = head_zeroinit - self.use_glu = use_glu - self.use_post_layer_norm = use_post_layer_norm - self.use_map = use_map - self.dtype = dtype - - if self.rep_size: - rep_size = self.width if self.rep_size is True else self.rep_size - self.pre_logits = nn.Linear(self.width, rep_size) - - self.conv_patch_extract = nn.Conv2d( - self.channels, - self.width, - self.patch_size, - stride=self.patch_size, - padding='valid') - - self.encoder = Encoder( - depth=self.depth, - width=self.width, - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm) - - if self.num_classes: - self.head = nn.Linear(self.width, self.num_classes) - - if self.use_map: - self.map = MAPHead(self.width, self.mlp_dim, self.num_heads) - else: - self.map = None - - self.reset_parameters() - - def reset_parameters(self) -> None: - init_utils.pytorch_default_init(self.conv_patch_extract) - - if self.rep_size: - init_utils.pytorch_default_init(self.pre_logits) - - if self.num_classes: - if self.head_zeroinit: - nn.init.constant_(self.head.weight.data, 0.) - nn.init.constant_(self.head.bias.data, 0.) - else: - init_utils.pytorch_default_init(self.head) - - def get_posemb(self, x: spec.Tensor) -> spec.Tensor: - return posemb_sincos_2d(x).type(self.dtype) - - def forward( - self, - x: spec.Tensor, - dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: - - # Patch extraction. - x = self.conv_patch_extract(x) - - # Add posemb before adding extra token. - n, c, h, w = x.shape - pes = self.get_posemb(x) - - # Reshape to match Jax's ViT implementation. - x = torch.transpose(torch.reshape(x, (n, c, h * w)), 1, 2) - x = x + pes - - x = F.dropout(x, dropout_rate, training=self.training) - x = self.encoder(x, dropout_rate) - - if self.use_map: - x = self.map(x, dropout_rate) - else: - x = torch.mean(x, dim=1) - - if self.rep_size: - x = torch.tanh(self.pre_logits(x)) - - if self.num_classes: - x = self.head(x) - - return x diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py index db1e24521..a6a60bf95 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from functools import partial import math -from typing import Tuple +from typing import Optional, Tuple import torch from torch import nn @@ -17,6 +17,8 @@ from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ SpecAug +DROPOUT_RATE = 0.1 + @dataclass class ConformerConfig: @@ -26,10 +28,7 @@ class ConformerConfig: num_attention_heads: int = 8 num_encoder_layers: int = 4 attention_dropout_rate: float = 0.0 - attention_residual_dropout_rate: float = 0.1 - conv_residual_dropout_rate: float = 0.0 feed_forward_dropout_rate: float = 0.0 - feed_forward_residual_dropout_rate: float = 0.1 convolution_kernel_size: int = 5 feed_forward_expansion_factor: int = 4 freq_mask_count: int = 2 @@ -39,7 +38,6 @@ class ConformerConfig: time_mask_max_ratio: float = 0.05 time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True - input_dropout_rate: float = 0.1 batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 use_specaug: bool = True @@ -77,11 +75,9 @@ class Subsample(nn.Module): def __init__(self, encoder_dim: int = 0, - input_dropout_rate: float = 0.0, num_bins: int = 80): super().__init__() self.encoder_dim = encoder_dim - self.input_dropout_rate = input_dropout_rate self.conv1 = Conv2dSubsampling( input_channels=1, output_channels=encoder_dim) @@ -93,9 +89,9 @@ def __init__(self, out_features=self.encoder_dim, bias=True) self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) - self.dropout = nn.Dropout(p=self.input_dropout_rate, inplace=True) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): + output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -109,7 +105,7 @@ def forward(self, inputs, input_paddings): outputs = self.linear(outputs) outputs = outputs + self.pos_encode(seq_length=outputs.shape[1]) - outputs = self.dropout(outputs) + outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) return outputs, output_paddings @@ -201,15 +197,8 @@ def __init__(self, config: ConformerConfig): out_features=config.encoder_dim, bias=True) - if config.feed_forward_residual_dropout_rate is None: - feed_forward_residual_dropout_rate = 0.1 - else: - feed_forward_residual_dropout_rate = ( - config.feed_forward_residual_dropout_rate) - self.dropout2 = nn.Dropout( - p=feed_forward_residual_dropout_rate, inplace=True) + def forward(self, inputs, padding_mask, dropout_rate): - def forward(self, inputs, padding_mask): inputs = self.ln(inputs) inputs = self.linear1(inputs) if self.config.activation_function_name == 'swish': @@ -226,7 +215,7 @@ def forward(self, inputs, padding_mask): inputs = inputs * padding_mask inputs = self.linear2(inputs) inputs = inputs * padding_mask - inputs = self.dropout2(inputs) + inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) return inputs @@ -280,7 +269,7 @@ def __init__(self, config: ConformerConfig): super().__init__() self.embed_dim = config.encoder_dim self.num_heads = config.num_attention_heads - self.dropout = config.attention_dropout_rate + self.attention_dropout_rate = config.attention_dropout_rate self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim) self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim) self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads) @@ -297,7 +286,7 @@ def forward(self, inputs, key_padding_mask=None): key=k, value=v, attn_mask=~key_padding_mask[:, None, None], - dropout_p=self.dropout, + dropout_p=self.attention_dropout_rate, ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim) out = out * self.attention_temperature out = self.out_proj(out) @@ -313,19 +302,14 @@ def __init__(self, config: ConformerConfig): self.ln = LayerNorm(dim=config.encoder_dim) self.self_attention = MHSAwithQS(config) - if config.attention_residual_dropout_rate is None: - attention_residual_dropout_rate = 0.1 - else: - attention_residual_dropout_rate = config.attention_residual_dropout_rate - self.dropout = nn.Dropout(p=attention_residual_dropout_rate, inplace=True) - def forward(self, outputs, paddings): + def forward(self, outputs, paddings, dropout_rate): outputs = self.ln(outputs) outputs = self.self_attention( outputs, key_padding_mask=paddings == 1, ) - outputs = self.dropout(outputs) + outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) return outputs @@ -405,13 +389,8 @@ def __init__(self, config): groups=config.encoder_dim) self.bn = BatchNorm(config) self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim) - if config.conv_residual_dropout_rate is None: - conv_residual_dropout_rate = 0.0 - else: - conv_residual_dropout_rate = config.conv_residual_dropout_rate - self.dropout = nn.Dropout(p=conv_residual_dropout_rate, inplace=True) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): inputs = self.ln(inputs) inputs = F.glu(torch.cat([self.lin1(inputs), self.lin2(inputs)], dim=2)) @@ -433,7 +412,7 @@ def forward(self, inputs, input_paddings): inputs = activation_fn(inputs) inputs = self.lin3(inputs) - inputs = self.dropout(inputs) + inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) return inputs @@ -450,12 +429,12 @@ def __init__(self, config: ConformerConfig): if config.use_post_layer_norm: self.ln = LayerNorm(dim=config.encoder_dim) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): padding_mask = 1 - input_paddings[:, :, None] - inputs = inputs + 0.5 * self.ff1(inputs, padding_mask) - inputs = inputs + self.mhsa(inputs, input_paddings) - inputs = inputs + self.conv(inputs, input_paddings) - inputs = inputs + 0.5 * self.ff2(inputs, padding_mask) + inputs = inputs + 0.5 * self.ff1(inputs, padding_mask, dropout_rate) + inputs = inputs + self.mhsa(inputs, input_paddings, dropout_rate) + inputs = inputs + self.conv(inputs, input_paddings, dropout_rate) + inputs = inputs + 0.5 * self.ff2(inputs, padding_mask, dropout_rate) if self.ln: inputs = self.ln(inputs) return inputs @@ -480,13 +459,8 @@ def __init__(self, config: ConformerConfig): time_masks_per_frame=config.time_masks_per_frame, use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames ) - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate self.subsample = Subsample( encoder_dim=config.encoder_dim, - input_dropout_rate=input_dropout_rate, num_bins=preprocessing_config.num_bins) self.conformers = nn.ModuleList( [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) @@ -494,15 +468,15 @@ def __init__(self, config: ConformerConfig): self.ln = LayerNorm(config.encoder_dim) self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs = inputs output_paddings = input_paddings outputs, output_paddings = self.preprocessor(outputs, output_paddings) if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings) + outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) for conformer in self.conformers: - outputs = conformer(outputs, output_paddings) + outputs = conformer(outputs, output_paddings, dropout_rate) outputs = self.ln(outputs) outputs = self.lin(outputs) return outputs, output_paddings diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py deleted file mode 100644 index a6a60bf95..000000000 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py +++ /dev/null @@ -1,482 +0,0 @@ -"""This is a pytorch implementation mirroring: -https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py. -""" - -from dataclasses import dataclass -from functools import partial -import math -from typing import Optional, Tuple - -import torch -from torch import nn -from torch.nn import init -import torch.nn.functional as F - -from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ - preprocessor -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ - SpecAug - -DROPOUT_RATE = 0.1 - - -@dataclass -class ConformerConfig: - """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" - vocab_size: int = 1024 - encoder_dim: int = 512 - num_attention_heads: int = 8 - num_encoder_layers: int = 4 - attention_dropout_rate: float = 0.0 - feed_forward_dropout_rate: float = 0.0 - convolution_kernel_size: int = 5 - feed_forward_expansion_factor: int = 4 - freq_mask_count: int = 2 - freq_mask_max_bins: int = 27 - time_mask_count: int = 10 - time_mask_max_frames: int = 40 - time_mask_max_ratio: float = 0.05 - time_masks_per_frame: float = 0.0 - use_dynamic_time_mask_max_frames: bool = True - batch_norm_momentum: float = 1 - 0.999 - batch_norm_epsilon: float = 0.001 - use_specaug: bool = True - attention_temperature: float = 1.0 - activation_function_name: str = 'swish' - use_post_layer_norm: bool = True - - -def initialize(m): - if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d): - init.xavier_uniform_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.MultiheadAttention): - init.xavier_uniform_(m.in_proj_weight) - for i in m.children(): - initialize(i) - - -class LayerNorm(nn.Module): - - def __init__(self, dim, epsilon=1e-6): - super().__init__() - self.dim = dim - - self.scale = nn.Parameter(torch.zeros(self.dim)) - self.bias = nn.Parameter(torch.zeros(self.dim)) - self.epsilon = epsilon - - def forward(self, x): - return F.layer_norm(x, (self.dim,), 1 + self.scale, self.bias, self.epsilon) - - -class Subsample(nn.Module): - - def __init__(self, - encoder_dim: int = 0, - num_bins: int = 80): - super().__init__() - self.encoder_dim = encoder_dim - - self.conv1 = Conv2dSubsampling( - input_channels=1, output_channels=encoder_dim) - self.conv2 = Conv2dSubsampling( - input_channels=encoder_dim, output_channels=encoder_dim) - - self.linear = nn.Linear( - in_features=self.encoder_dim * num_bins // 4, - out_features=self.encoder_dim, - bias=True) - self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) - - def forward(self, inputs, input_paddings, dropout_rate): - - output_paddings = input_paddings - outputs = inputs[:, None, :, :] - - outputs, output_paddings = self.conv1(outputs, output_paddings) - outputs, output_paddings = self.conv2(outputs, output_paddings) - - batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape - outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size, - subsampled_lengths, - subsampled_dims * channels) - - outputs = self.linear(outputs) - outputs = outputs + self.pos_encode(seq_length=outputs.shape[1]) - outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) - - return outputs, output_paddings - - -class Conv2dSubsampling(nn.Module): - - def __init__(self, - input_channels: int, - output_channels: int, - filter_stride: Tuple[int] = (2, 2), - padding: str = 'SAME'): - super().__init__() - - self.input_channels = input_channels - self.output_channels = output_channels - self.filter_stride = filter_stride - self.padding = padding - - self.filter_shape = (output_channels, input_channels, 3, 3) - - self.kernel = nn.Parameter( - torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) - self.bias = nn.Parameter(torch.zeros(output_channels)) - self.register_buffer('paddings_kernel', torch.ones([1, 1, 1])) - - def get_same_padding(self, input_shape): - in_height, in_width = input_shape[2:] - stride_height, stride_width = self.filter_stride - filter_height, filter_width = 3, 3 - if in_height % stride_height == 0: - pad_along_height = max(filter_height - stride_height, 0) - else: - pad_along_height = max(filter_height - (in_height % stride_height), 0) - if in_width % stride_width == 0: - pad_along_width = max(filter_width - stride_width, 0) - else: - pad_along_width = max(filter_width - (in_width % stride_width), 0) - pad_top = pad_along_height // 2 - pad_bottom = pad_along_height - pad_top - pad_left = pad_along_width // 2 - pad_right = pad_along_width - pad_left - return (pad_left, pad_right, pad_top, pad_bottom) - - def forward(self, inputs, paddings): - groups = inputs.shape[1] // self.input_channels - - if self.padding == 'SAME': - in_ = F.pad(inputs, self.get_same_padding(inputs.shape)) - else: - in_ = inputs - outputs = F.conv2d( - input=in_, - weight=self.kernel, - bias=self.bias, - stride=self.filter_stride, - dilation=(1, 1), - groups=groups) - - outputs = F.relu(outputs) - - input_length = paddings.shape[1] - stride = self.filter_stride[0] - pad_len = (input_length + stride - 1) // stride * stride - input_length - padded_paddings = F.pad( - paddings[:, None, :], (0, pad_len), mode='constant', value=0) - out_padding = F.conv1d( - input=padded_paddings, - weight=self.paddings_kernel, - stride=self.filter_stride[:1]) - out_padding = out_padding.squeeze(dim=1) - outputs = outputs * (1 - out_padding[:, None, :, None]) - return outputs, out_padding - - -class FeedForwardModule(nn.Module): - - def __init__(self, config: ConformerConfig): - super().__init__() - self.config = config - - self.ln = LayerNorm(dim=config.encoder_dim) - self.linear1 = nn.Linear( - in_features=config.encoder_dim, - out_features=config.encoder_dim * config.feed_forward_expansion_factor, - bias=True) - self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate, inplace=True) - self.linear2 = nn.Linear( - in_features=config.encoder_dim * config.feed_forward_expansion_factor, - out_features=config.encoder_dim, - bias=True) - - def forward(self, inputs, padding_mask, dropout_rate): - - inputs = self.ln(inputs) - inputs = self.linear1(inputs) - if self.config.activation_function_name == 'swish': - activation_fn = F.silu - elif self.config.activation_function_name == 'gelu': - # Use tanh approximation of GELU which is default for jax - activation_fn = partial(F.gelu, approximate='tanh') - else: - raise ValueError('Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{self.config.activation_function_name}') - inputs = activation_fn(inputs) - inputs = self.dropout1(inputs) - inputs = inputs * padding_mask - inputs = self.linear2(inputs) - inputs = inputs * padding_mask - inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) - - return inputs - - -class AddPositionalEmbedding(nn.Module): - - def __init__(self, - min_timescale: int = 1, - max_timescale: int = 10_000, - embedding_dim: int = 512): - super().__init__() - self.min_timescale = min_timescale - self.max_timescale = max_timescale - self.embedding_dim = embedding_dim - num_timescales = self.embedding_dim // 2 - log_timescale_increment = math.log( - float(self.max_timescale) / float(self.min_timescale)) / ( - num_timescales - 1) - inv_timescales = self.min_timescale * \ - torch.exp(torch.arange(num_timescales, dtype=torch.float32) - * -log_timescale_increment) - self.register_buffer('inv_timescales', inv_timescales[None, None, :]) - - def forward(self, seq_length): - position = torch.arange( - end=seq_length, dtype=torch.float32, device=self.inv_timescales.device) - scaled_time = position[None, :, None] * self.inv_timescales - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) - if self.embedding_dim % 2: - signal = torch.cat( - [signal, torch.zeros(signal.shape[0], signal.shape[1], 1)], dim=2) - return signal - - -class QueryScaler(nn.Module): - - def __init__(self, dim): - super().__init__() - self.dim = dim - self.scale = nn.Parameter(torch.zeros(self.dim)) - - def forward(self, inputs): - r_softplus_0 = 1.442695041 - scale = r_softplus_0 * F.softplus(self.scale) - return inputs * scale - - -class MHSAwithQS(nn.Module): - - def __init__(self, config: ConformerConfig): - super().__init__() - self.embed_dim = config.encoder_dim - self.num_heads = config.num_attention_heads - self.attention_dropout_rate = config.attention_dropout_rate - self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim) - self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim) - self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads) - self.attention_temperature = config.attention_temperature - - def forward(self, inputs, key_padding_mask=None): - batch_size, seq_len, embed_dim = inputs.shape - q, k, v = self.in_proj(inputs).split(self.embed_dim, dim=2) - q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2) - k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) - v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) - out = F.scaled_dot_product_attention( - query=q, - key=k, - value=v, - attn_mask=~key_padding_mask[:, None, None], - dropout_p=self.attention_dropout_rate, - ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim) - out = out * self.attention_temperature - out = self.out_proj(out) - return out - - -class MultiHeadedSelfAttention(nn.Module): - - def __init__(self, config: ConformerConfig): - super().__init__() - - self.config = config - - self.ln = LayerNorm(dim=config.encoder_dim) - self.self_attention = MHSAwithQS(config) - - def forward(self, outputs, paddings, dropout_rate): - outputs = self.ln(outputs) - outputs = self.self_attention( - outputs, - key_padding_mask=paddings == 1, - ) - outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) - return outputs - - -class BatchNorm(nn.Module): - - def __init__(self, config: ConformerConfig): - super().__init__() - running_mean = torch.zeros(config.encoder_dim) - running_var = torch.ones(config.encoder_dim) - self.register_buffer('running_mean', running_mean) - self.register_buffer('running_var', running_var) - self.scale = nn.Parameter(torch.zeros(config.encoder_dim)) - self.bias = nn.Parameter(torch.zeros(config.encoder_dim)) - - self.register_buffer('dim', torch.FloatTensor([config.encoder_dim])) - self.momentum = config.batch_norm_momentum - self.epsilon = config.batch_norm_epsilon - - def forward(self, inputs, input_paddings): - #inputs: NHD - #padding: NH - """ - Alternatively: - inputs[input_paddings==0] = F.batch_norm( - input = inputs[input_paddings==0], - running_mean = self.running_mean, - running_var = self.running_var, - weight = 1+self.scale, - bias = self.bias, - training = self.training, - momentum=1-self.momentum, - eps=self.epsilon - ) - inputs.masked_fill(input_paddings[...,None] != 0, 0) - return inputs - """ - mask = 1 - input_paddings[:, :, None] - if self.training: - count = mask.sum() - masked_inp = inputs.masked_fill(mask == 0, 0) - mean = (masked_inp).sum(dim=(0, 1)) / count - var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count - - self.running_mean = (1 - self.momentum) * self.running_mean + ( - self.momentum) * mean.detach() - self.running_var = (1 - self.momentum) * self.running_var + ( - self.momentum) * var.detach() - - else: - mean = self.running_mean - var = self.running_var - v = (1 + self.scale) * torch.rsqrt(var + self.epsilon) - bn = (inputs - mean) * v + self.bias - output = bn.masked_fill(mask == 0, 0) - return output - - -class ConvolutionBlock(nn.Module): - - def __init__(self, config): - super().__init__() - - self.config = config - self.ln = LayerNorm(dim=config.encoder_dim) - self.lin1 = nn.Linear( - in_features=config.encoder_dim, out_features=config.encoder_dim) - self.lin2 = nn.Linear( - in_features=config.encoder_dim, out_features=config.encoder_dim) - - self.conv1 = nn.Conv1d( - in_channels=config.encoder_dim, - out_channels=config.encoder_dim, - kernel_size=(config.convolution_kernel_size,), - stride=(1,), - padding='same', - bias=False, - groups=config.encoder_dim) - self.bn = BatchNorm(config) - self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim) - - def forward(self, inputs, input_paddings, dropout_rate): - inputs = self.ln(inputs) - - inputs = F.glu(torch.cat([self.lin1(inputs), self.lin2(inputs)], dim=2)) - inputs = inputs * (1 - input_paddings[:, :, None]) - - inputs = inputs.permute(0, 2, 1) - inputs = self.conv1(inputs) - inputs = inputs.permute(0, 2, 1) - - inputs = self.bn(inputs, input_paddings) - if self.config.activation_function_name == 'swish': - activation_fn = F.silu - elif self.config.activation_function_name == 'gelu': - activation_fn = F.gelu - else: - raise ValueError('Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{self.config.activation_function_name}') - inputs = activation_fn(inputs) - inputs = self.lin3(inputs) - - inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) - return inputs - - -class ConformerBlock(nn.Module): - - def __init__(self, config: ConformerConfig): - super().__init__() - - self.ff1 = FeedForwardModule(config) - self.mhsa = MultiHeadedSelfAttention(config) - self.conv = ConvolutionBlock(config) - self.ff2 = FeedForwardModule(config) - self.ln = None - if config.use_post_layer_norm: - self.ln = LayerNorm(dim=config.encoder_dim) - - def forward(self, inputs, input_paddings, dropout_rate): - padding_mask = 1 - input_paddings[:, :, None] - inputs = inputs + 0.5 * self.ff1(inputs, padding_mask, dropout_rate) - inputs = inputs + self.mhsa(inputs, input_paddings, dropout_rate) - inputs = inputs + self.conv(inputs, input_paddings, dropout_rate) - inputs = inputs + 0.5 * self.ff2(inputs, padding_mask, dropout_rate) - if self.ln: - inputs = self.ln(inputs) - return inputs - - -class ConformerEncoderDecoder(nn.Module): - - def __init__(self, config: ConformerConfig): - super().__init__() - self.config = config - preprocessing_config = preprocessor.PreprocessorConfig() - self.preprocessor = preprocessor.MelFilterbankFrontend( - preprocessing_config, - per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR) - self.specaug = SpecAug( - freq_mask_count=config.freq_mask_count, - freq_mask_max_bins=config.freq_mask_max_bins, - time_mask_count=config.time_mask_count, - time_mask_max_frames=config.time_mask_max_frames, - time_mask_max_ratio=config.time_mask_max_ratio, - time_masks_per_frame=config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames - ) - self.subsample = Subsample( - encoder_dim=config.encoder_dim, - num_bins=preprocessing_config.num_bins) - self.conformers = nn.ModuleList( - [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) - - self.ln = LayerNorm(config.encoder_dim) - self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - - def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): - outputs = inputs - output_paddings = input_paddings - outputs, output_paddings = self.preprocessor(outputs, output_paddings) - if self.training and self.config.use_specaug: - outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) - for conformer in self.conformers: - outputs = conformer(outputs, output_paddings, dropout_rate) - outputs = self.ln(outputs) - outputs = self.lin(outputs) - return outputs, output_paddings diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index 84d317326..a8480a343 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -17,6 +17,7 @@ SpecAug USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ +DROPOUT_RATE = 0.1 @dataclass @@ -38,10 +39,6 @@ class DeepspeechConfig: use_dynamic_time_mask_max_frames: bool = True batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.1. - feed_forward_dropout_rate: Optional[float] = 0.1 enable_residual_connections: bool = True enable_decoder_layer_norm: bool = True bidirectional: bool = True @@ -87,13 +84,8 @@ def __init__(self, config: DeepspeechConfig): self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True) - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate - self.dropout = nn.Dropout(p=input_dropout_rate) + def forward(self, inputs, input_paddings, dropout_rate): - def forward(self, inputs, input_paddings): output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -106,7 +98,7 @@ def forward(self, inputs, input_paddings): subsampled_dims * channels) outputs = self.lin(outputs) - outputs = self.dropout(outputs) + outputs = F.dropout(outputs, dropout_rate, training=self.training) return outputs, output_paddings @@ -205,13 +197,9 @@ def __init__(self, config: DeepspeechConfig): batch_norm_momentum=config.batch_norm_momentum, batch_norm_epsilon=config.batch_norm_epsilon) self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True) - if config.feed_forward_dropout_rate is None: - feed_forward_dropout_rate = 0.1 - else: - feed_forward_dropout_rate = config.feed_forward_dropout_rate - self.dropout = nn.Dropout(p=feed_forward_dropout_rate) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): + padding_mask = (1 - input_paddings)[:, :, None] if self.config.layernorm_everywhere: inputs = self.normalization_layer(inputs) @@ -226,7 +214,7 @@ def forward(self, inputs, input_paddings): inputs = F.relu(inputs) inputs = inputs * padding_mask - inputs = self.dropout(inputs) + inputs = F.dropout(inputs, dropout_rate, training=self.training) return inputs @@ -363,14 +351,14 @@ def __init__(self, config: DeepspeechConfig): self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs = inputs output_paddings = input_paddings outputs, output_paddings = self.preprocessor(outputs, output_paddings) if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings) + outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) for idx in range(self.config.num_lstm_layers): if self.config.enable_residual_connections: outputs = outputs + self.lstms[idx](outputs, output_paddings) @@ -379,9 +367,9 @@ def forward(self, inputs, input_paddings): for idx in range(self.config.num_ffn_layers): if self.config.enable_residual_connections: - outputs = outputs + self.ffns[idx](outputs, output_paddings) + outputs = outputs + self.ffns[idx](outputs, output_paddings, dropout_rate) else: - outputs = self.ffns[idx](outputs, output_paddings) + outputs = self.ffns[idx](outputs, output_paddings, dropout_rate) if self.config.enable_decoder_layer_norm: outputs = self.ln(outputs) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py deleted file mode 100644 index a8480a343..000000000 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py +++ /dev/null @@ -1,379 +0,0 @@ -"""This is a pytorch implementation mirroring: -https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py. -""" - -from dataclasses import dataclass -import os -from typing import Optional, Tuple - -import torch -from torch import nn -import torch.distributed.nn as dist_nn -import torch.nn.functional as F - -from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ - preprocessor -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ - SpecAug - -USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ -DROPOUT_RATE = 0.1 - - -@dataclass -class DeepspeechConfig: - """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" - vocab_size: int = 1024 - encoder_dim: int = 512 - num_lstm_layers: int = 6 - num_ffn_layers: int = 3 - conv_subsampling_factor: int = 2 - conv_subsampling_layers: int = 2 - use_specaug: bool = True - freq_mask_count: int = 2 - freq_mask_max_bins: int = 27 - time_mask_count: int = 10 - time_mask_max_frames: int = 40 - time_mask_max_ratio: float = 0.05 - time_masks_per_frame: float = 0.0 - use_dynamic_time_mask_max_frames: bool = True - batch_norm_momentum: float = 1 - 0.999 - batch_norm_epsilon: float = 0.001 - enable_residual_connections: bool = True - enable_decoder_layer_norm: bool = True - bidirectional: bool = True - use_tanh: bool = False - layernorm_everywhere: bool = False - - -class LayerNorm(nn.Module): - - def __init__(self, dim, epsilon=1e-6): - super().__init__() - self.dim = dim - - self.scale = nn.Parameter(torch.zeros(self.dim)) - self.bias = nn.Parameter(torch.zeros(self.dim)) - self.epsilon = epsilon - - def forward(self, x): - mean = x.mean(dim=-1, keepdims=True) - var = x.var(dim=-1, unbiased=False, keepdims=True) - - normed_x = (x - mean) * torch.rsqrt(var + self.epsilon) - normed_x *= (1 + self.scale) - normed_x += self.bias - - return normed_x - - -class Subsample(nn.Module): - - def __init__(self, config: DeepspeechConfig): - super().__init__() - encoder_dim = config.encoder_dim - - self.encoder_dim = encoder_dim - - self.conv1 = Conv2dSubsampling( - input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh) - self.conv2 = Conv2dSubsampling( - input_channels=encoder_dim, - output_channels=encoder_dim, - use_tanh=config.use_tanh) - - self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True) - - def forward(self, inputs, input_paddings, dropout_rate): - - output_paddings = input_paddings - outputs = inputs[:, None, :, :] - - outputs, output_paddings = self.conv1(outputs, output_paddings) - outputs, output_paddings = self.conv2(outputs, output_paddings) - - batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape - outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size, - subsampled_lengths, - subsampled_dims * channels) - - outputs = self.lin(outputs) - outputs = F.dropout(outputs, dropout_rate, training=self.training) - - return outputs, output_paddings - - -class Conv2dSubsampling(nn.Module): - - def __init__(self, - input_channels: int, - output_channels: int, - filter_stride: Tuple[int] = (2, 2), - padding: str = 'SAME', - batch_norm_momentum: float = 0.999, - batch_norm_epsilon: float = 0.001, - use_tanh: bool = False): - super().__init__() - - self.input_channels = input_channels - self.output_channels = output_channels - self.filter_stride = filter_stride - self.padding = padding - - self.filter_shape = (output_channels, input_channels, 3, 3) - - self.kernel = nn.Parameter( - nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) - self.bias = nn.Parameter(torch.zeros(output_channels)) - - self.use_tanh = use_tanh - - def get_same_padding(self, input_shape): - in_height, in_width = input_shape[2:] - stride_height, stride_width = self.filter_stride - filter_height, filter_width = 3, 3 - if in_height % stride_height == 0: - pad_along_height = max(filter_height - stride_height, 0) - else: - pad_along_height = max(filter_height - (in_height % stride_height), 0) - if in_width % stride_width == 0: - pad_along_width = max(filter_width - stride_width, 0) - else: - pad_along_width = max(filter_width - (in_width % stride_width), 0) - pad_top = pad_along_height // 2 - pad_bottom = pad_along_height - pad_top - pad_left = pad_along_width // 2 - pad_right = pad_along_width - pad_left - return (pad_left, pad_right, pad_top, pad_bottom) - - def forward(self, inputs, paddings): - groups = inputs.shape[1] // self.input_channels - - if self.padding == 'SAME': - in_ = F.pad(inputs, self.get_same_padding(inputs.shape)) - else: - in_ = inputs - outputs = F.conv2d( - input=in_, - weight=self.kernel, - bias=self.bias, - stride=self.filter_stride, - dilation=(1, 1), - groups=groups) - - if self.use_tanh: - outputs = F.tanh(outputs) - else: - outputs = F.relu(outputs) - - input_length = paddings.shape[1] - stride = self.filter_stride[0] - pad_len = (input_length + stride - 1) // stride * stride - input_length - out_padding = F.conv1d( - input=torch.cat([ - paddings[:, None, :], - torch.zeros( - size=(paddings.shape[0], 1, pad_len), device=paddings.device) - ], - dim=2), - weight=torch.ones([1, 1, 1], device=paddings.device), - stride=self.filter_stride[:1]) - out_padding = out_padding.squeeze(dim=1) - outputs = outputs * (1 - out_padding[:, None, :, None]) - return outputs, out_padding - - -class FeedForwardModule(nn.Module): - - def __init__(self, config: DeepspeechConfig): - super().__init__() - self.config = config - - if config.layernorm_everywhere: - self.normalization_layer = LayerNorm(config.encoder_dim) - else: - self.bn_normalization_layer = BatchNorm( - dim=config.encoder_dim, - batch_norm_momentum=config.batch_norm_momentum, - batch_norm_epsilon=config.batch_norm_epsilon) - self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True) - - def forward(self, inputs, input_paddings, dropout_rate): - - padding_mask = (1 - input_paddings)[:, :, None] - if self.config.layernorm_everywhere: - inputs = self.normalization_layer(inputs) - else: # batchnorm - inputs = self.bn_normalization_layer(inputs, input_paddings) - - inputs = self.lin(inputs) - - if self.config.use_tanh: - inputs = F.tanh(inputs) - else: - inputs = F.relu(inputs) - - inputs = inputs * padding_mask - inputs = F.dropout(inputs, dropout_rate, training=self.training) - - return inputs - - -class BatchNorm(nn.Module): - - def __init__(self, dim, batch_norm_momentum, batch_norm_epsilon): - super().__init__() - running_mean = torch.zeros(dim) - running_var = torch.ones(dim) - self.register_buffer('running_mean', running_mean) - self.register_buffer('running_var', running_var) - self.weight = nn.Parameter(torch.zeros(dim)) - self.bias = nn.Parameter(torch.zeros(dim)) - - self.momentum = batch_norm_momentum - self.epsilon = batch_norm_epsilon - self.dim = dim - - def forward(self, inputs, input_paddings): - #inputs: NHD - #padding: NH - mask = 1 - input_paddings[:, :, None] - if self.training: - count = mask.sum() - masked_inp = inputs.masked_fill(mask == 0, 0) - sum_ = (masked_inp).sum(dim=(0, 1)) - if USE_PYTORCH_DDP: - sum_ = dist_nn.all_reduce(sum_) - count = dist_nn.all_reduce(count) - mean = sum_ / count - - sum_ = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) - if USE_PYTORCH_DDP: - sum_ = dist_nn.all_reduce(sum_) - var = sum_ / count - - self.running_mean = (1 - self.momentum) * self.running_mean + ( - self.momentum) * mean.detach() - self.running_var = (1 - self.momentum) * self.running_var + ( - self.momentum) * var.detach() - else: - mean = self.running_mean - var = self.running_var - v = (1 + self.weight) * torch.rsqrt(var + self.epsilon) - bn = (inputs - mean) * v + self.bias - output = bn.masked_fill(mask == 0, 0) - return output - - -class BatchRNN(nn.Module): - - def __init__(self, config: DeepspeechConfig): - super().__init__() - self.config = config - hidden_size = config.encoder_dim - input_size = config.encoder_dim - bidirectional = config.bidirectional - self.bidirectional = bidirectional - - if config.layernorm_everywhere: - self.normalization_layer = LayerNorm(config.encoder_dim) - else: - self.bn_normalization_layer = BatchNorm(config.encoder_dim, - config.batch_norm_momentum, - config.batch_norm_epsilon) - - if bidirectional: - self.lstm = nn.LSTM( - input_size=input_size, - hidden_size=hidden_size // 2, - bidirectional=True, - batch_first=True) - else: - self.lstm = nn.LSTM( - input_size=input_size, hidden_size=hidden_size, batch_first=True) - - def forward(self, inputs, input_paddings): - if self.config.layernorm_everywhere: - inputs = self.normalization_layer(inputs) - else: - inputs = self.bn_normalization_layer(inputs, input_paddings) - lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy() - packed_inputs = torch.nn.utils.rnn.pack_padded_sequence( - inputs, lengths, batch_first=True, enforce_sorted=False) - packed_outputs, _ = self.lstm(packed_inputs) - outputs, _ = torch.nn.utils.rnn.pad_packed_sequence( - packed_outputs, batch_first=True) - if outputs.shape[1] < inputs.shape[1]: - outputs = torch.cat([ - outputs, - torch.zeros( - size=(outputs.shape[0], - inputs.shape[1] - outputs.shape[1], - outputs.shape[2]), - device=outputs.device) - ], - dim=1) - return outputs - - -class DeepspeechEncoderDecoder(nn.Module): - - def __init__(self, config: DeepspeechConfig): - super().__init__() - self.config = config - - self.specaug = SpecAug( - freq_mask_count=config.freq_mask_count, - freq_mask_max_bins=config.freq_mask_max_bins, - time_mask_count=config.time_mask_count, - time_mask_max_frames=config.time_mask_max_frames, - time_mask_max_ratio=config.time_mask_max_ratio, - time_masks_per_frame=config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames - ) - preprocessing_config = preprocessor.PreprocessorConfig() - self.preprocessor = preprocessor.MelFilterbankFrontend( - preprocessing_config, - per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR) - - self.subsample = Subsample(config=config) - - self.lstms = nn.ModuleList( - [BatchRNN(config) for _ in range(config.num_lstm_layers)]) - self.ffns = nn.ModuleList( - [FeedForwardModule(config) for _ in range(config.num_ffn_layers)]) - - if config.enable_decoder_layer_norm: - self.ln = LayerNorm(config.encoder_dim) - else: - self.ln = nn.Identity() - - self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - - def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): - outputs = inputs - output_paddings = input_paddings - - outputs, output_paddings = self.preprocessor(outputs, output_paddings) - if self.training and self.config.use_specaug: - outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) - for idx in range(self.config.num_lstm_layers): - if self.config.enable_residual_connections: - outputs = outputs + self.lstms[idx](outputs, output_paddings) - else: - outputs = self.lstms[idx](outputs, output_paddings) - - for idx in range(self.config.num_ffn_layers): - if self.config.enable_residual_connections: - outputs = outputs + self.ffns[idx](outputs, output_paddings, dropout_rate) - else: - outputs = self.ffns[idx](outputs, output_paddings, dropout_rate) - - if self.config.enable_decoder_layer_norm: - outputs = self.ln(outputs) - - outputs = self.lin(outputs) - - return outputs, output_paddings diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models.py b/algoperf/workloads/ogbg/ogbg_pytorch/models.py index fe9b29bc1..be5882333 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/models.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models.py @@ -9,17 +9,20 @@ from torch import nn from algoperf import init_utils +from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout +DROPOUT_RATE = 0.1 -def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): + +def _make_mlp(in_dim, hidden_dims, activation_fn): """Creates a MLP with specified dimensions.""" - layers = nn.Sequential() + layers = SequentialWithDropout() for i, dim in enumerate(hidden_dims): layers.add_module(f'dense_{i}', nn.Linear(in_features=in_dim, out_features=dim)) layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) layers.add_module(f'activation_fn_{i}', activation_fn()) - layers.add_module(f'dropout_{i}', nn.Dropout(dropout_rate)) + layers.add_module(f'dropout_{i}', CustomDropout()) in_dim = dim return layers @@ -33,7 +36,6 @@ class GNN(nn.Module): def __init__(self, num_outputs: int = 128, - dropout_rate: Optional[float] = 0.1, activation_fn_name: str = 'relu', latent_dim: int = 256, hidden_dims: Tuple[int] = (256,), @@ -43,8 +45,6 @@ def __init__(self, self.hidden_dims = hidden_dims self.num_message_passing_steps = num_message_passing_steps self.num_outputs = num_outputs - if dropout_rate is None: - dropout_rate = 0.1 # in_features are specifically chosen for the ogbg workload. self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim) self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim) @@ -77,17 +77,14 @@ def __init__(self, GraphNetwork( update_edge_fn=_make_mlp(in_dim_edge_fn, self.hidden_dims, - dropout_rate, activation_fn), update_node_fn=_make_mlp(in_dim_node_fn, self.hidden_dims, - dropout_rate, activation_fn), update_global_fn=_make_mlp(last_in_dim, self.hidden_dims, - dropout_rate, activation_fn))) - self.graph_network = nn.Sequential(*graph_network_layers) + self.graph_network = SequentialWithDropout(*graph_network_layers) self.decoder = nn.Linear( in_features=self.hidden_dims[-1], out_features=self.num_outputs) @@ -96,14 +93,18 @@ def __init__(self, if isinstance(m, nn.Linear): init_utils.pytorch_default_init(m) - def forward(self, graph: GraphsTuple) -> torch.Tensor: + def forward( + self, + graph: GraphsTuple, + dropout_rate: float = DROPOUT_RATE) -> torch.Tensor: + graph = graph._replace( globals=torch.zeros([graph.n_node.shape[0], self.num_outputs], device=graph.n_node.device)) graph = graph._replace(nodes=self.node_embedder(graph.nodes)) graph = graph._replace(edges=self.edge_embedder(graph.edges)) - graph = self.graph_network(graph) + graph = self.graph_network(graph, dropout_rate) # Map globals to represent the final result graph = graph._replace(globals=self.decoder(graph.globals)) @@ -145,8 +146,9 @@ def __init__(self, self.update_edge_fn = update_edge_fn self.update_node_fn = update_node_fn self.update_global_fn = update_global_fn + self._supports_custom_dropout = True # supports SequentialWithDropout - def forward(self, graph: GraphsTuple) -> GraphsTuple: + def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple: """Applies a configured GraphNetwork to a graph. This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 There is one difference. For the nodes update the class aggregates over the @@ -159,6 +161,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: GraphNets, for more information please see the paper. Args: graph: a `GraphsTuple` containing the graph. + dropout_rate: dropout probability value. Returns: Updated `GraphsTuple`. """ @@ -179,7 +182,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: edge_fn_inputs = torch.cat( [edges, sent_attributes, received_attributes, global_edge_attributes], dim=-1) - edges = self.update_edge_fn(edge_fn_inputs) + edges = self.update_edge_fn(edge_fn_inputs, dropout_rate) if self.update_node_fn: sent_attributes = tree.tree_map( @@ -194,7 +197,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: node_fn_inputs = torch.cat( [nodes, sent_attributes, received_attributes, global_attributes], dim=-1) - nodes = self.update_node_fn(node_fn_inputs) + nodes = self.update_node_fn(node_fn_inputs, dropout_rate) if self.update_global_fn: n_graph = n_node.shape[0] @@ -213,7 +216,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: # These pooled nodes are the inputs to the global update fn. global_fn_inputs = torch.cat([node_attributes, edge_attributes, globals_], dim=-1) - globals_ = self.update_global_fn(global_fn_inputs) + globals_ = self.update_global_fn(global_fn_inputs, dropout_rate) return GraphsTuple( nodes=nodes, diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py deleted file mode 100644 index be5882333..000000000 --- a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py +++ /dev/null @@ -1,315 +0,0 @@ -# Ported to PyTorch from -# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. -from functools import partial -from typing import Callable, Optional, Tuple - -import jax.tree_util as tree -from jraph import GraphsTuple -import torch -from torch import nn - -from algoperf import init_utils -from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout - -DROPOUT_RATE = 0.1 - - -def _make_mlp(in_dim, hidden_dims, activation_fn): - """Creates a MLP with specified dimensions.""" - layers = SequentialWithDropout() - for i, dim in enumerate(hidden_dims): - layers.add_module(f'dense_{i}', - nn.Linear(in_features=in_dim, out_features=dim)) - layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) - layers.add_module(f'activation_fn_{i}', activation_fn()) - layers.add_module(f'dropout_{i}', CustomDropout()) - in_dim = dim - return layers - - -class GNN(nn.Module): - """Defines a graph network. - - The model assumes the input data is a jraph.GraphsTuple without global - variables. The final prediction will be encoded in the globals. - """ - - def __init__(self, - num_outputs: int = 128, - activation_fn_name: str = 'relu', - latent_dim: int = 256, - hidden_dims: Tuple[int] = (256,), - num_message_passing_steps: int = 5) -> None: - super().__init__() - self.latent_dim = latent_dim - self.hidden_dims = hidden_dims - self.num_message_passing_steps = num_message_passing_steps - self.num_outputs = num_outputs - # in_features are specifically chosen for the ogbg workload. - self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim) - self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim) - - if activation_fn_name == 'relu': - activation_fn = nn.ReLU - elif activation_fn_name == 'gelu': - activation_fn = partial(nn.GELU, approximate='tanh') - elif activation_fn_name == 'silu': - activation_fn = nn.SiLU - else: - raise ValueError( - f'Invalid activation function name: {self.activation_fn_name}') - - graph_network_layers = [] - for st in range(self.num_message_passing_steps): - # Constants in in_dims are based on forward call of GraphNetwork: - # specifically update_edge_fn update_node_fn and update_global_fn. - if st == 0: - in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs - in_dim_node_fn = self.latent_dim + self.hidden_dims[ - -1] * 2 + self.num_outputs - last_in_dim = self.hidden_dims[-1] * 2 + self.num_outputs - else: - in_dim_edge_fn = self.hidden_dims[-1] * 4 - in_dim_node_fn = self.hidden_dims[-1] * 4 - last_in_dim = self.hidden_dims[-1] * 3 - - graph_network_layers.append( - GraphNetwork( - update_edge_fn=_make_mlp(in_dim_edge_fn, - self.hidden_dims, - activation_fn), - update_node_fn=_make_mlp(in_dim_node_fn, - self.hidden_dims, - activation_fn), - update_global_fn=_make_mlp(last_in_dim, - self.hidden_dims, - activation_fn))) - self.graph_network = SequentialWithDropout(*graph_network_layers) - - self.decoder = nn.Linear( - in_features=self.hidden_dims[-1], out_features=self.num_outputs) - - for m in self.modules(): - if isinstance(m, nn.Linear): - init_utils.pytorch_default_init(m) - - def forward( - self, - graph: GraphsTuple, - dropout_rate: float = DROPOUT_RATE) -> torch.Tensor: - - graph = graph._replace( - globals=torch.zeros([graph.n_node.shape[0], self.num_outputs], - device=graph.n_node.device)) - graph = graph._replace(nodes=self.node_embedder(graph.nodes)) - graph = graph._replace(edges=self.edge_embedder(graph.edges)) - - graph = self.graph_network(graph, dropout_rate) - - # Map globals to represent the final result - graph = graph._replace(globals=self.decoder(graph.globals)) - - return graph.globals - - -class GraphNetwork(nn.Module): - """Returns a method that applies a configured GraphNetwork. - This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 - There is one difference. For the nodes update the class aggregates over the - sender edges and receiver edges separately. This is a bit more general - than the algorithm described in the paper. The original behaviour can be - recovered by using only the receiver edge aggregations for the update. - In addition this implementation supports softmax attention over incoming - edge features. - Example usage:: - gn = GraphNetwork(update_edge_function, - update_node_function, **kwargs) - # Conduct multiple rounds of message passing with the same parameters: - for _ in range(num_message_passing_steps): - graph = gn(graph) - Args: - update_edge_fn: function used to update the edges or None to deactivate edge - updates. - update_node_fn: function used to update the nodes or None to deactivate node - updates. - update_global_fn: function used to update the globals or None to deactivate - globals updates. - Returns: - A method that applies the configured GraphNetwork. - """ - - def __init__(self, - update_edge_fn: Optional[Callable] = None, - update_node_fn: Optional[Callable] = None, - update_global_fn: Optional[Callable] = None) -> None: - super().__init__() - self.update_edge_fn = update_edge_fn - self.update_node_fn = update_node_fn - self.update_global_fn = update_global_fn - self._supports_custom_dropout = True # supports SequentialWithDropout - - def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple: - """Applies a configured GraphNetwork to a graph. - This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 - There is one difference. For the nodes update the class aggregates over the - sender edges and receiver edges separately. This is a bit more general - the algorithm described in the paper. The original behaviour can be - recovered by using only the receiver edge aggregations for the update. - In addition this implementation supports softmax attention over incoming - edge features. - Many popular Graph Neural Networks can be implemented as special cases of - GraphNets, for more information please see the paper. - Args: - graph: a `GraphsTuple` containing the graph. - dropout_rate: dropout probability value. - Returns: - Updated `GraphsTuple`. - """ - nodes, edges, receivers, senders, globals_, n_node, n_edge = graph - sum_n_node = tree.tree_leaves(nodes)[0].shape[0] - if not tree.tree_all( - tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)): - raise ValueError( - 'All node arrays in nest must contain the same number of nodes.') - - sent_attributes = tree.tree_map(lambda n: n[senders], nodes) - received_attributes = tree.tree_map(lambda n: n[receivers], nodes) - # Here we scatter the global features to the corresponding edges, - # giving us tensors of shape [num_edges, global_feat]. - global_edge_attributes = tree.tree_map( - lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_) - if self.update_edge_fn: - edge_fn_inputs = torch.cat( - [edges, sent_attributes, received_attributes, global_edge_attributes], - dim=-1) - edges = self.update_edge_fn(edge_fn_inputs, dropout_rate) - - if self.update_node_fn: - sent_attributes = tree.tree_map( - lambda e: scatter_sum(e, senders, dim=0, dim_size=sum_n_node), edges) - received_attributes = tree.tree_map( - lambda e: scatter_sum(e, receivers, dim=0, dim_size=sum_n_node), - edges) - # Here we scatter the global features to the corresponding nodes, - # giving us tensors of shape [num_nodes, global_feat]. - global_attributes = tree.tree_map( - lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_) - node_fn_inputs = torch.cat( - [nodes, sent_attributes, received_attributes, global_attributes], - dim=-1) - nodes = self.update_node_fn(node_fn_inputs, dropout_rate) - - if self.update_global_fn: - n_graph = n_node.shape[0] - graph_idx = torch.arange(n_graph, device=graph.n_node.device) - # To aggregate nodes and edges from each graph to global features, - # we first construct tensors that map the node to the corresponding graph. - # For example, if you have `n_node=[1,2]`, we construct the tensor - # [0, 1, 1]. We then do the same for edges. - node_gr_idx = torch.repeat_interleave(graph_idx, n_node, dim=0) - edge_gr_idx = torch.repeat_interleave(graph_idx, n_edge, dim=0) - # We use the aggregation function to pool the nodes/edges per graph. - node_attributes = tree.tree_map( - lambda n: scatter_sum(n, node_gr_idx, dim=0, dim_size=n_graph), nodes) - edge_attributes = tree.tree_map( - lambda e: scatter_sum(e, edge_gr_idx, dim=0, dim_size=n_graph), edges) - # These pooled nodes are the inputs to the global update fn. - global_fn_inputs = torch.cat([node_attributes, edge_attributes, globals_], - dim=-1) - globals_ = self.update_global_fn(global_fn_inputs, dropout_rate) - - return GraphsTuple( - nodes=nodes, - edges=edges, - receivers=receivers, - senders=senders, - globals=globals_, - n_node=n_node, - n_edge=n_edge) - - -# Forked from -# github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py. -def scatter_sum(src: torch.Tensor, - index: torch.Tensor, - dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None) -> torch.Tensor: - r""" - | - .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ - master/docs/source/_figures/add.svg?sanitize=true - :align: center - :width: 400px - | - Reduces all values from the :attr:`src` tensor into :attr:`out` at the - indices specified in the :attr:`index` tensor along a given axis - :attr:`dim`. - For each value in :attr:`src`, its output index is specified by its index - in :attr:`src` for dimensions outside of :attr:`dim` and by the - corresponding value in :attr:`index` for dimension :attr:`dim`. - The applied reduction is here defined as a sum. - Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional - tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` - and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional - tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. - Moreover, the values of :attr:`index` must be between :math:`0` and - :math:`y - 1`, although no specific ordering of indices is required. - The :attr:`index` tensor supports broadcasting in case its dimensions do - not match with :attr:`src`. - For one-dimensional tensors, the operation computes - .. math:: - \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j - where :math:`\sum_j` is over :math:`j` such that - :math:`\mathrm{index}_j = i`. - .. note:: - This operation is implemented via atomic operations on the GPU and is - therefore **non-deterministic** since the order of parallel operations - to the same value is undetermined. - For floating-point variables, this results in a source of variance in - the result. - :param src: The source tensor. - :param index: The indices of elements to scatter. - :param dim: The axis along which to index. (default: :obj:`-1`) - :param out: The destination tensor. - :param dim_size: If :attr:`out` is not given, automatically create output - with size :attr:`dim_size` at dimension :attr:`dim`. - If :attr:`dim_size` is not given, a minimal sized output tensor - according to :obj:`index.max() + 1` is returned. - :rtype: :class:`Tensor` - .. code-block:: python - src = torch.randn(10, 6, 64) - index = torch.tensor([0, 1, 0, 1, 2, 1]) - # Broadcasting in the first and last dim. - out = scatter_sum(src, index, dim=1) - print(out.size()) - .. code-block:: - torch.Size([10, 3, 64]) - """ - index = broadcast(index, src, dim) - if out is None: - size = list(src.size()) - if dim_size is not None: - size[dim] = dim_size - elif index.numel() == 0: - size[dim] = 0 - else: - size[dim] = int(index.max()) + 1 - out = torch.zeros(size, dtype=src.dtype, device=src.device) - return out.scatter_add_(dim, index, src) - else: - return out.scatter_add_(dim, index, src) - - -# Forked from -# github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/utils.py. -def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): - if dim < 0: - dim = other.dim() + dim - if src.dim() == 1: - for _ in range(0, dim): - src = src.unsqueeze(0) - for _ in range(src.dim(), other.dim()): - src = src.unsqueeze(-1) - src = src.expand(other.size()) - return src diff --git a/algoperf/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py index a1c7ce15e..a43df30d4 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models.py @@ -9,6 +9,8 @@ from torch.nn.init import normal_ from torch.nn.init import xavier_uniform_ +DROPOUT_RATE = 0.1 + def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: """Make a causal mask for self-attention. @@ -104,26 +106,18 @@ def __init__(self, nhead: int = 16, d_hid: int = 1024, nlayers: int = 6, - dropout_rate: Optional[float] = 0.1, - attention_dropout_rate: Optional[float] = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, attention_temp: float = 1.0, pre_ln: bool = True): super().__init__() - if dropout_rate is None: - dropout_rate = 0.1 - if attention_dropout_rate is None: - attention_dropout_rate = 0.1 - self.pos_encoder = PositionalEncoding(d_model, dropout_rate) + self.pos_encoder = PositionalEncoding(d_model) self.shared_embedding = nn.Embedding(ntoken, d_model) self.encoder = Encoder(d_model, nhead, d_hid, nlayers, - dropout_rate, - attention_dropout_rate, activation, glu, layer_norm_eps, @@ -133,8 +127,6 @@ def __init__(self, nhead, d_hid, nlayers, - dropout_rate, - attention_dropout_rate, activation, glu, layer_norm_eps, @@ -163,7 +155,8 @@ def forward(self, targets_positions: Optional[Tensor] = None, inputs_segmentation: Optional[Tensor] = None, targets_segmentation: Optional[Tensor] = None, - decode: bool = False) -> Tensor: + decode: bool = False, + dropout_rate: float = DROPOUT_RATE) -> Tensor: """ Args: src: Tensor, shape [batch_size, seq_len] @@ -173,16 +166,19 @@ def forward(self, inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len] targets_segmentation: Optional[Tensor], shape [batch_size, seq_len] decode: bool + dropout_rate: float Returns: output Tensor of shape [batch_size, seq_len, ntoken] """ if src.size(0) != tgt.size(0): raise RuntimeError('The batch size of src and tgt must be equal.') + memory = self.encoder( src, inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation) + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate) output = self.decoder( tgt, memory, @@ -190,7 +186,8 @@ def forward(self, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, - decode=decode) + decode=decode, + dropout_rate=dropout_rate) return output @@ -229,12 +226,15 @@ def __init__(self, self.enable_nested_tensor = enable_nested_tensor self.mask_check = mask_check - def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: + def forward(self, src: Tensor, + mask: Optional[Tensor] = None, + dropout_rate: Optional[float] = 0.0) -> Tensor: """Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). mask: the mask for the src sequence (optional). + dropout_rate: the dropout probability (optional). Shape: see the docs in Transformer class. @@ -243,7 +243,7 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: convert_to_nested = False for mod in self.layers: - output = mod(output, src_mask=mask) + output = mod(output, src_mask=mask, dropout_rate=dropout_rate) if convert_to_nested: output = output.to_padded_tensor(0.) @@ -261,8 +261,6 @@ def __init__(self, nhead: int = 16, d_hid: int = 1024, nlayers: int = 6, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -276,8 +274,6 @@ def __init__(self, d_model, nhead, d_hid, - dropout_rate, - attention_dropout_rate=attention_dropout_rate, activation=activation, glu=glu, layer_norm_eps=layer_norm_eps, @@ -290,12 +286,13 @@ def __init__(self, def forward(self, src: Tensor, inputs_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None) -> Tensor: + inputs_segmentation: Optional[Tensor] = None, + dropout_rate: Optional[float] = 0.0) -> Tensor: src = src.to(torch.int) src_mask = make_src_mask(src, inputs_segmentation, self.nhead) src = self.shared_embedding(src) - src = self.pos_encoder(src, inputs_positions) - memory = self.encoder(src, mask=src_mask) + src = self.pos_encoder(src, inputs_positions, dropout_rate=dropout_rate) + memory = self.encoder(src, mask=src_mask, dropout_rate=dropout_rate) return memory @@ -306,8 +303,6 @@ def __init__(self, nhead: int = 16, d_hid: int = 1024, nlayers: int = 6, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -320,8 +315,6 @@ def __init__(self, self.decoder = TransformerDecoder(d_model, nhead, d_hid, - dropout_rate, - attention_dropout_rate, activation, glu, layer_norm_eps, @@ -339,7 +332,8 @@ def forward( targets_segmentation: Optional[Tensor] = None, decode: bool = False, max_len: Optional[int] = None, - cache: Optional[dict] = None) -> Any: + cache: Optional[dict] = None, + dropout_rate: Optional[float] = 0.0) -> Any: tgt = tgt.to(torch.int) tgt_mask, memory_mask = make_tgt_and_memory_mask( tgt, src, inputs_segmentation, targets_segmentation, @@ -347,7 +341,7 @@ def forward( if not decode: tgt = shift_right(tgt) tgt = self.shared_embedding(tgt) - tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache) + tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache, dropout_rate=dropout_rate) if decode: tgt, cache = tgt output = self.decoder( @@ -357,7 +351,8 @@ def forward( memory_mask=memory_mask, decode=decode, max_len=max_len, - cache=cache) + cache=cache, + dropout_rate=dropout_rate) if decode: output, cache = output normalize = math.sqrt(output.shape[-1]) @@ -371,10 +366,8 @@ class PositionalEncoding(nn.Module): def __init__(self, d_model: int, - dropout_rate: float = 0.1, max_len: int = 256): super().__init__() - self.dropout = nn.Dropout(p=dropout_rate) position = torch.arange(max_len).unsqueeze(1) scale_factor = -math.log(10000.0) / (d_model // 2 - 1) @@ -389,7 +382,8 @@ def forward( x: Tensor, inputs_positions: Optional[Tensor] = None, decode: bool = False, - cache: Optional[Dict[str, Dict[str, Tensor]]] = None + cache: Optional[Dict[str, Dict[str, Tensor]]] = None, + dropout_rate: Optional[float] = 0.0 ) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]: """ Args: @@ -397,6 +391,7 @@ def forward( inputs_positions: Tensor (shape [batch_size, seq_len]) or None decode: bool cache: Dict[str, Dict[str, Tensor]] or None + dropout_rate: Optional[float] Returns: Tensor or Tuple[Tensor, Dict[str, Dict[str, Tensor]]] """ @@ -412,14 +407,14 @@ def forward( } pe = self.pe[0, cache[name]['cache_index'], :] cache[name]['cache_index'] += 1 - return self.dropout(x + pe), cache + return F.dropout(x + pe, dropout_rate, self.training), cache if inputs_positions is None: # normal unpacked case: pe = self.pe[:, :x.size(1), :] else: # for packed data we need to use known position indices: pe = self.pe[0, inputs_positions, :] - return self.dropout(x + pe) + return F.dropout(x + pe, dropout_rate, self.training) # TransformerEncoderLayer and TransformerDecoderLayer are taken from: @@ -438,7 +433,6 @@ class TransformerEncoderLayer(nn.Module): nhead: the number of heads in the multiheadattention models (default=16). dim_feedforward: the dimension of the feedforward network model (default=1024). - dropout_rate: the dropout_rate value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components @@ -457,8 +451,6 @@ def __init__(self, d_model: int = 1024, nhead: int = 16, dim_feedforward: int = 1024, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -472,7 +464,6 @@ def __init__(self, d_model, nhead, self_attn=True, - dropout_rate=attention_dropout_rate, attention_temp=attention_temp, bias=False, **factory_kwargs) @@ -482,50 +473,55 @@ def __init__(self, self.glu = glu if self.glu: self.linear_glu = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.dropout = nn.Dropout(dropout_rate) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.pre_ln = pre_ln self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.dropout1 = nn.Dropout(dropout_rate) - self.dropout2 = nn.Dropout(dropout_rate) self.activation = activation - def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: + def forward(self, + src: Tensor, + src_mask: Optional[Tensor] = None, + dropout_rate: Optional[float] = 0.0) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). - + dropout_rate: the dropout probability value (optional). Shape: see the docs in Transformer class. """ x = src if self.pre_ln: - x = x + self._sa_block(self.norm1(x), src_mask) - x = x + self._ff_block(self.norm2(x)) + x = x + self._sa_block(self.norm1(x), src_mask, dropout_rate) + x = x + self._ff_block(self.norm2(x), dropout_rate) else: - x = self.norm1(x + self._sa_block(x, src_mask)) - x = self.norm2(x + self._ff_block(x)) + x = self.norm1(x + self._sa_block(x, src_mask, dropout_rate)) + x = self.norm2(x + self._ff_block(x, dropout_rate)) return x # Self-attention block: - def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tensor: - x, _ = self.self_attn(x, attn_mask=attn_mask) - return self.dropout1(x) + def _sa_block(self, + x: Tensor, + attn_mask: Optional[Tensor], + dropout_rate: Optional[float] = 0.0) -> Tensor: + x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, training=self.training) # Feed forward block: - def _ff_block(self, inputs: Tensor) -> Tensor: + def _ff_block(self, + inputs: Tensor, + dropout_rate: Optional[float] = 0.0) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) x = x * y - x = self.linear2(self.dropout(x)) - return self.dropout2(x) + x = self.linear2(F.dropout(x, dropout_rate, training=self.training)) + return F.dropout(x, dropout_rate, training=self.training) # Modified to use cache for autoregressive decoding and custom @@ -537,7 +533,6 @@ class TransformerDecoder(nn.Module): nhead: the number of heads in the multiheadattention models (default=16) d_hid: the dimension of the feedforward network model (default=1024) - dropout_rate: the dropout_rate value (default=0.1) layer_norm_eps: the eps value in layer normalization components (default=1e-6). decoder_layer: an instance of the TransformerDecoderLayer() class @@ -555,8 +550,6 @@ def __init__(self, d_model, nhead, d_hid, - dropout_rate, - attention_dropout_rate, activation, glu, layer_norm_eps, @@ -569,8 +562,6 @@ def __init__(self, d_model, nhead, d_hid, - dropout_rate, - attention_dropout_rate, activation, glu, layer_norm_eps=layer_norm_eps, @@ -587,7 +578,8 @@ def forward(self, memory_mask: Optional[Tensor] = None, decode: bool = False, max_len: Optional[int] = None, - cache: Optional[dict] = None) -> Any: + cache: Optional[dict] = None, + dropout_rate: Optional[float] = 0.0) -> Any: r"""Pass the inputs (and mask) through the decoder layer in turn. Args: tgt: the sequence to the decoder (required). @@ -596,6 +588,7 @@ def forward(self, memory_mask: the mask for the memory sequence (optional). decode: whether to use cache for autoregressive decoding or not. max_len: maximum sequence length, necessary for decoding cache. + dropout_rate: the dropout probability value (optional) Shape: see the docs in Transformer class. """ @@ -610,7 +603,8 @@ def forward(self, decode=decode, max_len=max_len, cache=cache, - index=idx) + index=idx, + dropout_rate=dropout_rate) if self.norm is not None: output = self.norm(output) @@ -636,7 +630,6 @@ class TransformerDecoderLayer(nn.Module): nhead: the number of heads in the multiheadattention models (default=16). dim_feedforward: the dimension of the feedforward network model (default=1024). - dropout_rate: the dropout_rate value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components @@ -656,8 +649,6 @@ def __init__(self, d_model: int = 1024, nhead: int = 16, dim_feedforward: int = 1024, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -671,7 +662,6 @@ def __init__(self, d_model, nhead, self_attn=True, - dropout_rate=attention_dropout_rate, attention_temp=attention_temp, bias=False, **factory_kwargs) @@ -679,7 +669,6 @@ def __init__(self, d_model, nhead, self_attn=False, - dropout_rate=attention_dropout_rate, attention_temp=attention_temp, bias=False, **factory_kwargs) @@ -691,16 +680,12 @@ def __init__(self, self.linear_glu = nn.Linear(dim_feedforward, dim_feedforward, **factory_kwargs) - self.dropout = nn.Dropout(dropout_rate) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.pre_ln = pre_ln self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.dropout1 = nn.Dropout(dropout_rate) - self.dropout2 = nn.Dropout(dropout_rate) - self.dropout3 = nn.Dropout(dropout_rate) self.activation = activation @@ -713,7 +698,8 @@ def forward( # pylint: disable=arguments-renamed decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0) -> Any: r"""Pass the inputs (and mask) through the decoder layer. Args: tgt: the sequence to the decoder layer (required). @@ -722,6 +708,7 @@ def forward( # pylint: disable=arguments-renamed memory_mask: the mask for the memory sequence (optional). decode: wether to use cache for autoregressive decoding or not. max_len: maximum sequence length, necessary for decoding cache. + dropout_rate: the dropout probability value (optional) Shape: see the docs in Transformer class. """ @@ -735,10 +722,11 @@ def forward( # pylint: disable=arguments-renamed decode=decode, max_len=max_len, cache=cache, - index=index) + index=index, + dropout_rate=dropout_rate) x = x + sa_out - x = x + self._mha_block(self.norm2(x), memory, memory_mask) - x = x + self._ff_block(self.norm3(x)) + x = x + self._mha_block(self.norm2(x), memory, memory_mask, dropout_rate) + x = x + self._ff_block(self.norm3(x), dropout_rate) else: sa_out, cache = self._sa_block( x, @@ -746,10 +734,11 @@ def forward( # pylint: disable=arguments-renamed decode=decode, max_len=max_len, cache=cache, - index=index) + index=index, + dropout_rate=dropout_rate) x = self.norm1(x + sa_out) - x = self.norm2(x + self._mha_block(x, memory, memory_mask)) - x = self.norm3(x + self._ff_block(x)) + x = self.norm2(x + self._mha_block(x, memory, memory_mask, dropout_rate)) + x = self.norm3(x + self._ff_block(x, dropout_rate)) return x, cache @@ -761,30 +750,38 @@ def _sa_block( # pylint: disable=arguments-renamed decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0) -> Any: x, cache = self.self_attn( x, attn_mask=attn_mask, decode=decode, max_len=max_len, cache=cache, - index=index) - return self.dropout1(x), cache + index=index, + dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, self.training), cache # Multihead attention block: def _mha_block(self, x: Tensor, mem: Tensor, - attn_mask: Optional[Tensor]) -> Tensor: - x, _ = self.multihead_attn(x, mem, attn_mask=attn_mask) - return self.dropout2(x) + attn_mask: Optional[Tensor], + dropout_rate: Optional[float] = 0.0) -> Tensor: + x, _ = self.multihead_attn( + x, + mem, + attn_mask=attn_mask, + dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, self.training) # Feed forward block. - def _ff_block(self, inputs: Tensor) -> Tensor: + def _ff_block(self, inputs: Tensor, + dropout_rate: Optional[float] = 0.0) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) x = x * y - x = self.linear2(self.dropout(x)) - return self.dropout3(x) + x = self.linear2(F.dropout(x, dropout_rate, self.training)) + return F.dropout(x, dropout_rate, self.training) class MultiheadAttention(nn.Module): @@ -802,8 +799,6 @@ class MultiheadAttention(nn.Module): ``embed_dim // num_heads``). self_attn: Whether self attention or encoder-decoder attention is used. Default: ``True``. - dropout_rate: Dropout probability on ``attn_output_weights``. - Default: ``0.0`` (no dropout_rate). bias: If specified, adds bias to input / output projection layers. Default: ``False``. device: The device of the module. @@ -817,7 +812,6 @@ def __init__(self, embed_dim: int, num_heads: int, self_attn: bool = True, - dropout_rate: float = 0., attention_temp: float = 1.0, bias: bool = False, device: Optional[torch.device] = None, @@ -826,7 +820,6 @@ def __init__(self, self.embed_dim = embed_dim self.num_heads = num_heads self.self_attn = self_attn - self.dropout = dropout_rate self.head_dim = embed_dim // num_heads self.attention_temp = attention_temp assert self.head_dim * num_heads == self.embed_dim, \ @@ -861,7 +854,8 @@ def forward(self, decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0) -> Any: # TODO: (nico) remove default?! r""" Args: x: Batch of input sequences of shape @@ -887,6 +881,7 @@ def forward(self, max_len: maximum sequence length, necessary for decoding cache. cache: cache dictionary for autoregressive decoding. index: index of the current decoding step, necessary for decoding cache. + dropout_rate: dropout probability on ``attn_output_weights``. Outputs: - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where :math:`L` is the target sequence length, :math:`N` is the batch size, @@ -976,12 +971,12 @@ def forward(self, attn_mask = new_attn_mask # Adjust dropout_rate probability. - dropout_rate = self.dropout if self.training else 0.0 + attn_dropout_rate = dropout_rate if self.training else 0.0 # Calculate attention. q = self.attention_temp * q attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask, dropout_rate) + q, k, v, attn_mask, attn_dropout_rate) # Rearrange for output projection. attn_output = attn_output.transpose(1, 2).contiguous().view( bsz, tgt_len, embed_dim) diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py deleted file mode 100644 index a43df30d4..000000000 --- a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py +++ /dev/null @@ -1,989 +0,0 @@ -import copy -import math -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F -from torch.nn.init import normal_ -from torch.nn.init import xavier_uniform_ - -DROPOUT_RATE = 0.1 - - -def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: - """Make a causal mask for self-attention. - - Args: - x: input array of shape `[batch..., len]` - device: device to store the idxs - - Returns: - A `[batch..., len, len]` shaped causal attention mask. - """ - idxs = torch.broadcast_to( - torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape) - return torch.greater_equal(idxs.unsqueeze(-1), idxs.unsqueeze(-2)) - - -def make_src_mask(src, inputs_segmentation, nhead): - """Utility for creating src mask and adjust it for PyTorch Transformer API.""" - src_mask = torch.mul((src > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)) - # Add segmentation block-diagonal attention mask if using segmented data. - if inputs_segmentation is not None: - src_mask = torch.logical_and( - src_mask, - torch.eq( - inputs_segmentation.unsqueeze(-1), - inputs_segmentation.unsqueeze(-2))) - # Flip values and ensure numerical stability. - src_mask = torch.repeat_interleave( - torch.logical_not(src_mask), repeats=nhead, dim=0) - new_src_mask = torch.zeros_like(src_mask, dtype=torch.float32) - new_src_mask.masked_fill_(src_mask, -1e10) - return new_src_mask - - -def make_tgt_and_memory_mask(tgt, - src, - inputs_segmentation, - targets_segmentation, - decode, - nhead): - """ Utility for creating target and memory mask and adjust them for PyTorch - Transformer API.""" - if not decode: - tgt_mask = torch.logical_and( - torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)), - make_causal_mask(tgt, device=tgt.device)) - memory_mask = torch.mul((tgt > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)) - else: - tgt_mask = None - memory_mask = torch.mul((torch.ones_like(tgt) > 0).unsqueeze(-1), - (src > 0).unsqueeze(-2)) - # Add segmentation block-diagonal attention masks if using segmented data. - if inputs_segmentation is not None: - tgt_mask = torch.logical_and( - tgt_mask, - torch.eq( - targets_segmentation.unsqueeze(-1), - targets_segmentation.unsqueeze(-2))) - memory_mask = torch.logical_and( - memory_mask, - torch.eq( - targets_segmentation.unsqueeze(-1), - inputs_segmentation.unsqueeze(-2))) - # Flip values and ensure numerical stability. - memory_mask = torch.repeat_interleave( - torch.logical_not(memory_mask), repeats=nhead, dim=0) - new_memory_mask = torch.zeros_like(memory_mask, dtype=torch.float32) - new_memory_mask.masked_fill_(memory_mask, -1e10) - if tgt_mask is not None: - tgt_mask = torch.repeat_interleave( - torch.logical_not(tgt_mask), repeats=nhead, dim=0) - new_tgt_mask = torch.zeros_like(tgt_mask, dtype=torch.float32) - new_tgt_mask.masked_fill_(tgt_mask, -1e10) - tgt_mask = new_tgt_mask - return tgt_mask, new_memory_mask - - -def shift_right(x, axis=1): - """Shift the input to the right by padding on axis 1.""" - pad_widths = [(0, 0)] * len(x.shape) - pad_widths[axis] = (1, 0) - pad_widths = tuple(t for tup in reversed(pad_widths) for t in tup) - padded = F.pad(x, pad_widths, mode='constant') - return padded[:, :-1] - - -class Transformer(nn.Module): - """Transformer architecture based on the model from the WMT Jax workload.""" - - def __init__(self, - ntoken: int = 32000, - d_model: int = 1024, - nhead: int = 16, - d_hid: int = 1024, - nlayers: int = 6, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True): - super().__init__() - self.pos_encoder = PositionalEncoding(d_model) - self.shared_embedding = nn.Embedding(ntoken, d_model) - self.encoder = Encoder(d_model, - nhead, - d_hid, - nlayers, - activation, - glu, - layer_norm_eps, - attention_temp, - pre_ln) - self.decoder = Decoder(d_model, - nhead, - d_hid, - nlayers, - activation, - glu, - layer_norm_eps, - attention_temp, - pre_ln) - # Share positional encoding and embedding between encoder and decoder. - self.encoder.pos_encoder = self.pos_encoder - self.encoder.shared_embedding = self.shared_embedding - self.decoder.pos_encoder = self.pos_encoder - self.decoder.shared_embedding = self.shared_embedding - - self._reset_parameters() - - def _reset_parameters(self): - """Initiate parameters in the transformer model.""" - for module in self.modules(): - if isinstance(module, nn.Linear): - xavier_uniform_(module.weight) - if module.bias is not None: - normal_(module.bias, std=1e-6) - - def forward(self, - src: Tensor, - tgt: Tensor, - inputs_positions: Optional[Tensor] = None, - targets_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None, - targets_segmentation: Optional[Tensor] = None, - decode: bool = False, - dropout_rate: float = DROPOUT_RATE) -> Tensor: - """ - Args: - src: Tensor, shape [batch_size, seq_len] - tgt: Tensor, shape [batch_size, seq_len] - inputs_positions: Optional[Tensor], shape [batch_size, seq_len] - targets_positions: Optional[Tensor], shape [batch_size, seq_len] - inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len] - targets_segmentation: Optional[Tensor], shape [batch_size, seq_len] - decode: bool - dropout_rate: float - - Returns: - output Tensor of shape [batch_size, seq_len, ntoken] - """ - if src.size(0) != tgt.size(0): - raise RuntimeError('The batch size of src and tgt must be equal.') - - memory = self.encoder( - src, - inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation, - dropout_rate=dropout_rate) - output = self.decoder( - tgt, - memory, - src, # just for calculating the padding mask - targets_positions=targets_positions, - inputs_segmentation=inputs_segmentation, - targets_segmentation=targets_segmentation, - decode=decode, - dropout_rate=dropout_rate) - return output - - -class TransformerEncoder(nn.Module): - r"""TransformerEncoder is a stack of N encoder layers. Users can build the - BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. - - Args: - encoder_layer: an instance of the TransformerEncoderLayer() class. - num_layers: the number of sub-encoder-layers in the encoder. - norm: the layer normalization component (optional). - enable_nested_tensor: if True, input will automatically convert to - nested tensor (and convert back on output). This will improve - the overall performance of TransformerEncoder when padding - rate is high. - - Examples:: - >>> encoder_layer = nn.TransformerEncoderLayer(12, 8) - >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, 6) - >>> src = torch.rand(10, 32, 512) - >>> out = transformer_encoder(src) - """ - __constants__ = ['norm'] - - def __init__(self, - encoder_layer, - num_layers, - norm=None, - enable_nested_tensor=True, - mask_check=True): - super().__init__() - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for _ in range(num_layers)]) - self.num_layers = num_layers - self.norm = norm - self.enable_nested_tensor = enable_nested_tensor - self.mask_check = mask_check - - def forward(self, src: Tensor, - mask: Optional[Tensor] = None, - dropout_rate: Optional[float] = 0.0) -> Tensor: - """Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - mask: the mask for the src sequence (optional). - dropout_rate: the dropout probability (optional). - - Shape: - see the docs in Transformer class. - """ - output = src - convert_to_nested = False - - for mod in self.layers: - output = mod(output, src_mask=mask, dropout_rate=dropout_rate) - - if convert_to_nested: - output = output.to_padded_tensor(0.) - - if self.norm is not None: - output = self.norm(output) - - return output - - -class Encoder(nn.Module): - - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - d_hid: int = 1024, - nlayers: int = 6, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True): - super().__init__() - self.nhead = nhead - self.shared_embedding = None - self.pos_encoder = None - encoder_layer = TransformerEncoderLayer( - d_model, - nhead, - d_hid, - activation=activation, - glu=glu, - layer_norm_eps=layer_norm_eps, - attention_temp=attention_temp, - pre_ln=pre_ln) - encoder_norm = ( - nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None) - self.encoder = TransformerEncoder(encoder_layer, nlayers, encoder_norm) - - def forward(self, - src: Tensor, - inputs_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None, - dropout_rate: Optional[float] = 0.0) -> Tensor: - src = src.to(torch.int) - src_mask = make_src_mask(src, inputs_segmentation, self.nhead) - src = self.shared_embedding(src) - src = self.pos_encoder(src, inputs_positions, dropout_rate=dropout_rate) - memory = self.encoder(src, mask=src_mask, dropout_rate=dropout_rate) - return memory - - -class Decoder(nn.Module): - - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - d_hid: int = 1024, - nlayers: int = 6, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True): - super().__init__() - self.nhead = nhead - self.shared_embedding = None - self.pos_encoder = None - self.decoder = TransformerDecoder(d_model, - nhead, - d_hid, - activation, - glu, - layer_norm_eps, - nlayers, - attention_temp, - pre_ln) - - def forward( - self, - tgt: Tensor, - memory: Tensor, - src: Tensor, # just for calculating the padding mask - targets_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None, - targets_segmentation: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - dropout_rate: Optional[float] = 0.0) -> Any: - tgt = tgt.to(torch.int) - tgt_mask, memory_mask = make_tgt_and_memory_mask( - tgt, src, inputs_segmentation, targets_segmentation, - decode, self.nhead) - if not decode: - tgt = shift_right(tgt) - tgt = self.shared_embedding(tgt) - tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache, dropout_rate=dropout_rate) - if decode: - tgt, cache = tgt - output = self.decoder( - tgt, - memory, - tgt_mask=tgt_mask, - memory_mask=memory_mask, - decode=decode, - max_len=max_len, - cache=cache, - dropout_rate=dropout_rate) - if decode: - output, cache = output - normalize = math.sqrt(output.shape[-1]) - output = torch.matmul(output, self.shared_embedding.weight.T) / normalize - if decode: - return output, cache - return output - - -class PositionalEncoding(nn.Module): - - def __init__(self, - d_model: int, - max_len: int = 256): - super().__init__() - - position = torch.arange(max_len).unsqueeze(1) - scale_factor = -math.log(10000.0) / (d_model // 2 - 1) - div_term = torch.exp(torch.arange(d_model // 2) * scale_factor) - pe = torch.zeros(1, max_len, d_model) - pe[0, :, :d_model // 2] = torch.sin(position * div_term) - pe[0, :, d_model // 2:2 * (d_model // 2)] = torch.cos(position * div_term) - self.register_buffer('pe', pe) - - def forward( - self, - x: Tensor, - inputs_positions: Optional[Tensor] = None, - decode: bool = False, - cache: Optional[Dict[str, Dict[str, Tensor]]] = None, - dropout_rate: Optional[float] = 0.0 - ) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]: - """ - Args: - x: Tensor (shape [batch_size, seq_len, embedding_dim]) - inputs_positions: Tensor (shape [batch_size, seq_len]) or None - decode: bool - cache: Dict[str, Dict[str, Tensor]] or None - dropout_rate: Optional[float] - Returns: - Tensor or Tuple[Tensor, Dict[str, Dict[str, Tensor]]] - """ - # We use a cache position index for tracking decoding position. - if decode: - name = self._get_name() - if cache is None: - cache = { - name: { - 'cache_index': - torch.tensor(0, dtype=torch.long, device=self.pe.device), - }, - } - pe = self.pe[0, cache[name]['cache_index'], :] - cache[name]['cache_index'] += 1 - return F.dropout(x + pe, dropout_rate, self.training), cache - if inputs_positions is None: - # normal unpacked case: - pe = self.pe[:, :x.size(1), :] - else: - # for packed data we need to use known position indices: - pe = self.pe[0, inputs_positions, :] - return F.dropout(x + pe, dropout_rate, self.training) - - -# TransformerEncoderLayer and TransformerDecoderLayer are taken from: -# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py -# Main difference is the use of custom MultiheadAttention modules. -class TransformerEncoderLayer(nn.Module): - r"""TransformerEncoderLayer is made up of self-attn and feedforward network. - This standard encoder layer is based on the paper "Attention Is All You Need". - Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, - Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all - you need. In Advances in Neural Information Processing Systems, - pages 6000-6010. Users may modify or implement in a different way during - application. - Args: - d_model: the number of expected features in the input (default=1024). - nhead: the number of heads in the multiheadattention models (default=16). - dim_feedforward: the dimension of the feedforward network model - (default=1024). - activation: the activation function of the intermediate layer, can be a - string ("relu" or "gelu") or a unary callable (default=F.relu). - layer_norm_eps: the eps value in layer normalization components - (default=1e-6). - pre_ln: if ``True``, layer norm is done prior to attention and - feedforward operations, respectivaly. Otherwise it's done after. - Default: ``True``. - Examples:: - >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(32, 10, 512) - >>> out = encoder_layer(src) - """ - __constants__ = ['pre_ln'] - - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - dim_feedforward: int = 1024, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True, - device=None, - dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.self_attn = MultiheadAttention( - d_model, - nhead, - self_attn=True, - attention_temp=attention_temp, - bias=False, - **factory_kwargs) - - # Implementation of Feedforward model. - self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.glu = glu - if self.glu: - self.linear_glu = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) - - self.pre_ln = pre_ln - self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - - self.activation = activation - - def forward(self, - src: Tensor, - src_mask: Optional[Tensor] = None, - dropout_rate: Optional[float] = 0.0) -> Tensor: - r"""Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - src_mask: the mask for the src sequence (optional). - dropout_rate: the dropout probability value (optional). - Shape: - see the docs in Transformer class. - """ - x = src - if self.pre_ln: - x = x + self._sa_block(self.norm1(x), src_mask, dropout_rate) - x = x + self._ff_block(self.norm2(x), dropout_rate) - else: - x = self.norm1(x + self._sa_block(x, src_mask, dropout_rate)) - x = self.norm2(x + self._ff_block(x, dropout_rate)) - - return x - - # Self-attention block: - def _sa_block(self, - x: Tensor, - attn_mask: Optional[Tensor], - dropout_rate: Optional[float] = 0.0) -> Tensor: - x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate) - return F.dropout(x, dropout_rate, training=self.training) - - # Feed forward block: - def _ff_block(self, - inputs: Tensor, - dropout_rate: Optional[float] = 0.0) -> Tensor: - x = self.activation(self.linear1(inputs)) - if self.glu: - y = self.linear_glu(inputs) - x = x * y - x = self.linear2(F.dropout(x, dropout_rate, training=self.training)) - return F.dropout(x, dropout_rate, training=self.training) - - -# Modified to use cache for autoregressive decoding and custom -# MultiheadAttention modules. -class TransformerDecoder(nn.Module): - r"""TransformerDecoder is a stack of N decoder layers - Args: - d_model: the number of expected features in the input (default=1024) - nhead: the number of heads in the multiheadattention models (default=16) - d_hid: the dimension of the feedforward network model - (default=1024) - layer_norm_eps: the eps value in layer normalization components - (default=1e-6). - decoder_layer: an instance of the TransformerDecoderLayer() class - num_layers: the number of sub-decoder-layers in the decoder - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(12, 8) - >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, 6) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> out = transformer_decoder(tgt, memory) - """ - __constants__ = ['norm'] - - def __init__(self, - d_model, - nhead, - d_hid, - activation, - glu, - layer_norm_eps, - num_layers, - attention_temp, - pre_ln): - super().__init__() - self.layers = nn.ModuleList([ - TransformerDecoderLayer( - d_model, - nhead, - d_hid, - activation, - glu, - layer_norm_eps=layer_norm_eps, - attention_temp=attention_temp, - pre_ln=pre_ln) for _ in range(num_layers) - ]) - self.num_layers = num_layers - self.norm = (nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None) - - def forward(self, - tgt: Tensor, - memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - dropout_rate: Optional[float] = 0.0) -> Any: - r"""Pass the inputs (and mask) through the decoder layer in turn. - Args: - tgt: the sequence to the decoder (required). - memory: the sequence from the last layer of the encoder (required). - tgt_mask: the mask for the tgt sequence (optional). - memory_mask: the mask for the memory sequence (optional). - decode: whether to use cache for autoregressive decoding or not. - max_len: maximum sequence length, necessary for decoding cache. - dropout_rate: the dropout probability value (optional) - Shape: - see the docs in Transformer class. - """ - output = tgt - - for idx, mod in enumerate(self.layers): - output, cache = mod( - output, - memory, - tgt_mask=tgt_mask, - memory_mask=memory_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=idx, - dropout_rate=dropout_rate) - - if self.norm is not None: - output = self.norm(output) - - if decode: - return output, cache - return output - - -# Modified to use cache for autoregressive decoding and custom -# MultiheadAttention modules. -class TransformerDecoderLayer(nn.Module): - r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and - feedforward network. - This standard decoder layer is based on the paper "Attention Is All You Need". - Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, - Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all - you need. In Advances in Neural Information Processing Systems, - pages 6000-6010. Users may modify or implement in a different way during - application. - Args: - d_model: the number of expected features in the input (default=1024). - nhead: the number of heads in the multiheadattention models (default=16). - dim_feedforward: the dimension of the feedforward network model - (default=1024). - activation: the activation function of the intermediate layer, can be a - string ("relu" or "gelu") or a unary callable (default=F.relu). - layer_norm_eps: the eps value in layer normalization components - (default=1e-6). - pre_ln: if ``True``, layer norm is done prior to self attention, - multihead attention and feedforward operations, respectivaly. - Otherwise it's done after. Default: ``True``. - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) - >>> memory = torch.rand(32, 10, 512) - >>> tgt = torch.rand(32, 20, 512) - >>> out = decoder_layer(tgt, memory) - """ - __constants__ = ['pre_ln'] - - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - dim_feedforward: int = 1024, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - pre_ln: bool = True, - attention_temp: float = 1.0, - device=None, - dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.self_attn = MultiheadAttention( - d_model, - nhead, - self_attn=True, - attention_temp=attention_temp, - bias=False, - **factory_kwargs) - self.multihead_attn = MultiheadAttention( - d_model, - nhead, - self_attn=False, - attention_temp=attention_temp, - bias=False, - **factory_kwargs) - - # Implementation of Feedforward model. - self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.glu = glu - if self.glu: - self.linear_glu = nn.Linear(dim_feedforward, - dim_feedforward, - **factory_kwargs) - self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) - - self.pre_ln = pre_ln - self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - - self.activation = activation - - def forward( # pylint: disable=arguments-renamed - self, - tgt: Tensor, - memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - index: Optional[int] = None, - dropout_rate: Optional[float] = 0.0) -> Any: - r"""Pass the inputs (and mask) through the decoder layer. - Args: - tgt: the sequence to the decoder layer (required). - memory: the sequence from the last layer of the encoder (required). - tgt_mask: the mask for the tgt sequence (optional). - memory_mask: the mask for the memory sequence (optional). - decode: wether to use cache for autoregressive decoding or not. - max_len: maximum sequence length, necessary for decoding cache. - dropout_rate: the dropout probability value (optional) - Shape: - see the docs in Transformer class. - """ - # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf - - x = tgt - if self.pre_ln: - sa_out, cache = self._sa_block( - self.norm1(x), - tgt_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=index, - dropout_rate=dropout_rate) - x = x + sa_out - x = x + self._mha_block(self.norm2(x), memory, memory_mask, dropout_rate) - x = x + self._ff_block(self.norm3(x), dropout_rate) - else: - sa_out, cache = self._sa_block( - x, - tgt_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=index, - dropout_rate=dropout_rate) - x = self.norm1(x + sa_out) - x = self.norm2(x + self._mha_block(x, memory, memory_mask, dropout_rate)) - x = self.norm3(x + self._ff_block(x, dropout_rate)) - - return x, cache - - # Self-attention block: - def _sa_block( # pylint: disable=arguments-renamed - self, - x: Tensor, - attn_mask: Optional[Tensor], - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - index: Optional[int] = None, - dropout_rate: Optional[float] = 0.0) -> Any: - x, cache = self.self_attn( - x, - attn_mask=attn_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=index, - dropout_rate=dropout_rate) - return F.dropout(x, dropout_rate, self.training), cache - - # Multihead attention block: - def _mha_block(self, x: Tensor, mem: Tensor, - attn_mask: Optional[Tensor], - dropout_rate: Optional[float] = 0.0) -> Tensor: - x, _ = self.multihead_attn( - x, - mem, - attn_mask=attn_mask, - dropout_rate=dropout_rate) - return F.dropout(x, dropout_rate, self.training) - - # Feed forward block. - def _ff_block(self, inputs: Tensor, - dropout_rate: Optional[float] = 0.0) -> Tensor: - x = self.activation(self.linear1(inputs)) - if self.glu: - y = self.linear_glu(inputs) - x = x * y - x = self.linear2(F.dropout(x, dropout_rate, self.training)) - return F.dropout(x, dropout_rate, self.training) - - -class MultiheadAttention(nn.Module): - r"""Allows the model to jointly attend to information - from different representation subspaces. Supports self-attention and - encoder-decoder attention. - See `Attention Is All You Need `_. - .. math:: - \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O - where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. - Args: - embed_dim: Total dimension of the model. - num_heads: Number of parallel attention heads. Note that ``embed_dim`` will - be split across ``num_heads`` (i.e. each head will have dimension - ``embed_dim // num_heads``). - self_attn: Whether self attention or encoder-decoder attention is used. - Default: ``True``. - bias: If specified, adds bias to input / output projection layers. - Default: ``False``. - device: The device of the module. - dtype: The dtype of the module. - Examples:: - >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) - >>> attn_output, cache = multihead_attn(x) - """ - - def __init__(self, - embed_dim: int, - num_heads: int, - self_attn: bool = True, - attention_temp: float = 1.0, - bias: bool = False, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None) -> None: - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.self_attn = self_attn - self.head_dim = embed_dim // num_heads - self.attention_temp = attention_temp - assert self.head_dim * num_heads == self.embed_dim, \ - 'embed_dim must be divisible by num_heads.' - - factory_kwargs = {'device': device, 'dtype': dtype} - if self_attn: - # Self-attention. - self.in_proj = nn.Linear( - embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) - else: - # Encoder-decoder attention. - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) - self.kv_proj = nn.Linear( - embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) - - self._reset_parameters() - - def _reset_parameters(self): - """Initiate parameters in the MultiheadAttention module.""" - for module in self.modules(): - if isinstance(module, nn.Linear): - xavier_uniform_(module.weight) - if module.bias is not None: - normal_(module.bias, std=1e-6) - - def forward(self, - x: Tensor, - mem: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - index: Optional[int] = None, - dropout_rate: Optional[float] = 0.0) -> Any: # TODO: (nico) remove default?! - r""" - Args: - x: Batch of input sequences of shape - (batch size, sequence length, embedding dimensionality) for self - attention mechanism. See "Attention Is All You Need" for more details. - mem: Batch of input sequences of shape - (batch size, sequence length, embedding dimensionality) for - encoder-decoder attention. See "Attention Is All You Need" for more - details. - attn_mask: If specified, a 2D or 3D mask preventing attention to certain - positions. Must be of shape :math:`(L, S)` or - :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the - batch size, :math:`L` is the target sequence length, and :math:`S` - is the source sequence length. A 2D mask will be broadcasted across - the batch while a 3D mask allows for a different mask for each entry - in the batch. Binary, byte, and float masks are supported. - For a binary mask, a ``True`` value indicates that the - corresponding position is not allowed to attend. For a byte mask, - a non-zero value indicates that the corresponding position is not - allowed to attend. For a float mask, the mask values will be added to - the attention weight. - decode: wether to use cache for autoregressive decoding or not. - max_len: maximum sequence length, necessary for decoding cache. - cache: cache dictionary for autoregressive decoding. - index: index of the current decoding step, necessary for decoding cache. - dropout_rate: dropout probability on ``attn_output_weights``. - Outputs: - - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where - :math:`L` is the target sequence length, :math:`N` is the batch size, - and :math:`E` is the embedding dimension ``embed_dim``. - - **cache** - For autoregressive decoding. - """ - # Shape: (batch size, sequence length, embedding dimensionality) - bsz, seq_len, embed_dim = x.size() - # In projection. - if self.self_attn: - q, k, v = self.in_proj(x).split(self.embed_dim, dim=2) - else: - q = self.q_proj(x) - k, v = self.kv_proj(mem).split(self.embed_dim, dim=2) - # This is 1 (!= seq_len) during autoreregressive decoding. - tgt_len = q.size(1) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - name = f'decoder.layers.{index}.self_attn' - loc_cache = cache[name] if decode and name in cache else None - if decode: - if loc_cache is None: - loc_cache = { - 'cached_key': - torch.zeros((bsz, max_len, embed_dim), - dtype=k.dtype, - device=k.device), - 'cached_value': - torch.zeros((bsz, max_len, embed_dim), - dtype=v.dtype, - device=v.device), - 'cache_index': - torch.tensor(0, dtype=torch.long, device=k.device), - } - cached_key = loc_cache['cached_key'] - cached_value = loc_cache['cached_value'] - cache_index = loc_cache['cache_index'] - # Shape check of cached keys against query input. - expected_shape = (bsz, 1, embed_dim) - if expected_shape != x.shape: - raise ValueError('Autoregressive cache shape error, expected query ' - f'shape {expected_shape} instead got {x.shape}.') - # Update key, value caches with our new 1d spatial slices. - cached_key[:, cache_index:cache_index + 1, :] = k - cached_value[:, cache_index:cache_index + 1, :] = v - k = cached_key - v = cached_value - cache_index += 1 - # Causal mask for cached decoder self-attention: - # our single query position should only attend to those key - # positions that have already been generated and cached, - # not the remaining zero elements. - if attn_mask is not None: - raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) >= - cache_index).reshape(1, max_len) - - # Update sequence length to account for complete sequence. - seq_len = k.size(1) - - # Rearrange q, k, v for multihead attention. - q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) - k = k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - v = v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - - # Check dtype and shape of attention mask. - if not decode and attn_mask is not None: - assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ - f'Float and bool dtypes are supported, not {attn_mask.dtype}.' - # Ensure attn_mask's dim is 3. - if attn_mask.dim() == 3: - correct_3d_size = (bsz * self.num_heads, tgt_len, seq_len) - if attn_mask.shape != correct_3d_size: - raise RuntimeError(f'The shape of attn_mask is {attn_mask.shape}, ' - f'but should be {correct_3d_size}.') - else: - raise RuntimeError( - f"attn_mask's dimension {attn_mask.dim()} is not supported") - # Reshape attention mask to be consistent with q, k, v. - attn_mask = attn_mask.view(bsz, self.num_heads, tgt_len, seq_len) - - # Convert attention mask to float. - if attn_mask is not None and attn_mask.dtype == torch.bool: - new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) - new_attn_mask.masked_fill_(attn_mask, -1e10) - attn_mask = new_attn_mask - - # Adjust dropout_rate probability. - attn_dropout_rate = dropout_rate if self.training else 0.0 - - # Calculate attention. - q = self.attention_temp * q - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask, attn_dropout_rate) - # Rearrange for output projection. - attn_output = attn_output.transpose(1, 2).contiguous().view( - bsz, tgt_len, embed_dim) - # Output projection. - attn_output = self.out_proj(attn_output) - - if decode: - cache[name] = loc_cache - - return attn_output, cache