Skip to content

[Relax] Use optional dtype for absent Relax dtype fields#19890

Open
tqchen wants to merge 8 commits into
apache:mainfrom
tqchen:new-task-tvm-avoid-use-void-dldatatype-as-null
Open

[Relax] Use optional dtype for absent Relax dtype fields#19890
tqchen wants to merge 8 commits into
apache:mainfrom
tqchen:new-task-tvm-avoid-use-void-dldatatype-as-null

Conversation

@tqchen

@tqchen tqchen commented Jun 25, 2026

Copy link
Copy Markdown
Member

This PR replaces Relax uses of void dtype sentinels for semantically absent dtypes with explicit optional/null representations.

Summary:

  • Make unknown TensorType dtype optional instead of represented as void.
  • Convert optional Relax dtype attrs to optional values and update inference paths to check presence directly.
  • Represent R.memory.view(..., dtype=None) with R.null_value() at the expression boundary, and lower to a concrete runtime dtype when required.
  • Reject R.memory.view(..., dtype=R.dtype("void")) in inference instead of treating the legacy void spelling as absence.
  • Handle review feedback around unknown optional dtype propagation in utility, qdq, and mixed-precision paths.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors TVM Relax's TensorTypeNode to represent unknown data types using ffi::Optionaltvm::PrimType initialized to std::nullopt instead of an opaque DLDataType placeholder. This change is propagated across various operators, type inference helpers, and transformations. The code review identified several critical issues where calling .value() on the newly introduced dtype optionals could cause crashes when the data type is unknown, specifically in utils.cc, qdq.cc (quantize/dequantize), and to_mixed_precision.cc. Additionally, a minor improvement was suggested in op_common.h to print a cleaner concrete dtype in error messages.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread src/relax/utils.cc
Comment on lines 185 to 187
if (const auto* tensor = ty.as<TensorTypeNode>()) {
dtype = tensor->dtype->dtype;
dtype = tensor->dtype.value()->dtype;
ndim = tensor->ndim;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

When permit_unknown_dtype is true and tensor->IsUnknownDtype() is true, calling tensor->dtype.value() will throw an exception or crash due to accessing an empty optional. We should safely handle the unknown dtype case when permit_unknown_dtype is enabled.

  if (const auto* tensor = ty.as<TensorTypeNode>()) {
    if (tensor->IsUnknownDtype()) {
      if (!permit_unknown_dtype) {
        return false;
      }
      dtype = DLDataType{kDLOpaqueHandle, 0, 0};
    } else {
      dtype = tensor->dtype.value()->dtype;
    }
    ndim = tensor->ndim;

Comment on lines +72 to +74
PrimType input_dtype = input_ty->dtype.value();
PrimType scale_dtype = scale_ty->dtype.value();
PrimType zp_dtype = zp_ty->dtype.value();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If any of the input tensors (input_ty, scale_ty, or zp_ty) has an unknown dtype, calling .value() directly on their dtype fields will crash. Since type inference can be run on expressions with unknown dtypes, we should check IsUnknownDtype() and return a TensorType with an unknown dtype instead of crashing.

Suggested change
PrimType input_dtype = input_ty->dtype.value();
PrimType scale_dtype = scale_ty->dtype.value();
PrimType zp_dtype = zp_ty->dtype.value();
if (input_ty->IsUnknownDtype() || scale_ty->IsUnknownDtype() || zp_ty->IsUnknownDtype()) {
if (input_ty->shape.defined()) {
return TensorType(input_ty->shape.value(), std::nullopt, input_ty->vdevice);
} else {
return TensorType(std::nullopt, input_ty->ndim, input_ty->vdevice);
}
}
PrimType input_dtype = input_ty->dtype.value();
PrimType scale_dtype = scale_ty->dtype.value();
PrimType zp_dtype = zp_ty->dtype.value();

Comment on lines +174 to +176
PrimType input_dtype = input_ty->dtype.value();
PrimType scale_dtype = scale_ty->dtype.value();
PrimType zp_dtype = zp_ty->dtype.value();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If any of the input tensors (input_ty, scale_ty, or zp_ty) has an unknown dtype, calling .value() directly on their dtype fields will crash. Since type inference can be run on expressions with unknown dtypes, we should check IsUnknownDtype() and return a TensorType with an unknown dtype instead of crashing.

  if (input_ty->IsUnknownDtype() || scale_ty->IsUnknownDtype() || zp_ty->IsUnknownDtype()) {
    if (input_ty->shape.defined()) {
      return TensorType(input_ty->shape.value(), std::nullopt, input_ty->vdevice);
    } else {
      return TensorType(std::nullopt, input_ty->ndim, input_ty->vdevice);
    }
  }
  PrimType input_dtype = input_ty->dtype.value();
  PrimType scale_dtype = scale_ty->dtype.value();
  PrimType zp_dtype = zp_ty->dtype.value();

// We only rewrite the expr if the dtype is fp16 or fp32, dtypes such as int32, float64 is not
// supported to be rewritten
DLDataType tensor_dtype = tensor->dtype->dtype;
DLDataType tensor_dtype = tensor->dtype.value()->dtype;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If tensor->IsUnknownDtype() is true, calling tensor->dtype.value() will crash. We should check if the tensor has an unknown dtype and return the original expression early.

Suggested change
DLDataType tensor_dtype = tensor->dtype.value()->dtype;
if (tensor->IsUnknownDtype()) return expr;
DLDataType tensor_dtype = tensor->dtype.value()->dtype;

Comment thread src/relax/op/op_common.h Outdated
TVM_FFI_VISIT_THROW(TypeError, call)
<< call->op
<< " requires the input tensor to have float dtype. However, the given input dtype is "
<< input_ty->dtype;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since we have already verified that !input_ty->IsUnknownDtype(), we should print input_ty->dtype.value() instead of the optional wrapper input_ty->dtype to ensure the error message displays a clean concrete dtype (e.g., float32).

Suggested change
<< input_ty->dtype;
<< input_ty->dtype.value();

@tqchen tqchen force-pushed the new-task-tvm-avoid-use-void-dldatatype-as-null branch from f58a065 to 6bc5adf Compare June 25, 2026 17:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants