diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 370f4401a..afc68c2ee 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -5271,7 +5271,7 @@ def baddbmm(context, node): inputs = _get_inputs(context, node, expected=5) bias, batch1, batch2, beta, alpha = inputs - if beta.val != 1.0: + if beta.val != 0.0: # Apply scaling factor beta to the bias. bias = mb.mul(x=beta, y=bias, name=bias.name + "_scaled") context.add(bias)