Skip to content

Commit 64b34ae

Browse files
author
wufei2
committed
chore: support onnx export
1 parent dcae54c commit 64b34ae

File tree

1 file changed

+59
-19
lines changed

1 file changed

+59
-19
lines changed

yolo_world/models/layers/yolo_bricks.py

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,42 @@
1111
from mmyolo.registry import MODELS
1212
from mmyolo.models.layers import CSPLayerWithTwoConv
1313

14+
#AdaptiveAvgPool2dCustom and AdaptiveMaxPool2dCustom are compatible when exporting onnx format models
15+
# reference: https://github.com/pytorch/pytorch/issues/42653#issuecomment-1168816422
16+
class AdaptiveAvgPool2dCustom(nn.Module):
17+
def __init__(self, output_size):
18+
super(AdaptiveAvgPool2dCustom, self).__init__()
19+
self.output_size = torch.tensor(output_size)
20+
21+
def forward(self, x: torch.Tensor):
22+
# Calculate the stride size required to achieve the desired output size
23+
stride_size = torch.floor(torch.tensor(x.shape[-2:]) / self.output_size).to(torch.int32)
24+
25+
# Calculate the kernel size based on the stride size and desired output size
26+
kernel_size = torch.tensor(x.shape[-2:]) - (self.output_size - 1) * stride_size
27+
28+
# Create a AvgPool2d layer with the calculated kernel and stride sizes
29+
avg = nn.AvgPool2d(kernel_size.tolist(), stride=stride_size.tolist())
30+
31+
x = avg(x)
32+
return x
33+
class AdaptiveMaxPool2dCustom(nn.Module):
34+
def __init__(self, output_size):
35+
super(AdaptiveMaxPool2dCustom, self).__init__()
36+
self.output_size = torch.tensor(output_size)
37+
38+
def forward(self, x: torch.Tensor):
39+
# Calculate the stride size required to achieve the desired output size
40+
stride_size = torch.floor(torch.tensor(x.shape[-2:]) / self.output_size).to(torch.int32)
41+
42+
# Calculate the kernel size based on the stride size and desired output size
43+
kernel_size = torch.tensor(x.shape[-2:]) - (self.output_size - 1) * stride_size
44+
45+
# Create a MaxPool2d layer with the calculated kernel and stride sizes
46+
max_pool = nn.MaxPool2d(kernel_size.tolist(), stride=stride_size.tolist())
47+
48+
x = max_pool(x)
49+
return x
1450

1551
@MODELS.register_module()
1652
class MaxSigmoidAttnBlock(BaseModule):
@@ -31,7 +67,7 @@ def __init__(self,
3167
momentum=0.03,
3268
eps=0.001),
3369
init_cfg: OptMultiConfig = None,
34-
use_einsum: bool = True) -> None:
70+
export_onnx: bool = True) -> None:
3571
super().__init__(init_cfg=init_cfg)
3672
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
3773

@@ -40,7 +76,7 @@ def __init__(self,
4076
'out_channels and embed_channels should be divisible by num_heads.'
4177
self.num_heads = num_heads
4278
self.head_channels = out_channels // num_heads
43-
self.use_einsum = use_einsum
79+
self.export_onnx = export_onnx
4480

4581
self.embed_conv = ConvModule(
4682
in_channels,
@@ -73,8 +109,7 @@ def forward(self, x: Tensor, guide: Tensor) -> Tensor:
73109
guide = guide.reshape(B, -1, self.num_heads, self.head_channels)
74110
embed = self.embed_conv(x) if self.embed_conv is not None else x
75111
embed = embed.reshape(B, self.num_heads, self.head_channels, H, W)
76-
77-
if self.use_einsum:
112+
if self.export_onnx == False:
78113
attn_weight = torch.einsum('bmchw,bnmc->bmhwn', embed, guide)
79114
else:
80115
batch, m, channel, height, width = embed.shape
@@ -116,7 +151,7 @@ def __init__(self,
116151
momentum=0.03,
117152
eps=0.001),
118153
init_cfg: OptMultiConfig = None,
119-
use_einsum: bool = True) -> None:
154+
export_onnx: bool = True) -> None:
120155
super().__init__(init_cfg=init_cfg)
121156
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
122157

@@ -125,7 +160,7 @@ def __init__(self,
125160
'out_channels and embed_channels should be divisible by num_heads.'
126161
self.num_heads = num_heads
127162
self.head_channels = out_channels // num_heads
128-
self.use_einsum = use_einsum
163+
self.export_onnx = export_onnx
129164

130165
self.embed_conv = ConvModule(
131166
in_channels,
@@ -191,7 +226,7 @@ def __init__(
191226
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
192227
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
193228
init_cfg: OptMultiConfig = None,
194-
use_einsum: bool = True) -> None:
229+
export_onnx: bool = True) -> None:
195230
super().__init__(in_channels=in_channels,
196231
out_channels=out_channels,
197232
expand_ratio=expand_ratio,
@@ -217,7 +252,7 @@ def __init__(
217252
with_scale=with_scale,
218253
conv_cfg=conv_cfg,
219254
norm_cfg=norm_cfg,
220-
use_einsum=use_einsum)
255+
export_onnx=export_onnx)
221256

222257
def forward(self, x: Tensor, guide: Tensor) -> Tensor:
223258
"""Forward process."""
@@ -247,7 +282,7 @@ def __init__(
247282
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
248283
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
249284
init_cfg: OptMultiConfig = None,
250-
use_einsum: bool = True) -> None:
285+
export_onnx: bool = True) -> None:
251286
super().__init__(in_channels=in_channels,
252287
out_channels=out_channels,
253288
expand_ratio=expand_ratio,
@@ -274,7 +309,7 @@ def __init__(
274309
with_scale=with_scale,
275310
conv_cfg=conv_cfg,
276311
norm_cfg=norm_cfg,
277-
use_einsum=use_einsum)
312+
export_onnx=export_onnx)
278313

279314
def forward(self, x: Tensor, guide: Tensor) -> Tensor:
280315
"""Forward process."""
@@ -296,7 +331,7 @@ def __init__(self,
296331
num_feats: int = 3,
297332
num_heads: int = 8,
298333
pool_size: int = 3,
299-
use_einsum: bool = True):
334+
export_onnx: bool = True):
300335
super().__init__()
301336

302337
self.text_channels = text_channels
@@ -305,7 +340,7 @@ def __init__(self,
305340
self.num_feats = num_feats
306341
self.head_channels = embed_channels // num_heads
307342
self.pool_size = pool_size
308-
self.use_einsum = use_einsum
343+
self.export_onnx = export_onnx
309344
if with_scale:
310345
self.scale = nn.Parameter(torch.tensor([0.]), requires_grad=True)
311346
else:
@@ -321,11 +356,16 @@ def __init__(self,
321356
self.value = nn.Sequential(nn.LayerNorm(embed_channels),
322357
Linear(embed_channels, embed_channels))
323358
self.proj = Linear(embed_channels, text_channels)
324-
325-
self.image_pools = nn.ModuleList([
326-
nn.AdaptiveMaxPool2d((pool_size, pool_size))
327-
for _ in range(num_feats)
328-
])
359+
if self.export_onnx == False:
360+
self.image_pools = nn.ModuleList([
361+
nn.AdaptiveMaxPool2d((pool_size, pool_size))
362+
for _ in range(num_feats)
363+
])
364+
else:
365+
self.image_pools = nn.ModuleList([
366+
AdaptiveMaxPool2dCustom((pool_size, pool_size))
367+
for _ in range(num_feats)
368+
])
329369

330370
def forward(self, text_features, image_features):
331371
B = image_features[0].shape[0]
@@ -345,7 +385,7 @@ def forward(self, text_features, image_features):
345385
q = q.reshape(B, -1, self.num_heads, self.head_channels)
346386
k = k.reshape(B, -1, self.num_heads, self.head_channels)
347387
v = v.reshape(B, -1, self.num_heads, self.head_channels)
348-
if self.use_einsum:
388+
if self.export_onnx == False:
349389
attn_weight = torch.einsum('bnmc,bkmc->bmnk', q, k)
350390
else:
351391
q = q.permute(0, 2, 1, 3)
@@ -354,7 +394,7 @@ def forward(self, text_features, image_features):
354394

355395
attn_weight = attn_weight / (self.head_channels**0.5)
356396
attn_weight = F.softmax(attn_weight, dim=-1)
357-
if self.use_einsum:
397+
if self.export_onnx == False:
358398
x = torch.einsum('bmnk,bkmc->bnmc', attn_weight, v)
359399
else:
360400
v = v.permute(0, 2, 1, 3)

0 commit comments

Comments
 (0)