4141
4242
4343class OverlapPatchMerging (nn .Module ):
44+ """
45+ OverlapPatchMerging layer for merging patches.
46+
47+ Parameters
48+ -----------
49+ in_shape : tuple
50+ Input shape (height, width)
51+ out_shape : tuple
52+ Output shape (height, width)
53+ in_channels : int
54+ Number of input channels
55+ out_channels : int
56+ Number of output channels
57+ kernel_shape : tuple
58+ Kernel shape for convolution
59+ bias : bool, optional
60+ Whether to use bias, by default False
61+ """
4462 def __init__ (
4563 self ,
4664 in_shape = (721 , 1440 ),
@@ -88,6 +106,30 @@ def forward(self, x):
88106
89107
90108class MixFFN (nn .Module ):
109+ """
110+ MixFFN module combining MLP and depthwise convolution.
111+
112+ Parameters
113+ -----------
114+ shape : tuple
115+ Input shape (height, width)
116+ inout_channels : int
117+ Number of input/output channels
118+ hidden_channels : int
119+ Number of hidden channels in MLP
120+ mlp_bias : bool, optional
121+ Whether to use bias in MLP layers, by default True
122+ kernel_shape : tuple, optional
123+ Kernel shape for depthwise convolution, by default (3, 3)
124+ conv_bias : bool, optional
125+ Whether to use bias in convolution, by default False
126+ activation : callable, optional
127+ Activation function, by default nn.GELU
128+ use_mlp : bool, optional
129+ Whether to use MLP instead of linear layers, by default False
130+ drop_path : float, optional
131+ Drop path rate, by default 0.0
132+ """
91133 def __init__ (
92134 self ,
93135 shape ,
@@ -142,7 +184,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
142184 x = x .permute (0 , 3 , 1 , 2 )
143185
144186 # NOTE: we add another activation here
145- # because in the paper they only use depthwise conv,
187+ # because in the paper the authors only use depthwise conv,
146188 # but without this activation it would just be a fused MM
147189 # with the disco conv
148190 x = self .mlp_in (x )
@@ -162,13 +204,25 @@ class GlobalAttention(nn.Module):
162204
163205 Input shape: (B, C, H, W)
164206 Output shape: (B, C, H, W) with residual skip.
207+
208+ Parameters
209+ -----------
210+ chans : int
211+ Number of channels
212+ num_heads : int, optional
213+ Number of attention heads, by default 8
214+ dropout : float, optional
215+ Dropout rate, by default 0.0
216+ bias : bool, optional
217+ Whether to use bias, by default True
165218 """
166219
167220 def __init__ (self , chans , num_heads = 8 , dropout = 0.0 , bias = True ):
168221 super ().__init__ ()
169222 self .attn = nn .MultiheadAttention (embed_dim = chans , num_heads = num_heads , dropout = dropout , batch_first = True , bias = bias )
170223
171224 def forward (self , x ):
225+
172226 # x: B, C, H, W
173227 B , H , W , C = x .shape
174228 # flatten spatial dims
@@ -181,6 +235,30 @@ def forward(self, x):
181235
182236
183237class AttentionWrapper (nn .Module ):
238+ """
239+ Wrapper for different attention mechanisms.
240+
241+ Parameters
242+ -----------
243+ channels : int
244+ Number of channels
245+ shape : tuple
246+ Input shape (height, width)
247+ heads : int
248+ Number of attention heads
249+ pre_norm : bool, optional
250+ Whether to apply normalization before attention, by default False
251+ attention_drop_rate : float, optional
252+ Attention dropout rate, by default 0.0
253+ drop_path : float, optional
254+ Drop path rate, by default 0.0
255+ attention_mode : str, optional
256+ Attention mode ("neighborhood", "global"), by default "neighborhood"
257+ kernel_shape : tuple, optional
258+ Kernel shape for neighborhood attention, by default (7, 7)
259+ bias : bool, optional
260+ Whether to use bias, by default True
261+ """
184262 def __init__ (self , channels , shape , heads , pre_norm = False , attention_drop_rate = 0.0 , drop_path = 0.0 , attention_mode = "neighborhood" , kernel_shape = (7 , 7 ), bias = True ):
185263 super ().__init__ ()
186264
@@ -203,11 +281,13 @@ def __init__(self, channels, shape, heads, pre_norm=False, attention_drop_rate=0
203281 self .apply (self ._init_weights )
204282
205283 def _init_weights (self , m ):
284+
206285 if isinstance (m , nn .LayerNorm ):
207286 nn .init .constant_ (m .bias , 0 )
208287 nn .init .constant_ (m .weight , 1.0 )
209288
210289 def forward (self , x : torch .Tensor ) -> torch .Tensor :
290+
211291 residual = x
212292 x = x .permute (0 , 2 , 3 , 1 )
213293 if self .norm is not None :
@@ -219,6 +299,41 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
219299
220300
221301class TransformerBlock (nn .Module ):
302+ """
303+ Transformer block with attention and MLP.
304+
305+ Parameters
306+ ----------
307+ in_shape : tuple
308+ Input shape (height, width)
309+ out_shape : tuple
310+ Output shape (height, width)
311+ in_channels : int
312+ Number of input channels
313+ out_channels : int
314+ Number of output channels
315+ mlp_hidden_channels : int
316+ Number of hidden channels in MLP
317+ nrep : int, optional
318+ Number of repetitions of attention and MLP blocks, by default 1
319+ heads : int, optional
320+ Number of attention heads, by default 1
321+ kernel_shape : tuple, optional
322+ Kernel shape for neighborhood attention, by default (3, 3)
323+ activation : torch.nn.Module, optional
324+ Activation function to use, by default nn.GELU
325+ att_drop_rate : float, optional
326+ Attention dropout rate, by default 0.0
327+ drop_path_rates : float or list, optional
328+ Drop path rates for each block, by default 0.0
329+ attention_mode : str, optional
330+ Attention mode ("neighborhood", "global"), by default "neighborhood"
331+ attn_kernel_shape : tuple, optional
332+ Kernel shape for neighborhood attention, by default (7, 7)
333+ bias : bool, optional
334+ Whether to use bias, by default True
335+ """
336+
222337 def __init__ (
223338 self ,
224339 in_shape ,
@@ -341,6 +456,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
341456
342457
343458class Upsampling (nn .Module ):
459+ """
460+ Upsampling block for the Segformer model.
461+
462+ Parameters
463+ ----------
464+ in_shape : tuple
465+ Input shape (height, width)
466+ out_shape : tuple
467+ Output shape (height, width)
468+ in_channels : int
469+ Number of input channels
470+ out_channels : int
471+ Number of output channels
472+ hidden_channels : int
473+ Number of hidden channels in MLP
474+ mlp_bias : bool, optional
475+ Whether to use bias in MLP, by default True
476+ kernel_shape : tuple, optional
477+ Kernel shape for convolution, by default (3, 3)
478+ conv_bias : bool, optional
479+ Whether to use bias in convolution, by default False
480+ activation : torch.nn.Module, optional
481+ Activation function to use, by default nn.GELU
482+ use_mlp : bool, optional
483+ Whether to use MLP, by default False
484+ """
485+
344486 def __init__ (
345487 self ,
346488 in_shape ,
@@ -382,7 +524,7 @@ class Segformer(nn.Module):
382524 Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
383525
384526 Parameters
385- -----------
527+ ----------
386528 img_shape : tuple, optional
387529 Shape of the input channels, by default (128, 256)
388530 kernel_shape: tuple, int
@@ -414,7 +556,7 @@ class Segformer(nn.Module):
414556 Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
415557
416558 Example
417- -----------
559+ ----------
418560 >>> model = Segformer(
419561 ... img_size=(128, 256),
420562 ... in_chans=3,
0 commit comments