Skip to content

Commit d762dec

Browse files
gortizjicopybara-github
authored andcommitted
Internal.
PiperOrigin-RevId: 471778275
1 parent faaabf5 commit d762dec

File tree

1 file changed

+78
-62
lines changed

1 file changed

+78
-62
lines changed

uncertainty_baselines/models/wide_resnet.py

Lines changed: 78 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,8 @@ def Conv2D(filters, seed=None, **kwargs): # pylint: disable=invalid-name
4949
return tf.keras.layers.Conv2D(filters, **default_kwargs)
5050

5151

52-
def basic_block(
53-
inputs: tf.Tensor,
54-
filters: int,
55-
strides: int,
56-
conv_l2: float,
57-
bn_l2: float,
58-
seed: int,
59-
version: int) -> tf.Tensor:
52+
def basic_block(inputs: tf.Tensor, filters: int, strides: int, conv_l2: float,
53+
bn_l2: float, seed: int, version: int) -> tf.Tensor:
6054
"""Basic residual block of two 3x3 convs.
6155
6256
Args:
@@ -75,30 +69,42 @@ def basic_block(
7569
x = inputs
7670
y = inputs
7771
if version == 2:
78-
y = BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(bn_l2),
79-
gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(y)
72+
y = BatchNormalization(
73+
beta_regularizer=tf.keras.regularizers.l2(bn_l2),
74+
gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(
75+
y)
8076
y = tf.keras.layers.Activation('relu')(y)
8177
seeds = tf.random.experimental.stateless_split([seed, seed + 1], 3)[:, 0]
82-
y = Conv2D(filters,
83-
strides=strides,
84-
seed=seeds[0],
85-
kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(y)
86-
y = BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(bn_l2),
87-
gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(y)
78+
y = Conv2D(
79+
filters,
80+
strides=strides,
81+
seed=seeds[0],
82+
kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(
83+
y)
84+
y = BatchNormalization(
85+
beta_regularizer=tf.keras.regularizers.l2(bn_l2),
86+
gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(
87+
y)
8888
y = tf.keras.layers.Activation('relu')(y)
89-
y = Conv2D(filters,
90-
strides=1,
91-
seed=seeds[1],
92-
kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(y)
89+
y = Conv2D(
90+
filters,
91+
strides=1,
92+
seed=seeds[1],
93+
kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(
94+
y)
9395
if version == 1:
94-
y = BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(bn_l2),
95-
gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(y)
96+
y = BatchNormalization(
97+
beta_regularizer=tf.keras.regularizers.l2(bn_l2),
98+
gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(
99+
y)
96100
if not x.shape.is_compatible_with(y.shape):
97-
x = Conv2D(filters,
98-
kernel_size=1,
99-
strides=strides,
100-
seed=seeds[2],
101-
kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(x)
101+
x = Conv2D(
102+
filters,
103+
kernel_size=1,
104+
strides=strides,
105+
seed=seeds[2],
106+
kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(
107+
x)
102108
x = tf.keras.layers.add([x, y])
103109
if version == 1:
104110
x = tf.keras.layers.Activation('relu')(x)
@@ -107,8 +113,8 @@ def basic_block(
107113

108114
def group(inputs, filters, strides, num_blocks, conv_l2, bn_l2, version, seed):
109115
"""Group of residual blocks."""
110-
seeds = tf.random.experimental.stateless_split(
111-
[seed, seed + 1], num_blocks)[:, 0]
116+
seeds = tf.random.experimental.stateless_split([seed, seed + 1],
117+
num_blocks)[:, 0]
112118
x = basic_block(
113119
inputs,
114120
filters=filters,
@@ -187,49 +193,59 @@ def wide_resnet(
187193
raise ValueError('depth should be 6n+4 (e.g., 16, 22, 28, 40).')
188194
num_blocks = (depth - 4) // 6
189195
inputs = tf.keras.layers.Input(shape=input_shape)
190-
x = Conv2D(16,
191-
strides=1,
192-
seed=seeds[0],
193-
kernel_regularizer=l2_reg(hps['input_conv_l2']))(inputs)
196+
x = Conv2D(
197+
16,
198+
strides=1,
199+
seed=seeds[0],
200+
kernel_regularizer=l2_reg(hps['input_conv_l2']))(
201+
inputs)
194202
if version == 1:
195-
x = BatchNormalization(beta_regularizer=l2_reg(hps['bn_l2']),
196-
gamma_regularizer=l2_reg(hps['bn_l2']))(x)
203+
x = BatchNormalization(
204+
beta_regularizer=l2_reg(hps['bn_l2']),
205+
gamma_regularizer=l2_reg(hps['bn_l2']))(
206+
x)
197207
x = tf.keras.layers.Activation('relu')(x)
198-
x = group(x,
199-
filters=16 * width_multiplier,
200-
strides=1,
201-
num_blocks=num_blocks,
202-
conv_l2=hps['group_1_conv_l2'],
203-
bn_l2=hps['bn_l2'],
204-
version=version,
205-
seed=seeds[1])
206-
x = group(x,
207-
filters=32 * width_multiplier,
208-
strides=2,
209-
num_blocks=num_blocks,
210-
conv_l2=hps['group_2_conv_l2'],
211-
bn_l2=hps['bn_l2'],
212-
version=version,
213-
seed=seeds[2])
214-
x = group(x,
215-
filters=64 * width_multiplier,
216-
strides=2,
217-
num_blocks=num_blocks,
218-
conv_l2=hps['group_3_conv_l2'],
219-
bn_l2=hps['bn_l2'],
220-
version=version,
221-
seed=seeds[3])
208+
x = group(
209+
x,
210+
filters=16 * width_multiplier,
211+
strides=1,
212+
num_blocks=num_blocks,
213+
conv_l2=hps['group_1_conv_l2'],
214+
bn_l2=hps['bn_l2'],
215+
version=version,
216+
seed=seeds[1])
217+
x = group(
218+
x,
219+
filters=32 * width_multiplier,
220+
strides=2,
221+
num_blocks=num_blocks,
222+
conv_l2=hps['group_2_conv_l2'],
223+
bn_l2=hps['bn_l2'],
224+
version=version,
225+
seed=seeds[2])
226+
x = group(
227+
x,
228+
filters=64 * width_multiplier,
229+
strides=2,
230+
num_blocks=num_blocks,
231+
conv_l2=hps['group_3_conv_l2'],
232+
bn_l2=hps['bn_l2'],
233+
version=version,
234+
seed=seeds[3])
222235
if version == 2:
223-
x = BatchNormalization(beta_regularizer=l2_reg(hps['bn_l2']),
224-
gamma_regularizer=l2_reg(hps['bn_l2']))(x)
236+
x = BatchNormalization(
237+
beta_regularizer=l2_reg(hps['bn_l2']),
238+
gamma_regularizer=l2_reg(hps['bn_l2']))(
239+
x)
225240
x = tf.keras.layers.Activation('relu')(x)
226241
x = tf.keras.layers.AveragePooling2D(pool_size=8)(x)
227242
x = tf.keras.layers.Flatten()(x)
228243
x = tf.keras.layers.Dense(
229244
num_classes,
230245
kernel_initializer=tf.keras.initializers.HeNormal(seed=seeds[4]),
231246
kernel_regularizer=l2_reg(hps['dense_kernel_l2']),
232-
bias_regularizer=l2_reg(hps['dense_bias_l2']))(x)
247+
bias_regularizer=l2_reg(hps['dense_bias_l2']))(
248+
x)
233249
return tf.keras.Model(
234250
inputs=inputs,
235251
outputs=x,

0 commit comments

Comments
 (0)