Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion micromind/networks/phinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,11 @@ def __init__(
kernel_size=k_size,
stride=stride,
bias=False,
padding=k_size // 2 if stride == 1 else (padding[1], padding[3]),
padding=k_size // 2
if isinstance(k_size, int) and stride == 1
else [x // 2 for x in k_size]
if stride == 1
else (padding[1], padding[3]),
)

bn_dw1 = nn.BatchNorm2d(
Expand Down Expand Up @@ -629,6 +633,7 @@ def __init__(
squeeze_excite: bool = True, # S1
divisor: int = 1,
return_layers=None,
flattened_embeddings=False,
) -> None:
super(PhiNet, self).__init__()
self.alpha = alpha
Expand All @@ -637,6 +642,8 @@ def __init__(
self.num_layers = num_layers
self.num_classes = num_classes
self.return_layers = return_layers
self.flattened_embeddings = flattened_embeddings
self.features_dim = 0

if compatibility: # disables operations hard for some platforms
h_swish = False
Expand Down Expand Up @@ -802,6 +809,12 @@ def __init__(
)
block_id += 1

if self.flattened_embeddings:
flatten = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten())
self._layers.append(flatten)

self.num_features = _make_divisible(int(block_filters * alpha), divisor=divisor)

if include_top:
# Includes classification head if required
self.classifier = nn.Sequential(
Expand Down
11 changes: 11 additions & 0 deletions micromind/networks/xinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def __init__(
include_top=False,
base_filters: int = 16,
return_layers: Optional[List] = None,
flattened_embeddings=False,
):
super().__init__()

Expand All @@ -258,6 +259,8 @@ def __init__(
self.include_top = include_top
self.return_layers = return_layers
count_downsample = 0
self.flattened_embeddings = flattened_embeddings
self.features_dim = 0

self.conv1 = nn.Sequential(
nn.Conv2d(
Expand Down Expand Up @@ -340,6 +343,9 @@ def __init__(
for i in self.return_layers:
print(f"Layer {i} - {self._layers[i].__class__}")

if self.flattened_embeddings:
self.flatten = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten())

self.input_shape = input_shape
if self.include_top:
self.classifier = nn.Sequential(
Expand All @@ -348,6 +354,8 @@ def __init__(
nn.Linear(int(num_filters[-1] * alpha), num_classes),
)

self.num_features = int(num_filters[-1] * alpha)

def forward(self, x):
"""Computes the forward step of the XiNet.
Arguments
Expand All @@ -374,6 +382,9 @@ def forward(self, x):
if layer_id in self.return_layers:
ret.append(x)

if self.flattened_embeddings:
x = self.flatten(x)

if self.include_top:
x = self.classifier(x)

Expand Down