Skip to content

Commit a8725ee

Browse files
committed
translate upsample_bicubic2d into bilinear layer
1 parent 6d46721 commit a8725ee

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3937,6 +3937,9 @@ def _translate_torch_args(x, output_size, align_corners, scales) -> Var:
39373937
"upsample_bilinear2d.vec",
39383938
"_upsample_bilinear2d_aa",
39393939
"_upsample_bilinear2d_aa.vec",
3940+
# Note that in CoreML, there is no bicubic2d layer,
3941+
# hence we use the bilinear layer for the approximation.
3942+
"upsample_bicubic2d",
39403943
],
39413944
)
39423945
def upsample_bilinear2d(context, node):
@@ -4003,6 +4006,13 @@ def _translate_torch_args(x, output_size, align_corners, scales_h, scales_w) ->
40034006
scales_h, scales_w = scales
40044007
return scales_h, scales_w
40054008

4009+
if node.kind == "upsample_bicubic2d":
4010+
logger.warning(
4011+
"upsample_bicubic2d is not supported in CoreML. "
4012+
"Hence be approximated by the upsample_bilinear layer. "
4013+
"This could lead to incorrect inference results!"
4014+
)
4015+
40064016
x, output_size, align_corners, scales_h, scales_w = _parse_positional_args(context, node)
40074017
scales_h, scales_w = _parse_keyword_args(context, node, scales_h, scales_w)
40084018
scales_h, scales_w = _translate_torch_args(x, output_size, align_corners, scales_h, scales_w)

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2872,6 +2872,33 @@ def forward(self, args):
28722872
if layer.WhichOneof("layer") == "upsample":
28732873
assert len(layer.upsample.fractionalScalingFactor) == 0
28742874

2875+
@pytest.mark.parametrize(
2876+
"compute_unit, backend, frontend, scales_h, scales_w",
2877+
itertools.product(compute_units, backends, frontends, [2, 3, 4.5], [4, 5, 5.5]),
2878+
)
2879+
def test_upsample_upsample_bicubic2d_with_scales(
2880+
self, compute_unit, backend, frontend, scales_h, scales_w
2881+
):
2882+
if backend[0] == "neuralnetwork":
2883+
if isinstance(scales_h, float) or isinstance(scales_w, float):
2884+
return # Skip fractional scale factors tests for neuralnetwork
2885+
2886+
input_shape = (1, 3, 10, 10)
2887+
class Model(nn.Module):
2888+
def forward(self, x):
2889+
y = nn.functional.interpolate(
2890+
x,
2891+
scale_factor=(scales_h, scales_w),
2892+
mode="bicubic"
2893+
)
2894+
# since we approximate bicubic2d with bilinear,
2895+
# we only check the output shape.
2896+
return torch.tensor(y.shape)
2897+
2898+
self.run_compare_torch(
2899+
input_shape, Model().eval(), frontend=frontend, backend=backend, compute_unit=compute_unit
2900+
)
2901+
28752902

28762903
class TestEmpty(TorchBaseTest):
28772904
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)