diff --git a/python/tvm/relax/transform/legalize_ops/unary.py b/python/tvm/relax/transform/legalize_ops/unary.py index fc6b96a438e0..4d09c6d61cc8 100644 --- a/python/tvm/relax/transform/legalize_ops/unary.py +++ b/python/tvm/relax/transform/legalize_ops/unary.py @@ -38,6 +38,9 @@ register_legalize("relax.cosh", _call_topi_without_attr(topi.cosh, "tir_cosh")) register_legalize("relax.exp", _call_topi_without_attr(topi.exp, "tir_exp")) register_legalize("relax.floor", _call_topi_without_attr(topi.floor, "tir_floor")) +register_legalize("relax.isfinite", _call_topi_without_attr(topi.isfinite, "tir_isfinite")) +register_legalize("relax.isinf", _call_topi_without_attr(topi.isinf, "tir_isinf")) +register_legalize("relax.isnan", _call_topi_without_attr(topi.isnan, "tir_isnan")) register_legalize("relax.log", _call_topi_without_attr(topi.log, "tir_log")) register_legalize("relax.logical_not", _call_topi_without_attr(topi.logical_not, "tir_logical_not")) register_legalize("relax.negative", _call_topi_without_attr(topi.negative, "tir_negative")) diff --git a/tests/python/relax/test_transform_legalize_ops_unary.py b/tests/python/relax/test_transform_legalize_ops_unary.py index e32f95a3c4d2..4f8ee67d99ee 100644 --- a/tests/python/relax/test_transform_legalize_ops_unary.py +++ b/tests/python/relax/test_transform_legalize_ops_unary.py @@ -89,6 +89,9 @@ def main(x: R.Tensor(("m", "n"), dtype)): ("exp", R.exp, topi.exp, "float32"), ("floor", R.floor, topi.floor, "float32"), ("floor", R.floor, topi.identity, "int32"), + ("isfinite", R.isfinite, topi.isfinite, "float32"), + ("isinf", R.isinf, topi.isinf, "float32"), + ("isnan", R.isnan, topi.isnan, "float32"), ("log", R.log, topi.log, "float32"), ("negative", R.negative, topi.negative, "float32"), ("round", R.round, topi.round, "float32"),