@@ -49,14 +49,8 @@ def Conv2D(filters, seed=None, **kwargs): # pylint: disable=invalid-name
49
49
return tf .keras .layers .Conv2D (filters , ** default_kwargs )
50
50
51
51
52
- def basic_block (
53
- inputs : tf .Tensor ,
54
- filters : int ,
55
- strides : int ,
56
- conv_l2 : float ,
57
- bn_l2 : float ,
58
- seed : int ,
59
- version : int ) -> tf .Tensor :
52
+ def basic_block (inputs : tf .Tensor , filters : int , strides : int , conv_l2 : float ,
53
+ bn_l2 : float , seed : int , version : int ) -> tf .Tensor :
60
54
"""Basic residual block of two 3x3 convs.
61
55
62
56
Args:
@@ -75,30 +69,42 @@ def basic_block(
75
69
x = inputs
76
70
y = inputs
77
71
if version == 2 :
78
- y = BatchNormalization (beta_regularizer = tf .keras .regularizers .l2 (bn_l2 ),
79
- gamma_regularizer = tf .keras .regularizers .l2 (bn_l2 ))(y )
72
+ y = BatchNormalization (
73
+ beta_regularizer = tf .keras .regularizers .l2 (bn_l2 ),
74
+ gamma_regularizer = tf .keras .regularizers .l2 (bn_l2 ))(
75
+ y )
80
76
y = tf .keras .layers .Activation ('relu' )(y )
81
77
seeds = tf .random .experimental .stateless_split ([seed , seed + 1 ], 3 )[:, 0 ]
82
- y = Conv2D (filters ,
83
- strides = strides ,
84
- seed = seeds [0 ],
85
- kernel_regularizer = tf .keras .regularizers .l2 (conv_l2 ))(y )
86
- y = BatchNormalization (beta_regularizer = tf .keras .regularizers .l2 (bn_l2 ),
87
- gamma_regularizer = tf .keras .regularizers .l2 (bn_l2 ))(y )
78
+ y = Conv2D (
79
+ filters ,
80
+ strides = strides ,
81
+ seed = seeds [0 ],
82
+ kernel_regularizer = tf .keras .regularizers .l2 (conv_l2 ))(
83
+ y )
84
+ y = BatchNormalization (
85
+ beta_regularizer = tf .keras .regularizers .l2 (bn_l2 ),
86
+ gamma_regularizer = tf .keras .regularizers .l2 (bn_l2 ))(
87
+ y )
88
88
y = tf .keras .layers .Activation ('relu' )(y )
89
- y = Conv2D (filters ,
90
- strides = 1 ,
91
- seed = seeds [1 ],
92
- kernel_regularizer = tf .keras .regularizers .l2 (conv_l2 ))(y )
89
+ y = Conv2D (
90
+ filters ,
91
+ strides = 1 ,
92
+ seed = seeds [1 ],
93
+ kernel_regularizer = tf .keras .regularizers .l2 (conv_l2 ))(
94
+ y )
93
95
if version == 1 :
94
- y = BatchNormalization (beta_regularizer = tf .keras .regularizers .l2 (bn_l2 ),
95
- gamma_regularizer = tf .keras .regularizers .l2 (bn_l2 ))(y )
96
+ y = BatchNormalization (
97
+ beta_regularizer = tf .keras .regularizers .l2 (bn_l2 ),
98
+ gamma_regularizer = tf .keras .regularizers .l2 (bn_l2 ))(
99
+ y )
96
100
if not x .shape .is_compatible_with (y .shape ):
97
- x = Conv2D (filters ,
98
- kernel_size = 1 ,
99
- strides = strides ,
100
- seed = seeds [2 ],
101
- kernel_regularizer = tf .keras .regularizers .l2 (conv_l2 ))(x )
101
+ x = Conv2D (
102
+ filters ,
103
+ kernel_size = 1 ,
104
+ strides = strides ,
105
+ seed = seeds [2 ],
106
+ kernel_regularizer = tf .keras .regularizers .l2 (conv_l2 ))(
107
+ x )
102
108
x = tf .keras .layers .add ([x , y ])
103
109
if version == 1 :
104
110
x = tf .keras .layers .Activation ('relu' )(x )
@@ -107,8 +113,8 @@ def basic_block(
107
113
108
114
def group (inputs , filters , strides , num_blocks , conv_l2 , bn_l2 , version , seed ):
109
115
"""Group of residual blocks."""
110
- seeds = tf .random .experimental .stateless_split (
111
- [ seed , seed + 1 ], num_blocks )[:, 0 ]
116
+ seeds = tf .random .experimental .stateless_split ([ seed , seed + 1 ],
117
+ num_blocks )[:, 0 ]
112
118
x = basic_block (
113
119
inputs ,
114
120
filters = filters ,
@@ -187,49 +193,59 @@ def wide_resnet(
187
193
raise ValueError ('depth should be 6n+4 (e.g., 16, 22, 28, 40).' )
188
194
num_blocks = (depth - 4 ) // 6
189
195
inputs = tf .keras .layers .Input (shape = input_shape )
190
- x = Conv2D (16 ,
191
- strides = 1 ,
192
- seed = seeds [0 ],
193
- kernel_regularizer = l2_reg (hps ['input_conv_l2' ]))(inputs )
196
+ x = Conv2D (
197
+ 16 ,
198
+ strides = 1 ,
199
+ seed = seeds [0 ],
200
+ kernel_regularizer = l2_reg (hps ['input_conv_l2' ]))(
201
+ inputs )
194
202
if version == 1 :
195
- x = BatchNormalization (beta_regularizer = l2_reg (hps ['bn_l2' ]),
196
- gamma_regularizer = l2_reg (hps ['bn_l2' ]))(x )
203
+ x = BatchNormalization (
204
+ beta_regularizer = l2_reg (hps ['bn_l2' ]),
205
+ gamma_regularizer = l2_reg (hps ['bn_l2' ]))(
206
+ x )
197
207
x = tf .keras .layers .Activation ('relu' )(x )
198
- x = group (x ,
199
- filters = 16 * width_multiplier ,
200
- strides = 1 ,
201
- num_blocks = num_blocks ,
202
- conv_l2 = hps ['group_1_conv_l2' ],
203
- bn_l2 = hps ['bn_l2' ],
204
- version = version ,
205
- seed = seeds [1 ])
206
- x = group (x ,
207
- filters = 32 * width_multiplier ,
208
- strides = 2 ,
209
- num_blocks = num_blocks ,
210
- conv_l2 = hps ['group_2_conv_l2' ],
211
- bn_l2 = hps ['bn_l2' ],
212
- version = version ,
213
- seed = seeds [2 ])
214
- x = group (x ,
215
- filters = 64 * width_multiplier ,
216
- strides = 2 ,
217
- num_blocks = num_blocks ,
218
- conv_l2 = hps ['group_3_conv_l2' ],
219
- bn_l2 = hps ['bn_l2' ],
220
- version = version ,
221
- seed = seeds [3 ])
208
+ x = group (
209
+ x ,
210
+ filters = 16 * width_multiplier ,
211
+ strides = 1 ,
212
+ num_blocks = num_blocks ,
213
+ conv_l2 = hps ['group_1_conv_l2' ],
214
+ bn_l2 = hps ['bn_l2' ],
215
+ version = version ,
216
+ seed = seeds [1 ])
217
+ x = group (
218
+ x ,
219
+ filters = 32 * width_multiplier ,
220
+ strides = 2 ,
221
+ num_blocks = num_blocks ,
222
+ conv_l2 = hps ['group_2_conv_l2' ],
223
+ bn_l2 = hps ['bn_l2' ],
224
+ version = version ,
225
+ seed = seeds [2 ])
226
+ x = group (
227
+ x ,
228
+ filters = 64 * width_multiplier ,
229
+ strides = 2 ,
230
+ num_blocks = num_blocks ,
231
+ conv_l2 = hps ['group_3_conv_l2' ],
232
+ bn_l2 = hps ['bn_l2' ],
233
+ version = version ,
234
+ seed = seeds [3 ])
222
235
if version == 2 :
223
- x = BatchNormalization (beta_regularizer = l2_reg (hps ['bn_l2' ]),
224
- gamma_regularizer = l2_reg (hps ['bn_l2' ]))(x )
236
+ x = BatchNormalization (
237
+ beta_regularizer = l2_reg (hps ['bn_l2' ]),
238
+ gamma_regularizer = l2_reg (hps ['bn_l2' ]))(
239
+ x )
225
240
x = tf .keras .layers .Activation ('relu' )(x )
226
241
x = tf .keras .layers .AveragePooling2D (pool_size = 8 )(x )
227
242
x = tf .keras .layers .Flatten ()(x )
228
243
x = tf .keras .layers .Dense (
229
244
num_classes ,
230
245
kernel_initializer = tf .keras .initializers .HeNormal (seed = seeds [4 ]),
231
246
kernel_regularizer = l2_reg (hps ['dense_kernel_l2' ]),
232
- bias_regularizer = l2_reg (hps ['dense_bias_l2' ]))(x )
247
+ bias_regularizer = l2_reg (hps ['dense_bias_l2' ]))(
248
+ x )
233
249
return tf .keras .Model (
234
250
inputs = inputs ,
235
251
outputs = x ,
0 commit comments