diff --git a/micromind/networks/phinet.py b/micromind/networks/phinet.py index d859bf9..96066e8 100644 --- a/micromind/networks/phinet.py +++ b/micromind/networks/phinet.py @@ -449,7 +449,11 @@ def __init__( kernel_size=k_size, stride=stride, bias=False, - padding=k_size // 2 if stride == 1 else (padding[1], padding[3]), + padding=k_size // 2 + if isinstance(k_size, int) and stride == 1 + else [x // 2 for x in k_size] + if stride == 1 + else (padding[1], padding[3]), ) bn_dw1 = nn.BatchNorm2d( @@ -629,6 +633,7 @@ def __init__( squeeze_excite: bool = True, # S1 divisor: int = 1, return_layers=None, + flattened_embeddings=False, ) -> None: super(PhiNet, self).__init__() self.alpha = alpha @@ -637,6 +642,8 @@ def __init__( self.num_layers = num_layers self.num_classes = num_classes self.return_layers = return_layers + self.flattened_embeddings = flattened_embeddings + self.features_dim = 0 if compatibility: # disables operations hard for some platforms h_swish = False @@ -802,6 +809,12 @@ def __init__( ) block_id += 1 + if self.flattened_embeddings: + flatten = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten()) + self._layers.append(flatten) + + self.num_features = _make_divisible(int(block_filters * alpha), divisor=divisor) + if include_top: # Includes classification head if required self.classifier = nn.Sequential( diff --git a/micromind/networks/xinet.py b/micromind/networks/xinet.py index 949f6aa..587f32e 100644 --- a/micromind/networks/xinet.py +++ b/micromind/networks/xinet.py @@ -250,6 +250,7 @@ def __init__( include_top=False, base_filters: int = 16, return_layers: Optional[List] = None, + flattened_embeddings=False, ): super().__init__() @@ -258,6 +259,8 @@ def __init__( self.include_top = include_top self.return_layers = return_layers count_downsample = 0 + self.flattened_embeddings = flattened_embeddings + self.features_dim = 0 self.conv1 = nn.Sequential( nn.Conv2d( @@ -340,6 +343,9 @@ def __init__( for i in self.return_layers: print(f"Layer {i} - {self._layers[i].__class__}") + if self.flattened_embeddings: + self.flatten = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten()) + self.input_shape = input_shape if self.include_top: self.classifier = nn.Sequential( @@ -348,6 +354,8 @@ def __init__( nn.Linear(int(num_filters[-1] * alpha), num_classes), ) + self.num_features = int(num_filters[-1] * alpha) + def forward(self, x): """Computes the forward step of the XiNet. Arguments @@ -374,6 +382,9 @@ def forward(self, x): if layer_id in self.return_layers: ret.append(x) + if self.flattened_embeddings: + x = self.flatten(x) + if self.include_top: x = self.classifier(x)