Skip to content

Conversation

@christopherbate
Copy link
Collaborator

  • [executor] NFC: Remove dead code
  • [mlir-tensorrt] Fix Stablehlo dependency after last internal sync
  • [compiler] Add utility for approximately checking Stablehlo-to-Linalg convertability
  • [compiler] NFC: Update stablehlo-ext-constant-folding to drop duplicate patterns
  • [compiler] Bring stablehlo-ext-constant-folding up-to-par with stablehlo-aggressive-folders
  • [compiler] Refactor the Stablehlo input preprocessing pipeline
  • [compiler] Improve handling of constants in plan-bufferize-pipeline
  • [compiler] Update 'host' backend to support a much wider range of programs
  • [compiler] Further improve host backend
  • [compiler] Fix issues related to ToLoopsOpInterface
  • [compiler] Fix two bugs in Plan segmentation & bufferization pipeline
  • [compiler] Fix host backend integration tests to return an integer return code from entrypoint
  • [executor] Fix missing enum case for kE8M0 in TensorRT 10.12 module
  • [executor] Fix bug in cf.cond_br translation.
  • [compiler] Fix several issues in Stablehlo-to-Executable pipeline
  • [compiler] Fix potential nullptr dyn_cast in stablehlo-to-tensorrt
  • [compiler] Update CallOpInterface handling in TensorKindAnalysis
  • [executor] Add 'remf' support
  • [compiler] Fix some cases of bufferization for scf.if yielding constants
  • [compiler] Fix missing nullptr check in 01ad2ee2ce6cc392554922b9b0fe2c8fec56c246
  • [compiler] Add support for stablehlo.optimization_barrier
  • [compiler] Fix bufferization uses of bufferization.materialize_in_destination
  • [cmake][Executor] Fix standalone mlir-executor build
  • [mlir-tensorrt][cmake] Add TRT 10.12 download support

shelkesagar29 and others added 24 commits July 10, 2025 14:55
GitOrigin-RevId: b4edc6d7aa468cb48b5995616f7d776d43c62a52
… convertability

Adds a utility under StablehloExt to check if a Stablehlo op can be converted
to a Linalg op. Used inside some clustering routines.

GitOrigin-RevId: 2559f856bae90a0ce6a5cd7c79e2cf8cf932d207
…cate patterns

Some patterns that we have in `stablehlo-ext-constant-folding` are
also present in the upstream `stablehlo-aggressive-simplification`
patterns. This change drops the duplicate patterns.

GitOrigin-RevId: acab597927015301a45bfaa44bb544470d5b6bc3
…blehlo-aggressive-folders`

This change updates patterns in `stablehlo-ext-constant-folding` to
ensure we can pass the upstream tests for
`stablehlo-aggressive-folders`.

A couple additional patterns were needed for simple binary elementwise
op folders.

Two bugs/deficiencies were found and corrected:

- Our utility for performing conversions between element types was not
  correctly returning failure during float->integer element
  conversions if a non-convertible value was present in the source
  array (e.g. NaN or Inf).

- Our pattern to fold `stablehlo.transpose` would crash on 0-d input
  tensors.

Note that we previously did not (and still do not) use the upstream
`stablehlo-aggressive-folders` patterns due to some deficiencies
related to enforcing the size limits. The upstream patterns don't
consistently enforce size limits on folder patterns when the
operand/result types are integer tensors. In addition, they often
abort in cases where the op result can be inferred to be static but
the specified type is a generalization of the inferred type. We handle
this through cast insertion. We should work on contributing upstream
to fix these deficiencies, which will allow us to depuplicate the
patterns in `stablehlo-ext-constant-folding` and
`stablehlo-aggressive-folders`.

GitOrigin-RevId: cf38cacfdc9c2665fb7298a757041be283fa4523
This change refactors the Stablehlo input preprocessing pipeline to be
signficantly simpler and allow for a greater degree of control with
respect to constant folding and optimization patterns:

- The `stablehlo-ext-constant-folding` now uses a pass options to
  control the tensor volume limit (instead of a hard-coded constant).
  This is is piped through and exposed to the top-level
  `stablehlo-to-executable` pipeline as well. The default remains set
  to the same value as the previously used hard-coded constant.

- We add a new pass `stablehlo-ext-target-specific-optimizations` that
  applies target-specific optimization patterns to the Stablehlo input
  IR. The allowd pattern sets are those which are currently also
  exposed in the separate passes
  (`stablehlo-ext-canonicalize-dot-general`,
  `stablehlo-ext-canonicalize-gather`,
  `stablehlo-ext-canonicalize-scatter`,
  `stablehlo-ext-canonicalize-convolution`, etc). As discovered
  through recent performance work, these "optimizations" that we are
  using for e.g. `stablehlo.dot_general` are not actually always
  beneficial. The purpose of this change is therefore to expose
  options to disable certain pattern sets. The default is currently
  set to `all` which enables all patterns which were previously being
  applied as separate passes. This allows for significant
  simplification of the pipeline, making it easier to understand and
  control (and removes phase-ordering when running as separate
  passes). The option to control active pattern sets is also exposed
  to the top-level `stablehlo-to-executable` pipeline.

- The overall input preprocessing pipeline is now simplified. We only
  run the inliner once vs. previously we were running it swice. The
  number of shape and canonicalization patterns we use is
  significantly reduced.

GitOrigin-RevId: 89103ece3fcdda9980ab5eb2d730259902543a7d
This change resolves several TODOs related to handling of constant
tensors in the bufferization pipeline. Several patterns in
`plan-alloc-tensors` can be removed now that we have migrated entirely
to using the tensor encodings prior to bufferization.

GitOrigin-RevId: d983ad2ce92b644ad3d096316a2669dd44a5af06
programs

Updates to clutering rules to allow Host backend to
accept a range of upstream-provided MLIR ops types (e.g.
Arith/Math/Tensor/Linalg) and Stablehlo that can be converted to
that IR.

GitOrigin-RevId: bc2baa4c3e277c231dab52823e8d4bdd34c97f7a
Makes the following improvements and bug fixes in support of the host
backend:

- Fix issue with Stablehlo signed-to-signless conversion. Add back the
  `stablehlo-convert-unsigned-to-signless` conversion pass into the
  pipeline. Cherry-pick upstream bug fix in stablehlo for that pass.

- Allow module-scope `plan.memory_space` attributes to specify the
  default memory space for all functions in the module that don't have
  a `plan.memory_space` attribute or a `plan.cluster_kind` attribute.

- Improve `scf-detensorize-loops` to allow hoisting of
  `tensor.insert`\
  `tensor.extract` operations out of for-style loops.

- Fix missing `cf.assert` to `executor.assert` conversion in the
  `std-to-executor` pass.

- Add an additional Plan bufferization pipeline test that checks
  whether efficient code can be generated when there are private
  `no_inline` annotated functions being called.

- Adds additional Stablehlo host backend integration tests.

GitOrigin-RevId: 1ae2b42887c97c1d1525d40ca11e9a87583da30c
GitOrigin-RevId: fb6b40d6e5f8fbdd1609e5a202cbbb373a9efef8
- Fix `scf.while` bufferization issues related to tied iteration
  arguments in the before/after regions bufferizing to different
  memory spaces.

- Fix host backend configuration, which was allowing constants to be
  clusterable rather than cloned during outlining.

GitOrigin-RevId: ccceffd99168bfa043d60899a23d198208af2f97
…turn code from entrypoint

Fixes a failure discovered on debug builds. The new integration tests
for the host backend were not returning an integer return code from
the entrypoint. Adds improved logic to the runtime to catch this
mistake earlier.

GitOrigin-RevId: fc98de008e7fd65221947aee10c458bfd997b4e5
Fixes a warning with TRT 10.12.

GitOrigin-RevId: 2d23f9c2873893acbc73a632e369e62c932ed122
Fixes a bug in the executor's translation of `cf.cond_br` that was
causing incorrect translation of nested loops. During conditional
branch, the variables for the block arguments of the target blocks
should be assigned conditionally.

GitOrigin-RevId: de5da9ed8e07bb06c201a29c2ddf4d5e7c171399
Fixes the following issues related to stablehlo
preprocessing and bufferization:

- Fix imcorrect use of `dyn_cast` in
  `PromoteHostTensorsToHostPinned` pass that could
  cause a crash.

- Fix outdated use of TensorKindAnalysis in
  AllocTensors. This is no longer needed and could
  cause a crash when the IR is updated and solver
  state goes out-of-date. Bufferization tests are
  also updated.

- During the Plan bufferization pipeline, move
  inlining so that it occurs right before the
  `plan-alloc-tensors`. This is OK since at this
  point we have already assignned memory spaces
  and no longer need the funciton-scoped memory
  space hints. It is required since we get better
  results out of DPS loop rewriting when the
  functions have been inlined since `func.call` is
  not DPS.

- Fix missing `stablehlo.dynamic_update_slice` in
  `canConvertToLinalg` utility function.

- Fix incorrect logic in
  `plan-assign-memory-spaces` pass that cause
  internal type converter to not look for
  module-scoped default memory space before
  choosing the default space (device).

GitOrigin-RevId: e079888189d5c5e5f168bac1efdc5c3e8ad35061
Fixes a potential llvm::dyn_cast on nullptr in
`stablehlo-to-tensorrt`.

GitOrigin-RevId: 830065398f030e087ed7338aa829ab9df0f6f9c7
The MLIR backwards sparse dataflow framework has a hook called
`visitCallOperand`. Despite the name, its original purpose is to allow
analyses to specify how to handle CallOpInterface operands that are
*not* tied to the block arguments of the callee.

In the TensorKindAnalysis, we were assuming that this is the only
purpose of the hook and have been returning "host" for all such
operands. However, this is actually incorrect since by default the
`SparseBackwardDataFlowAnalysis` will also call `visitCallOperand` for
each operand when the callee is not visible to the solver (e.g. because
it is an SSA value only a declaration).

This commit updates the TensorKindAnalysis to explicitly implement
`visitExternalCall`, which gets us closer to the desired behavior. In
`visitExternalCall`, we set the lattice to "host" for all operands that
are not forwarded and invoke `setToExitState` for the operands that are
forwarded.

In the future we should consider constraints specified on the callee's
argument attributes, but there is the additional complication that we
don't know whether a concurrent pass is changing that IR, so it is
technically not valid to access.

We can perhaps can sidestep most issues by only running analysis in
dedicated module-scoped analysis passes that just cache the result for
the more fine-grained function passes to access.

GitOrigin-RevId: 61b1af910f50704b5b9c9cc55605481ea92fa617
Adds an op 'executor.remf' equivalent to 'llvm|arith.remf'. Adds
conversion from arith dialect and Lua implementation + relevant tests.

GitOrigin-RevId: f22b0c93449ced4fb1d56ecb5f123766cf51320e
Previously, some patterns involving `scf.if` where one branch returns a
constant and the other branch returns a tensor produced by
`bufferization.materialize_in_destination` with a copy-between-spaces
semantic were producing a compiler error.

GitOrigin-RevId: 01ad2ee2ce6cc392554922b9b0fe2c8fec56c246
…c8fec56c246

Fixes missing nullptr check in the previous change and adds an
additional test case.

GitOrigin-RevId: 0ad873e2965993d460fa71bcfa3261fad6c85dc5
This change introduces `plan.optimization_barrier` which is a utility
operation that is used to prevent some optimizations which match
patterns in the input IR (e.g. constant folding) from being applied.
Semantically, the operation is an identity -- the operands are tied to
the results -- but it disrupts potential pattern matching.

Note that the operation is still functionally pure, so it does not
prevent optimizations like loop-invariant code motion or CSE.

We can use this as a conversion target for
`stablehlo.optimization_barrier` which has the equivalent semantic in
Stablehlo.

We only need `plan.optimization_barrier` until after bufferization. So
to eliminate `plan.optimization_barrier` from the IR, we simply add a
bufferization implementation that forwards the operands to the results.

GitOrigin-RevId: 6d724a00d58fd159a1845ceec8556e5c7bf35c1e
…stination`

In (previous commit) tried to fix a number of bufferization issues
dealing with our use of `bufferization.materialize_in_destination` and
`scf.if` ops. However, I discovered a better solution, which I implement
in this change.

The problem is that `bufferization.materialize_in_destination` is not
just a tensor-land copy operation. It is meant to indicate that the
target of the copy must be the buffer which will be associated with the
`dest` SSA, and it must be bufferized in-place. Bufferization will raise
an error if the bufferization does not occur in-place. This is useful
for indicating that e.g. the resulting bufferized IR *must* copy a
source data into a particular Value with important meaning (e.g.
function output argument).

However, we were using it in a couple places (namely convert all
`tensor.cast` to `tensor.empty` +
`bufferization.materialize_in_destination` ops) where the "in place"
requirement is not necessary. This was causing bufferization failures in
edge cases associated with `scf.if`.

To fix this, we just need an alternate copy-like operation that is
bufferizable and a DesitnationStyleOp. Luckily, there is already
`linalg.copy`, which we can use as a drop-in replacement. Then, to
recover the original behavior, we just convert the `linalg.copy` to
`memref.copy` operations.

GitOrigin-RevId: 93ed038f47690d33633db23be5e8f70b3d89d119
This ensures `mlir-executor` can be built standalone by properly
depending on the `MLIRTensorRTCommon` package, which in turn requires
top-level CMake utilities.

GitOrigin-RevId: e204d59c9a95112803440565bc0810a492a1872b
@christopherbate christopherbate merged commit c1d6e9b into main Jul 10, 2025
1 check passed
@christopherbate christopherbate deleted the integrate_internal_changes-2 branch July 10, 2025 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants