diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 12924428a..a61a33d21 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -1502,8 +1502,15 @@ def maximum(context, node): @register_torch_op def div(context, node): inputs = _get_inputs(context, node, expected=[2, 3]) - x = mb.cast(x=inputs[0], dtype="fp32") - y = mb.cast(x=inputs[1], dtype="fp32") + x = inputs[0] + y = inputs[1] + if not types.is_float(x.dtype) and not types.is_float(y.dtype): + x = mb.cast(x=x, dtype="fp32") + y = mb.cast(x=y, dtype="fp32") + elif not types.is_float(x.dtype): + x = mb.cast(x=x, dtype=y.dtype) + elif not types.is_float(y.dtype): + y = mb.cast(x=y, dtype=x.dtype) if len(inputs) > 2 and inputs[2] is not None: rounding_mode = inputs[2].val