@@ -76,11 +76,11 @@ class TinyWideConvNet(nn.Module):
7676 def __init__ (self , in_size :int | Sequence [int ], in_channels :int , out_channels :int , act_cls : Callable = nn .ReLU , dropout = 0.5 ):
7777 super ().__init__ ()
7878 if isinstance (in_size , int ): in_size = (in_size , )
79- ndim = len (in_size )
79+ self . ndim = len (in_size )
8080
81- Conv = ConvNd (ndim )
82- MaxPool = MaxPoolNd (ndim )
83- Dropout = DropoutNd (ndim )
81+ Conv = ConvNd (self . ndim )
82+ MaxPool = MaxPoolNd (self . ndim )
83+ Dropout = DropoutNd (self . ndim )
8484
8585 self .c1 = nn .Sequential (
8686 Conv (in_channels , 8 , kernel_size = 5 ), # ~37
@@ -105,7 +105,8 @@ def forward(self, x):
105105 if x .ndim == 2 : x = x .unsqueeze (1 )
106106 x = self .c1 (x )
107107 x = self .c2 (x )
108- x = self .c3 (x ).mean (- 1 )
108+ dims = [- i for i in range (1 , self .ndim + 1 )]
109+ x = self .c3 (x ).mean (dims )
109110 return self .linear (x )
110111
111112
@@ -114,11 +115,11 @@ class TinyLongConvNet(nn.Module):
114115 def __init__ (self , in_size :int | Sequence [int ], in_channels :int , out_channels :int , act_cls : Callable = nn .ReLU , dropout = 0.0 ):
115116 super ().__init__ ()
116117 if isinstance (in_size , int ): in_size = (in_size , )
117- ndim = len (in_size )
118+ self . ndim = len (in_size )
118119
119- Conv = ConvNd (ndim )
120- Dropout = DropoutNd (ndim )
121- BatchNorm = BatchNormNd (ndim )
120+ Conv = ConvNd (self . ndim )
121+ Dropout = DropoutNd (self . ndim )
122+ BatchNorm = BatchNormNd (self . ndim )
122123
123124 self .c1 = nn .Sequential (
124125 Conv (in_channels , 4 , kernel_size = 2 , bias = False ),
@@ -158,7 +159,9 @@ def forward(self, x):
158159 if x .ndim == 2 : x = x .unsqueeze (1 )
159160 x = self .c1 (x )
160161 x = self .c2 (x )
161- x = self .c3 (x ).mean (- 1 )
162+
163+ dims = [- i for i in range (1 , self .ndim + 1 )]
164+ x = self .c3 (x ).mean (dims )
162165 return self .linear (x )
163166
164167
@@ -258,10 +261,10 @@ class MobileNet(nn.Module):
258261 def __init__ (self , in_size :int | Sequence [int ], in_channels :int , out_channels :int , act_cls : Callable = nn .ReLU , dropout = 0.5 ):
259262 super ().__init__ ()
260263 if isinstance (in_size , int ): in_size = (in_size , )
261- ndim = len (in_size )
264+ self . ndim = len (in_size )
262265
263- Conv = ConvNd (ndim )
264- Dropout = DropoutNd (ndim )
266+ Conv = ConvNd (self . ndim )
267+ Dropout = DropoutNd (self . ndim )
265268
266269 self .c1 = nn .Sequential (
267270 Conv (in_channels , 32 , kernel_size = 3 , stride = 2 , padding = 1 ),
@@ -297,7 +300,8 @@ def forward(self, x):
297300 x = self .c1 (x )
298301 x = self .c2 (x )
299302 x = self .c3 (x )
300- return x .mean (- 1 )
303+ dims = [- i for i in range (1 , self .ndim + 1 )]
304+ return x .mean (dims )
301305
302306def convblocknd (in_channels , out_channels , kernel_size , stride , padding , act_cls , bn : bool , dropout :float | None , transpose = False , ndim :int = 2 ):
303307 ConvCls = ConvTransposeNd (ndim ) if transpose else ConvNd (ndim )
0 commit comments