Skip to content

Commit 1596e5f

Browse files
authored
[Dialect/Plan] Update Plan dialect to use non-dps calling convention (#285)
The Plan dialect has been enhanced to support non-DPS calling conventions, with updates to PlanAllocTensorsPass and CreateClosedRegionsPass. Additionally, the Plan transformation now converts both non-DPS and DPS group operations to CallAllocOp and CallOp respectively, required for subsequent transformations.
1 parent 8923f16 commit 1596e5f

File tree

14 files changed

+827
-165
lines changed

14 files changed

+827
-165
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {
128128
/// Whether to disallow host tensors in TensorRT clusters.
129129
bool disallowHostTensorsInTensorRTClusters = false;
130130

131+
/// Use non-DPS style calling convention for entrypoint function
132+
/// and backend types that support allocating results.
133+
bool enableNonDPSReturns = false;
134+
131135
/// Entrypoint function name.
132136
std::string entrypoint = "main";
133137

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanOps.td

Lines changed: 103 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,41 @@ def Plan_InlineGroupOp : Plan_GroupOpBase<"inline_group", [
131131
}
132132

133133
//===----------------------------------------------------------------------===//
134-
// InlineClosedGroupOp
134+
// Plan_InlineClosedGroupBase
135135
//===----------------------------------------------------------------------===//
136136

137-
def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
138-
IsolatedFromAbove,
137+
class Plan_InlineClosedGroupBase<string mnemonic, list<Trait> traits = []> :
138+
Plan_GroupOpBase<mnemonic, traits # [IsolatedFromAbove]> {
139+
140+
code baseInlineClosedExtraClassDeclaration = baseExtraClassDeclaration # [{
141+
// Common methods for both DPS and non-DPS versions
142+
bool argHasTensorType(unsigned inputIdx) {
143+
assert(inputIdx < getInputs().size() && "input index out-of-bounds");
144+
return isa<RankedTensorType>(getInputs()[inputIdx].getType());
145+
}
146+
147+
BoundsAttr getInputBoundsAttr(unsigned inputIdx) {
148+
assert(inputIdx < getInputs().size() && "input index out-of-bounds");
149+
return cast<BoundsAttr>(getInputAttrs()[inputIdx]);
150+
}
151+
152+
/// Populate the `input_attrs` from an array of BoundsAttrs.
153+
void setInputAttrsAttr(ArrayRef<BoundsAttr> boundsAttrs) {
154+
setInputAttrsAttr(::mlir::ArrayAttr::get(
155+
getOperation()->getContext(),
156+
ArrayRef<Attribute>(boundsAttrs.begin(), boundsAttrs.end())
157+
));
158+
}
159+
}];
160+
161+
let extraClassDeclaration = baseInlineClosedExtraClassDeclaration;
162+
}
163+
164+
//===----------------------------------------------------------------------===//
165+
// Plan_InlineClosedGroupOp
166+
//===----------------------------------------------------------------------===//
167+
168+
def Plan_InlineClosedGroupOp : Plan_InlineClosedGroupBase<"inline_closed_group", [
139169
AttrSizedOperandSegments,
140170
DestinationStyleOpInterface,
141171
SingleBlockImplicitTerminator<"plan::YieldOp">,
@@ -226,24 +256,12 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
226256
CArg<"ArrayRef<BoundsAttr>", "{}">:$res_attrs)>
227257
];
228258

229-
let extraClassDeclaration = baseExtraClassDeclaration # [{
259+
let extraClassDeclaration = baseInlineClosedExtraClassDeclaration # [{
230260

231261
MutableOperandRange getDpsInitsMutable() {
232262
return getOutsMutable();
233263
}
234264

235-
/// Returns true if the `i-th` input argument has a tensor type.
236-
bool argHasTensorType(unsigned inputIdx) {
237-
assert(inputIdx < getInputs().size() && "input index out-of-bounds");
238-
return isa<RankedTensorType>(getInputs()[inputIdx].getType());
239-
}
240-
241-
/// Returns the i-th input argument's bounds attribute.
242-
BoundsAttr getInputBoundsAttr(unsigned inputIdx) {
243-
assert(inputIdx < getInputs().size() && "input index out-of-bounds");
244-
return cast<BoundsAttr>(getInputAttrs()[inputIdx]);
245-
}
246-
247265
ArrayRef<BlockArgument> getRegionOutArgs() {
248266
return getBody().getArguments().take_back(getOuts().size());
249267
}
@@ -255,16 +273,77 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
255273
ArrayRef<Attribute>(boundsAttrs.begin(), boundsAttrs.end())
256274
));
257275
}
276+
}];
277+
}
258278

259-
/// Populate the `input_attrs` from an array of BoundsAttrs.
260-
void setInputAttrsAttr(ArrayRef<BoundsAttr> boundsAttrs) {
261-
setInputAttrsAttr(::mlir::ArrayAttr::get(
262-
getOperation()->getContext(),
263-
ArrayRef<Attribute>(boundsAttrs.begin(), boundsAttrs.end())
264-
));
265-
}
279+
//===----------------------------------------------------------------------===//
280+
// InlineClosedAllocGroupOp
281+
//===----------------------------------------------------------------------===//
282+
283+
def Plan_InlineClosedAllocGroupOp : Plan_InlineClosedGroupBase<"inline_closed_alloc_group", [
284+
IsolatedFromAbove,
285+
SingleBlockImplicitTerminator<"plan::YieldOp">,
286+
DeclareOpInterfaceMethods<RegionBranchOpInterface,
287+
["getEntrySuccessorOperands"]>,
288+
DeclareOpInterfaceMethods<OpAsmOpInterface,
289+
["getAsmBlockArgumentNames"]>
290+
]> {
291+
let description = [{
292+
The `plan.inline_closed_alloc_group` operation is a variant of the
293+
`plan.inline_closed_group` operation that does not use destination-passing style
294+
(DPS). It is isolated from above and explicitly captures input operands, but unlike
295+
its DPS counterpart, it does not capture destination operands because its results must
296+
be lowered to allocation(s). The allocations may or may not be of a size that can only
297+
be computed inside of the region.
298+
This operation takes input operands and their corresponding bounds attributes,
299+
and produces results. The `input_attrs` hold bounds attribute information for
300+
the input operands. The absence of bounds information is allowed (`none` bounds).
301+
302+
The `target` attribute specifies the execution target for the group.
303+
304+
#### Example
305+
306+
Consider the following simple program containing operations with dynamically shaped operands:
307+
308+
```mlir
309+
%0 = ... : tensor<?xf32> // A dynamically shaped operand
310+
%1 = ... : index // A dynamic calculation of %0's extent
311+
312+
%2 = plan.inline_closed_alloc_group target(#plan.cluster_target<tensorrt>)
313+
inputs(%0, %1 : tensor<?xf32>, index)
314+
in_attrs [#plan.bounds<shape, [10], [20]>, #plan.bounds<none>]-> tensor<?xf32> {
315+
%3 = plan.with_shape %0 (%1) : (tensor<?xf32>, index) -> tensor<?xf32>
316+
%4 = stablehlo.exponential %3 : tensor<?xf32>
317+
yield %4 : tensor<?xf32>
318+
}
266319

267320
}];
321+
let arguments = (ins Variadic<AnyTypeOf<[AnyRankedTensor, AnySignlessIntegerOrIndex]>>:$inputs,
322+
BoundsAttrArray:$input_attrs,
323+
AnyAttr:$target);
324+
325+
let results = (outs Variadic<AnyTypeOf<[AnyRankedTensor]>>:$results);
326+
327+
let assemblyFormat = [{
328+
`target` `(` $target `)` `\n`
329+
`inputs` `(` ( $inputs^ `:` type($inputs) `)` ) : ( `)` ) ? `\n`
330+
`in_attrs` $input_attrs `\n`
331+
attr-dict-with-keyword `->` type($results)
332+
$body
333+
}];
334+
335+
let hasVerifier = 1;
336+
337+
let skipDefaultBuilders = 1;
338+
339+
let builders = [
340+
OpBuilder<(ins "TypeRange":$results,
341+
"Attribute":$target,
342+
"ValueRange":$inputs,
343+
CArg<"ArrayRef<BoundsAttr>", "{}">:$input_attrs)>,
344+
];
345+
346+
let extraClassDeclaration = baseInlineClosedExtraClassDeclaration;
268347
}
269348

270349
//===----------------------------------------------------------------------===//
@@ -276,7 +355,7 @@ def Plan_YieldOp : Plan_Op<"yield", [
276355
Terminator,
277356
ReturnLike,
278357
ParentOneOf<["plan::InlineGroupOp",
279-
"plan::InlineClosedGroupOp"]>]> {
358+
"plan::InlineClosedGroupOp", "plan::InlineClosedAllocGroupOp"]>]> {
280359

281360
let arguments = (ins Variadic<AnyType>:$results);
282361

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ executorOneShotModuleBufferize(ModuleOp targetOp,
6969
const ExecutorBufferizationOptions &options);
7070

7171
/// Build a pipeline (targeting ModuleOp) for bufferization.
72-
void buildPlanBufferizationPipeline(OpPassManager &pm);
72+
void buildPlanBufferizationPipeline(
73+
OpPassManager &pm, const plan::PlanAllocTensorsPassOptions &options);
7374

7475
/// Build a post-bufferization pipeline that performs optimizations on memrefs.
7576
void buildPlanBufferOptimizationPipeline(OpPassManager &pm);

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.td

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ def StablehloClusteringPass : Pass<"stablehlo-clustering", "::mlir::ModuleOp"> {
248248
Option<"entrypoint", "entrypoint", "std::string", "\"\"",
249249
"the name of the entrypoint function; if empty then the clustering runs"
250250
" on all functions">,
251+
Option<"enableNonDPSReturns",
252+
"enable-non-dps-returns", "bool", "false",
253+
"allow backend clusters to directly allocate outputs">,
251254
Option<"disallowHostTensorsInTensorRTClusters",
252255
"disallow-host-tensors-in-tensorrt-clusters", "bool", "false",
253256
"don't cluster host tensors in TensorRT clusters">,
@@ -332,7 +335,10 @@ def CreateClosedRegionsPass : Pass<"plan-create-closed-regions", "::mlir::Module
332335
Option<"testPreWalkOrder", "test-pre-walk-order", "bool", "false",
333336
"(used only in testing) specifies to outline regions by walking in "
334337
" pre-order; used for verifying results are not sensitive "
335-
"to traversal order">
338+
"to traversal order">,
339+
Option<"enableNonDPSReturns", "enable-non-dps-returns", "bool",
340+
/*default=*/"false",
341+
"Allow backend clusters to directly allocate outputs">
336342
];
337343

338344
let dependentDialects = [
@@ -428,6 +434,13 @@ def PlanAllocTensorsPass : Pass<"plan-alloc-tensors",
428434
"::mlir::bufferization::BufferizationDialect",
429435
"::mlir::plan::PlanDialect"
430436
];
437+
438+
let options = [
439+
Option<"enableNonDPSReturns", "enable-non-dps-returns", "bool",
440+
/*default=*/"false",
441+
"Allow backend clusters to directly allocate outputs">
442+
];
443+
431444
}
432445

433446
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/lib/Dialect/Plan/IR/PlanOps.cpp

Lines changed: 88 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ void InlineGroupOp::getSuccessorRegions(
295295
}
296296

297297
//===----------------------------------------------------------------------===//
298-
// InlineClosedGroupOp
298+
// InlineClosedGroupOp and InlineClosedAllocGroupOp Helpers
299299
//===----------------------------------------------------------------------===//
300300

301301
static LogicalResult
@@ -371,36 +371,38 @@ verifyBoundsAttr(StringRef argOrResult, unsigned idx, Type type,
371371
return success();
372372
}
373373

374-
LogicalResult InlineClosedGroupOp::verify() {
375-
SmallVector<BoundsAttr> inputAttrs =
376-
llvm::to_vector(getInputAttrs().getAsRange<BoundsAttr>());
377-
if (inputAttrs.size() != getInputs().size())
378-
return emitOpError("expected number of inputs (")
379-
<< getInputs().size()
380-
<< " to equal the number of input_attrs BoundsAttrs ("
381-
<< inputAttrs.size() << ")";
382-
383-
for (auto [idx, type] : llvm::enumerate(TypeRange(getInputs()))) {
384-
BoundsAttr boundsAttr = inputAttrs[idx];
385-
if (failed(verifyBoundsAttr("input argument", idx, type, boundsAttr,
386-
[&]() { return emitOpError(); })))
374+
static LogicalResult verifyBoundsAttrs(Operation *op, ValueRange operands,
375+
ArrayAttr attrsArray, StringRef attrName,
376+
StringRef boundName) {
377+
SmallVector<BoundsAttr> attrs =
378+
llvm::to_vector(attrsArray.getAsRange<BoundsAttr>());
379+
if (attrs.size() != operands.size())
380+
return op->emitOpError("expected number of ")
381+
<< attrName << " (" << operands.size() << ") to equal the number of "
382+
<< boundName << " BoundsAttrs (" << attrs.size() << ")";
383+
384+
for (auto [idx, type] : llvm::enumerate(TypeRange(operands))) {
385+
BoundsAttr boundsAttr = attrs[idx];
386+
if (failed(verifyBoundsAttr(attrName, idx, type, boundsAttr,
387+
[&]() { return op->emitOpError(); })))
387388
return failure();
388389
}
389390

390-
SmallVector<BoundsAttr> resAttrs =
391-
llvm::to_vector(getResAttrs().getAsRange<BoundsAttr>());
392-
if (resAttrs.size() != getNumResults())
393-
return emitOpError("expected number of results (")
394-
<< getNumResults()
395-
<< ") to equal the number of res_attrs BoundsAttrs ("
396-
<< resAttrs.size() << ")";
397-
398-
for (auto [idx, type] : llvm::enumerate(getResultTypes())) {
399-
BoundsAttr boundsAttr = resAttrs[idx];
400-
if (failed(verifyBoundsAttr("result", idx, type, boundsAttr,
401-
[&]() { return emitOpError(); })))
402-
return failure();
403-
}
391+
return success();
392+
}
393+
394+
//===----------------------------------------------------------------------===//
395+
// InlineClosedGroupOp
396+
//===----------------------------------------------------------------------===//
397+
398+
LogicalResult InlineClosedGroupOp::verify() {
399+
if (failed(verifyBoundsAttrs(getOperation(), getInputs(), getInputAttrs(),
400+
"inputs", "input_attrs")))
401+
return failure();
402+
403+
if (failed(verifyBoundsAttrs(getOperation(), getResults(), getResAttrs(),
404+
"results", "result_attrs")))
405+
return failure();
404406

405407
return success();
406408
}
@@ -465,6 +467,64 @@ void InlineClosedGroupOp::build(OpBuilder &b, OperationState &state,
465467
state.addTypes(TypeRange(outs));
466468
}
467469

470+
//===----------------------------------------------------------------------===//
471+
// InlineClosedAllocGroupOp
472+
//===----------------------------------------------------------------------===//
473+
474+
LogicalResult InlineClosedAllocGroupOp::verify() {
475+
Operation *op = getOperation();
476+
// Check for res_attrs
477+
if (op->hasAttr("res_attrs"))
478+
return op->emitOpError("must not contain 'res_attrs' attribute");
479+
return verifyBoundsAttrs(op, getInputs(), getInputAttrs(), "inputs",
480+
"input_attrs");
481+
}
482+
483+
void InlineClosedAllocGroupOp::getSuccessorRegions(
484+
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
485+
// If the predecessor is the InlineClosedGroupOp, branch into the body.
486+
if (point.isParent()) {
487+
regions.push_back(RegionSuccessor(&getBody(), getBody().getArguments()));
488+
return;
489+
}
490+
// Otherwise, the region branches back to the parent operation.
491+
regions.push_back(RegionSuccessor(getResults()));
492+
}
493+
494+
OperandRange
495+
InlineClosedAllocGroupOp::getEntrySuccessorOperands(RegionBranchPoint point) {
496+
return getOperands();
497+
}
498+
499+
void InlineClosedAllocGroupOp::getAsmBlockArgumentNames(
500+
Region &region, OpAsmSetValueNameFn setNameFn) {
501+
assert(region.getNumArguments() == getInputs().size() &&
502+
"expected one block arg for each input argument");
503+
for (BlockArgument arg : region.getArguments())
504+
setNameFn(arg, "in");
505+
}
506+
507+
void InlineClosedAllocGroupOp::build(OpBuilder &b, OperationState &state,
508+
TypeRange resultTypes, Attribute target,
509+
ValueRange inputs,
510+
ArrayRef<BoundsAttr> input_attrs) {
511+
state.addTypes(resultTypes);
512+
state.addOperands(inputs);
513+
state.getOrAddProperties<Properties>().target = target;
514+
state.getOrAddProperties<Properties>().setInputAttrs(b.getArrayAttr(
515+
SmallVector<Attribute>(input_attrs.begin(), input_attrs.end())));
516+
Region *body = state.addRegion();
517+
auto getLocs = [](ValueRange r) {
518+
SmallVector<Location> locs;
519+
locs.reserve(r.size());
520+
for (Value v : r)
521+
locs.push_back(v.getLoc());
522+
return locs;
523+
};
524+
(void)body->emplaceBlock();
525+
body->addArguments(TypeRange(inputs), getLocs(inputs));
526+
}
527+
468528
//===----------------------------------------------------------------------===//
469529
// YieldOp
470530
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/AllocTensors.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -856,11 +856,13 @@ class AllocTensorsPass
856856
}
857857
}
858858

859-
// First rewrite public functions to conform to DPS style.
860859
IRRewriter rewriter(ctx);
861-
if (failed(rewriteNotPrivateFuncsToDPS(rewriter, op))) {
862-
op->emitError("Failed to convert non-private functions to DPS");
863-
return signalPassFailure();
860+
if (!enableNonDPSReturns) {
861+
// First rewrite public functions to conform to DPS style.
862+
if (failed(rewriteNotPrivateFuncsToDPS(rewriter, op))) {
863+
op->emitError("Failed to convert non-private functions to DPS");
864+
return signalPassFailure();
865+
}
864866
}
865867

866868
// Rewrite SCF for and while loop bodies for better bufferization results,

0 commit comments

Comments
 (0)