1111from mmyolo .registry import MODELS
1212from mmyolo .models .layers import CSPLayerWithTwoConv
1313
14+ #AdaptiveAvgPool2dCustom and AdaptiveMaxPool2dCustom are compatible when exporting onnx format models
15+ # reference: https://github.com/pytorch/pytorch/issues/42653#issuecomment-1168816422
16+ class AdaptiveAvgPool2dCustom (nn .Module ):
17+ def __init__ (self , output_size ):
18+ super (AdaptiveAvgPool2dCustom , self ).__init__ ()
19+ self .output_size = torch .tensor (output_size )
20+
21+ def forward (self , x : torch .Tensor ):
22+ # Calculate the stride size required to achieve the desired output size
23+ stride_size = torch .floor (torch .tensor (x .shape [- 2 :]) / self .output_size ).to (torch .int32 )
24+
25+ # Calculate the kernel size based on the stride size and desired output size
26+ kernel_size = torch .tensor (x .shape [- 2 :]) - (self .output_size - 1 ) * stride_size
27+
28+ # Create a AvgPool2d layer with the calculated kernel and stride sizes
29+ avg = nn .AvgPool2d (kernel_size .tolist (), stride = stride_size .tolist ())
30+
31+ x = avg (x )
32+ return x
33+ class AdaptiveMaxPool2dCustom (nn .Module ):
34+ def __init__ (self , output_size ):
35+ super (AdaptiveMaxPool2dCustom , self ).__init__ ()
36+ self .output_size = torch .tensor (output_size )
37+
38+ def forward (self , x : torch .Tensor ):
39+ # Calculate the stride size required to achieve the desired output size
40+ stride_size = torch .floor (torch .tensor (x .shape [- 2 :]) / self .output_size ).to (torch .int32 )
41+
42+ # Calculate the kernel size based on the stride size and desired output size
43+ kernel_size = torch .tensor (x .shape [- 2 :]) - (self .output_size - 1 ) * stride_size
44+
45+ # Create a MaxPool2d layer with the calculated kernel and stride sizes
46+ max_pool = nn .MaxPool2d (kernel_size .tolist (), stride = stride_size .tolist ())
47+
48+ x = max_pool (x )
49+ return x
1450
1551@MODELS .register_module ()
1652class MaxSigmoidAttnBlock (BaseModule ):
@@ -31,7 +67,7 @@ def __init__(self,
3167 momentum = 0.03 ,
3268 eps = 0.001 ),
3369 init_cfg : OptMultiConfig = None ,
34- use_einsum : bool = True ) -> None :
70+ export_onnx : bool = True ) -> None :
3571 super ().__init__ (init_cfg = init_cfg )
3672 conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
3773
@@ -40,7 +76,7 @@ def __init__(self,
4076 'out_channels and embed_channels should be divisible by num_heads.'
4177 self .num_heads = num_heads
4278 self .head_channels = out_channels // num_heads
43- self .use_einsum = use_einsum
79+ self .export_onnx = export_onnx
4480
4581 self .embed_conv = ConvModule (
4682 in_channels ,
@@ -73,8 +109,7 @@ def forward(self, x: Tensor, guide: Tensor) -> Tensor:
73109 guide = guide .reshape (B , - 1 , self .num_heads , self .head_channels )
74110 embed = self .embed_conv (x ) if self .embed_conv is not None else x
75111 embed = embed .reshape (B , self .num_heads , self .head_channels , H , W )
76-
77- if self .use_einsum :
112+ if self .export_onnx == False :
78113 attn_weight = torch .einsum ('bmchw,bnmc->bmhwn' , embed , guide )
79114 else :
80115 batch , m , channel , height , width = embed .shape
@@ -116,7 +151,7 @@ def __init__(self,
116151 momentum = 0.03 ,
117152 eps = 0.001 ),
118153 init_cfg : OptMultiConfig = None ,
119- use_einsum : bool = True ) -> None :
154+ export_onnx : bool = True ) -> None :
120155 super ().__init__ (init_cfg = init_cfg )
121156 conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
122157
@@ -125,7 +160,7 @@ def __init__(self,
125160 'out_channels and embed_channels should be divisible by num_heads.'
126161 self .num_heads = num_heads
127162 self .head_channels = out_channels // num_heads
128- self .use_einsum = use_einsum
163+ self .export_onnx = export_onnx
129164
130165 self .embed_conv = ConvModule (
131166 in_channels ,
@@ -191,7 +226,7 @@ def __init__(
191226 norm_cfg : ConfigType = dict (type = 'BN' , momentum = 0.03 , eps = 0.001 ),
192227 act_cfg : ConfigType = dict (type = 'SiLU' , inplace = True ),
193228 init_cfg : OptMultiConfig = None ,
194- use_einsum : bool = True ) -> None :
229+ export_onnx : bool = True ) -> None :
195230 super ().__init__ (in_channels = in_channels ,
196231 out_channels = out_channels ,
197232 expand_ratio = expand_ratio ,
@@ -217,7 +252,7 @@ def __init__(
217252 with_scale = with_scale ,
218253 conv_cfg = conv_cfg ,
219254 norm_cfg = norm_cfg ,
220- use_einsum = use_einsum )
255+ export_onnx = export_onnx )
221256
222257 def forward (self , x : Tensor , guide : Tensor ) -> Tensor :
223258 """Forward process."""
@@ -247,7 +282,7 @@ def __init__(
247282 norm_cfg : ConfigType = dict (type = 'BN' , momentum = 0.03 , eps = 0.001 ),
248283 act_cfg : ConfigType = dict (type = 'SiLU' , inplace = True ),
249284 init_cfg : OptMultiConfig = None ,
250- use_einsum : bool = True ) -> None :
285+ export_onnx : bool = True ) -> None :
251286 super ().__init__ (in_channels = in_channels ,
252287 out_channels = out_channels ,
253288 expand_ratio = expand_ratio ,
@@ -274,7 +309,7 @@ def __init__(
274309 with_scale = with_scale ,
275310 conv_cfg = conv_cfg ,
276311 norm_cfg = norm_cfg ,
277- use_einsum = use_einsum )
312+ export_onnx = export_onnx )
278313
279314 def forward (self , x : Tensor , guide : Tensor ) -> Tensor :
280315 """Forward process."""
@@ -296,7 +331,7 @@ def __init__(self,
296331 num_feats : int = 3 ,
297332 num_heads : int = 8 ,
298333 pool_size : int = 3 ,
299- use_einsum : bool = True ):
334+ export_onnx : bool = True ):
300335 super ().__init__ ()
301336
302337 self .text_channels = text_channels
@@ -305,7 +340,7 @@ def __init__(self,
305340 self .num_feats = num_feats
306341 self .head_channels = embed_channels // num_heads
307342 self .pool_size = pool_size
308- self .use_einsum = use_einsum
343+ self .export_onnx = export_onnx
309344 if with_scale :
310345 self .scale = nn .Parameter (torch .tensor ([0. ]), requires_grad = True )
311346 else :
@@ -321,11 +356,16 @@ def __init__(self,
321356 self .value = nn .Sequential (nn .LayerNorm (embed_channels ),
322357 Linear (embed_channels , embed_channels ))
323358 self .proj = Linear (embed_channels , text_channels )
324-
325- self .image_pools = nn .ModuleList ([
326- nn .AdaptiveMaxPool2d ((pool_size , pool_size ))
327- for _ in range (num_feats )
328- ])
359+ if self .export_onnx == False :
360+ self .image_pools = nn .ModuleList ([
361+ nn .AdaptiveMaxPool2d ((pool_size , pool_size ))
362+ for _ in range (num_feats )
363+ ])
364+ else :
365+ self .image_pools = nn .ModuleList ([
366+ AdaptiveMaxPool2dCustom ((pool_size , pool_size ))
367+ for _ in range (num_feats )
368+ ])
329369
330370 def forward (self , text_features , image_features ):
331371 B = image_features [0 ].shape [0 ]
@@ -345,7 +385,7 @@ def forward(self, text_features, image_features):
345385 q = q .reshape (B , - 1 , self .num_heads , self .head_channels )
346386 k = k .reshape (B , - 1 , self .num_heads , self .head_channels )
347387 v = v .reshape (B , - 1 , self .num_heads , self .head_channels )
348- if self .use_einsum :
388+ if self .export_onnx == False :
349389 attn_weight = torch .einsum ('bnmc,bkmc->bmnk' , q , k )
350390 else :
351391 q = q .permute (0 , 2 , 1 , 3 )
@@ -354,7 +394,7 @@ def forward(self, text_features, image_features):
354394
355395 attn_weight = attn_weight / (self .head_channels ** 0.5 )
356396 attn_weight = F .softmax (attn_weight , dim = - 1 )
357- if self .use_einsum :
397+ if self .export_onnx == False :
358398 x = torch .einsum ('bmnk,bkmc->bnmc' , attn_weight , v )
359399 else :
360400 v = v .permute (0 , 2 , 1 , 3 )
0 commit comments