Skip to content

Commit c7afb54

Browse files
authored
Merge pull request #95 from NVIDIA/aparis/docs
Docstrings PR
2 parents b5c410c + 644465b commit c7afb54

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2356
-427
lines changed

Changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* Reorganized examples folder, including new examples based on the 2d3ds dataset
1515
* Added spherical loss functions to examples
1616
* Added plotting module
17+
* Updated docstrings
1718

1819
### v0.7.6
1920

examples/baseline_models/segformer.py

Lines changed: 145 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,24 @@
4141

4242

4343
class OverlapPatchMerging(nn.Module):
44+
"""
45+
OverlapPatchMerging layer for merging patches.
46+
47+
Parameters
48+
-----------
49+
in_shape : tuple
50+
Input shape (height, width)
51+
out_shape : tuple
52+
Output shape (height, width)
53+
in_channels : int
54+
Number of input channels
55+
out_channels : int
56+
Number of output channels
57+
kernel_shape : tuple
58+
Kernel shape for convolution
59+
bias : bool, optional
60+
Whether to use bias, by default False
61+
"""
4462
def __init__(
4563
self,
4664
in_shape=(721, 1440),
@@ -88,6 +106,30 @@ def forward(self, x):
88106

89107

90108
class MixFFN(nn.Module):
109+
"""
110+
MixFFN module combining MLP and depthwise convolution.
111+
112+
Parameters
113+
-----------
114+
shape : tuple
115+
Input shape (height, width)
116+
inout_channels : int
117+
Number of input/output channels
118+
hidden_channels : int
119+
Number of hidden channels in MLP
120+
mlp_bias : bool, optional
121+
Whether to use bias in MLP layers, by default True
122+
kernel_shape : tuple, optional
123+
Kernel shape for depthwise convolution, by default (3, 3)
124+
conv_bias : bool, optional
125+
Whether to use bias in convolution, by default False
126+
activation : callable, optional
127+
Activation function, by default nn.GELU
128+
use_mlp : bool, optional
129+
Whether to use MLP instead of linear layers, by default False
130+
drop_path : float, optional
131+
Drop path rate, by default 0.0
132+
"""
91133
def __init__(
92134
self,
93135
shape,
@@ -142,7 +184,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
142184
x = x.permute(0, 3, 1, 2)
143185

144186
# NOTE: we add another activation here
145-
# because in the paper they only use depthwise conv,
187+
# because in the paper the authors only use depthwise conv,
146188
# but without this activation it would just be a fused MM
147189
# with the disco conv
148190
x = self.mlp_in(x)
@@ -162,13 +204,25 @@ class GlobalAttention(nn.Module):
162204
163205
Input shape: (B, C, H, W)
164206
Output shape: (B, C, H, W) with residual skip.
207+
208+
Parameters
209+
-----------
210+
chans : int
211+
Number of channels
212+
num_heads : int, optional
213+
Number of attention heads, by default 8
214+
dropout : float, optional
215+
Dropout rate, by default 0.0
216+
bias : bool, optional
217+
Whether to use bias, by default True
165218
"""
166219

167220
def __init__(self, chans, num_heads=8, dropout=0.0, bias=True):
168221
super().__init__()
169222
self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias)
170223

171224
def forward(self, x):
225+
172226
# x: B, C, H, W
173227
B, H, W, C = x.shape
174228
# flatten spatial dims
@@ -181,6 +235,30 @@ def forward(self, x):
181235

182236

183237
class AttentionWrapper(nn.Module):
238+
"""
239+
Wrapper for different attention mechanisms.
240+
241+
Parameters
242+
-----------
243+
channels : int
244+
Number of channels
245+
shape : tuple
246+
Input shape (height, width)
247+
heads : int
248+
Number of attention heads
249+
pre_norm : bool, optional
250+
Whether to apply normalization before attention, by default False
251+
attention_drop_rate : float, optional
252+
Attention dropout rate, by default 0.0
253+
drop_path : float, optional
254+
Drop path rate, by default 0.0
255+
attention_mode : str, optional
256+
Attention mode ("neighborhood", "global"), by default "neighborhood"
257+
kernel_shape : tuple, optional
258+
Kernel shape for neighborhood attention, by default (7, 7)
259+
bias : bool, optional
260+
Whether to use bias, by default True
261+
"""
184262
def __init__(self, channels, shape, heads, pre_norm=False, attention_drop_rate=0.0, drop_path=0.0, attention_mode="neighborhood", kernel_shape=(7, 7), bias=True):
185263
super().__init__()
186264

@@ -203,11 +281,13 @@ def __init__(self, channels, shape, heads, pre_norm=False, attention_drop_rate=0
203281
self.apply(self._init_weights)
204282

205283
def _init_weights(self, m):
284+
206285
if isinstance(m, nn.LayerNorm):
207286
nn.init.constant_(m.bias, 0)
208287
nn.init.constant_(m.weight, 1.0)
209288

210289
def forward(self, x: torch.Tensor) -> torch.Tensor:
290+
211291
residual = x
212292
x = x.permute(0, 2, 3, 1)
213293
if self.norm is not None:
@@ -219,6 +299,41 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
219299

220300

221301
class TransformerBlock(nn.Module):
302+
"""
303+
Transformer block with attention and MLP.
304+
305+
Parameters
306+
----------
307+
in_shape : tuple
308+
Input shape (height, width)
309+
out_shape : tuple
310+
Output shape (height, width)
311+
in_channels : int
312+
Number of input channels
313+
out_channels : int
314+
Number of output channels
315+
mlp_hidden_channels : int
316+
Number of hidden channels in MLP
317+
nrep : int, optional
318+
Number of repetitions of attention and MLP blocks, by default 1
319+
heads : int, optional
320+
Number of attention heads, by default 1
321+
kernel_shape : tuple, optional
322+
Kernel shape for neighborhood attention, by default (3, 3)
323+
activation : torch.nn.Module, optional
324+
Activation function to use, by default nn.GELU
325+
att_drop_rate : float, optional
326+
Attention dropout rate, by default 0.0
327+
drop_path_rates : float or list, optional
328+
Drop path rates for each block, by default 0.0
329+
attention_mode : str, optional
330+
Attention mode ("neighborhood", "global"), by default "neighborhood"
331+
attn_kernel_shape : tuple, optional
332+
Kernel shape for neighborhood attention, by default (7, 7)
333+
bias : bool, optional
334+
Whether to use bias, by default True
335+
"""
336+
222337
def __init__(
223338
self,
224339
in_shape,
@@ -341,6 +456,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
341456

342457

343458
class Upsampling(nn.Module):
459+
"""
460+
Upsampling block for the Segformer model.
461+
462+
Parameters
463+
----------
464+
in_shape : tuple
465+
Input shape (height, width)
466+
out_shape : tuple
467+
Output shape (height, width)
468+
in_channels : int
469+
Number of input channels
470+
out_channels : int
471+
Number of output channels
472+
hidden_channels : int
473+
Number of hidden channels in MLP
474+
mlp_bias : bool, optional
475+
Whether to use bias in MLP, by default True
476+
kernel_shape : tuple, optional
477+
Kernel shape for convolution, by default (3, 3)
478+
conv_bias : bool, optional
479+
Whether to use bias in convolution, by default False
480+
activation : torch.nn.Module, optional
481+
Activation function to use, by default nn.GELU
482+
use_mlp : bool, optional
483+
Whether to use MLP, by default False
484+
"""
485+
344486
def __init__(
345487
self,
346488
in_shape,
@@ -382,7 +524,7 @@ class Segformer(nn.Module):
382524
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
383525
384526
Parameters
385-
-----------
527+
----------
386528
img_shape : tuple, optional
387529
Shape of the input channels, by default (128, 256)
388530
kernel_shape: tuple, int
@@ -414,7 +556,7 @@ class Segformer(nn.Module):
414556
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
415557
416558
Example
417-
-----------
559+
----------
418560
>>> model = Segformer(
419561
... img_size=(128, 256),
420562
... in_chans=3,

examples/baseline_models/transformer.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,34 @@ def __init__(
5757
self.conv = nn.Conv2d(in_chans, out_chans, kernel_size=kernel_shape, bias=bias, stride=(stride_h, stride_w), padding=(pad_h, pad_w), groups=groups)
5858

5959
def forward(self, x):
60+
6061
x = self.conv(x)
6162
return x
6263

6364

6465
class Decoder(nn.Module):
66+
"""
67+
Decoder module for upsampling and feature processing.
68+
69+
Parameters
70+
-----------
71+
in_shape : tuple, optional
72+
Input shape (height, width), by default (480, 960)
73+
out_shape : tuple, optional
74+
Output shape (height, width), by default (721, 1440)
75+
in_chans : int, optional
76+
Number of input channels, by default 2
77+
out_chans : int, optional
78+
Number of output channels, by default 2
79+
kernel_shape : tuple, optional
80+
Kernel shape for convolution, by default (3, 3)
81+
groups : int, optional
82+
Number of groups for convolution, by default 1
83+
bias : bool, optional
84+
Whether to use bias, by default False
85+
upsampling_method : str, optional
86+
Upsampling method ("conv", "pixel_shuffle"), by default "conv"
87+
"""
6588
def __init__(self, in_shape=(480, 960), out_shape=(721, 1440), in_chans=2, out_chans=2, kernel_shape=(3, 3), groups=1, bias=False, upsampling_method="conv"):
6689
super().__init__()
6790
self.out_shape = out_shape
@@ -87,6 +110,7 @@ def __init__(self, in_shape=(480, 960), out_shape=(721, 1440), in_chans=2, out_c
87110
raise ValueError(f"Unknown upsampling method {upsampling_method}")
88111

89112
def forward(self, x):
113+
90114
x = self.upsample(x)
91115
return x
92116

@@ -97,13 +121,25 @@ class GlobalAttention(nn.Module):
97121
98122
Input shape: (B, C, H, W)
99123
Output shape: (B, C, H, W) with residual skip.
124+
125+
Parameters
126+
-----------
127+
chans : int
128+
Number of channels
129+
num_heads : int, optional
130+
Number of attention heads, by default 8
131+
dropout : float, optional
132+
Dropout rate, by default 0.0
133+
bias : bool, optional
134+
Whether to use bias, by default True
100135
"""
101136

102137
def __init__(self, chans, num_heads=8, dropout=0.0, bias=True):
103138
super().__init__()
104139
self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias)
105140

106141
def forward(self, x):
142+
107143
# x: B, C, H, W
108144
B, H, W, C = x.shape
109145
# flatten spatial dims
@@ -118,8 +154,36 @@ def forward(self, x):
118154
class AttentionBlock(nn.Module):
119155
"""
120156
Neighborhood attention block based on Natten.
157+
158+
Parameters
159+
-----------
160+
in_shape : tuple, optional
161+
Input shape (height, width), by default (480, 960)
162+
out_shape : tuple, optional
163+
Output shape (height, width), by default (480, 960)
164+
chans : int, optional
165+
Number of channels, by default 2
166+
num_heads : int, optional
167+
Number of attention heads, by default 1
168+
mlp_ratio : float, optional
169+
Ratio of MLP hidden dim to input dim, by default 2.0
170+
drop_rate : float, optional
171+
Dropout rate, by default 0.0
172+
drop_path : float, optional
173+
Drop path rate, by default 0.0
174+
act_layer : callable, optional
175+
Activation function, by default nn.GELU
176+
norm_layer : str, optional
177+
Normalization layer type, by default "none"
178+
use_mlp : bool, optional
179+
Whether to use MLP, by default True
180+
bias : bool, optional
181+
Whether to use bias, by default True
182+
attention_mode : str, optional
183+
Attention mode ("neighborhood", "global"), by default "neighborhood"
184+
attn_kernel_shape : tuple, optional
185+
Kernel shape for neighborhood attention, by default (7, 7)
121186
"""
122-
123187
def __init__(
124188
self,
125189
in_shape=(480, 960),

0 commit comments

Comments
 (0)