Skip to content

Commit a868afd

Browse files
committed
Change to remove DPS style calling convention in plan dialect
1 parent 5c412aa commit a868afd

File tree

15 files changed

+101
-410
lines changed

15 files changed

+101
-410
lines changed

mlir-tensorrt/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ compile_commands.json
3232
**/*.private.*
3333
*.private
3434
**/tmp/**
35+
**/tmp**
36+
**/tripy/**
3537

3638
# TRT Timing Cache artifacts
3739
*.timing-cache

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanOps.td

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ include "mlir-tensorrt-dialect/Interface/TensorKindOpInterface.td"
77
include "mlir/Interfaces/SideEffectInterfaces.td"
88
include "mlir/Interfaces/ControlFlowInterfaces.td"
99
include "mlir/Interfaces/DestinationStyleOpInterface.td"
10-
include "mlir/Interfaces/InferTypeOpInterface.td"
1110
include "mlir/IR/OpAsmInterface.td"
1211

1312
class Plan_NativeOpTrait<string name,
@@ -136,8 +135,6 @@ def Plan_InlineGroupOp : Plan_GroupOpBase<"inline_group", [
136135

137136
def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
138137
IsolatedFromAbove,
139-
AttrSizedOperandSegments,
140-
DestinationStyleOpInterface,
141138
SingleBlockImplicitTerminator<"plan::YieldOp">,
142139
DeclareOpInterfaceMethods<RegionBranchOpInterface,
143140
["getEntrySuccessorOperands"]>,
@@ -199,19 +196,16 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
199196

200197
}];
201198
let arguments = (ins Variadic<AnyTypeOf<[AnyRankedTensor, AnySignlessIntegerOrIndex]>>:$inputs,
202-
Variadic<AnyRankedTensor>:$outs,
203199
BoundsAttrArray:$input_attrs,
204-
BoundsAttrArray:$res_attrs,
205200
AnyAttr:$target);
206201

207202
let results = (outs Variadic<AnyType>:$results);
208203

209204
let assemblyFormat = [{
210205
`target` `(` $target `)` `\n`
211206
`inputs` `(` ( $inputs^ `:` type($inputs) `)` ) : ( `)` ) ? `\n`
212-
`outs` `(` $outs `:` type($outs) `)` `\n`
213207
`in_attrs` $input_attrs `\n`
214-
`res_attrs` $res_attrs attr-dict-with-keyword `->` type($results)
208+
attr-dict-with-keyword `->` type($results)
215209
$body
216210
}];
217211

@@ -220,18 +214,14 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
220214
let skipDefaultBuilders = 1;
221215

222216
let builders = [
223-
OpBuilder<(ins "Attribute":$target,
224-
"ValueRange":$inputs, "ValueRange":$outs,
225-
CArg<"ArrayRef<BoundsAttr>", "{}">:$input_attrs,
226-
CArg<"ArrayRef<BoundsAttr>", "{}">:$res_attrs)>
217+
OpBuilder<(ins "TypeRange":$results,
218+
"Attribute":$target,
219+
"ValueRange":$inputs,
220+
CArg<"ArrayRef<BoundsAttr>", "{}">:$input_attrs)>
227221
];
228222

229223
let extraClassDeclaration = baseExtraClassDeclaration # [{
230224

231-
MutableOperandRange getDpsInitsMutable() {
232-
return getOutsMutable();
233-
}
234-
235225
/// Returns true if the `i-th` input argument has a tensor type.
236226
bool argHasTensorType(unsigned inputIdx) {
237227
assert(inputIdx < getInputs().size() && "input index out-of-bounds");
@@ -244,17 +234,6 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
244234
return cast<BoundsAttr>(getInputAttrs()[inputIdx]);
245235
}
246236

247-
ArrayRef<BlockArgument> getRegionOutArgs() {
248-
return getBody().getArguments().take_back(getOuts().size());
249-
}
250-
251-
/// Populate the `res_attrs` from an array of BoundsAttrs.
252-
void setResAttrsAttr(ArrayRef<BoundsAttr> boundsAttrs) {
253-
setResAttrsAttr(::mlir::ArrayAttr::get(
254-
getOperation()->getContext(),
255-
ArrayRef<Attribute>(boundsAttrs.begin(), boundsAttrs.end())
256-
));
257-
}
258237

259238
/// Populate the `input_attrs` from an array of BoundsAttrs.
260239
void setInputAttrsAttr(ArrayRef<BoundsAttr> boundsAttrs) {

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/TensorRTRuntime/IR/TensorRTRuntimeOps.td

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,8 @@ def TensorRTRuntime_CompileOp : TensorRTRuntime_Op<"compile", [Pure]> {
6161
//===----------------------------------------------------------------------===//
6262

6363
def TensorRTRuntime_EnqueueOp : TensorRTRuntime_Op<"enqueue", [
64-
DeclareOpInterfaceMethods<InferTypeOpInterface>,
6564
DeclareOpInterfaceMethods<TensorKindOpInterface>,
6665
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
67-
AttrSizedOperandSegments,
68-
DestinationStyleOpInterface
6966
]> {
7067
let description = [{
7168

@@ -88,23 +85,19 @@ def TensorRTRuntime_EnqueueOp : TensorRTRuntime_Op<"enqueue", [
8885
let arguments = (ins TensorRTRuntime_Context:$execution_context,
8986
CUDA_Stream:$stream,
9087
Variadic<AnyShaped>:$inputs,
91-
Variadic<AnyShaped>:$outs,
9288
OptionalAttr<DenseI64ArrayAttr>:$host_tensor_args);
9389
let results = (outs Variadic<AnyType>:$results);
9490

9591
let assemblyFormat = [{
9692
$execution_context `stream` `(` $stream `)` ` `
9793
(`host_tensor_args` $host_tensor_args^ ` ` )?
98-
`(` $inputs `)` `outs` `(` $outs `)`
99-
attr-dict `:` functional-type($inputs, $outs)
94+
`(` $inputs `)`
95+
attr-dict `:` functional-type($inputs, $results)
10096
}];
10197

10298
let hasVerifier = 1;
10399

104100
let extraClassDeclaration = [{
105-
// Declare the outs as inits/outs to DestinationStyleOpInterface.
106-
MutableOperandRange getDpsInitsMutable() { return getOutsMutable(); }
107-
108101
/// Return true if the operand at the specified index is a host tensor
109102
/// argument.
110103
bool isOperandOnHost(int64_t operandIdx) {

mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,11 +217,6 @@ struct ConvertEnqueueToCall
217217
if (failed(createMemRefAndExractPtr(oldVal, newVal)))
218218
return failure();
219219
}
220-
for (auto [oldVal, newVal] : llvm::zip(op.getOuts(), adaptor.getOuts())) {
221-
if (failed(createMemRefAndExractPtr(oldVal, newVal)))
222-
return failure();
223-
}
224-
225220
// Create the table containing the pointer/offset args and append it to the
226221
// arguments for the call op.
227222
Value args = b.create<executor::CreateTableOp>(

mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@
3535
#include "mlir/IR/PatternMatch.h"
3636
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3737

38+
#include "llvm/Support/Debug.h"
39+
40+
#define DEBUG_TYPE "tensorrt-to-tensorrt-runtime"
41+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
42+
3843
namespace mlir {
3944
#define GEN_PASS_DEF_CONVERTTENSORRTTOTENSORRTRUNTIMEPASS
4045
#include "mlir-tensorrt/Conversion/Passes.h.inc"
@@ -123,12 +128,15 @@ class ConvertTensorRTToRuntimePass
123128
{FlatSymbolRefAttr::get(trtFunc)}));
124129
Value stream = rewriter.create<cuda::GetGlobalStreamOp>(loc, 0);
125130
auto enqueueOp = rewriter.create<trtrt::EnqueueOp>(
126-
loc, executionContext, stream, callOp.getInputs(),
127-
callOp.getOutputs(),
131+
loc, callOp->getResultTypes(), executionContext, stream, callOp.getInputs(),
128132
/*host_tensors_args=*/hostTensorArgs.empty()
129133
? DenseI64ArrayAttr{}
130134
: DenseI64ArrayAttr::get(ctx, hostTensorArgs));
131135
rewriter.setInsertionPointAfter(enqueueOp);
136+
137+
DBGS() << "Number of call op results: " << callOp->getNumResults() << "\n";
138+
DBGS() << "Number of enqueue op results: " << enqueueOp->getNumResults() << "\n";
139+
132140
rewriter.replaceOp(callOp, enqueueOp->getResults());
133141
}
134142
}

mlir-tensorrt/compiler/lib/Dialect/Plan/IR/PlanOps.cpp

Lines changed: 7 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -387,21 +387,6 @@ LogicalResult InlineClosedGroupOp::verify() {
387387
return failure();
388388
}
389389

390-
SmallVector<BoundsAttr> resAttrs =
391-
llvm::to_vector(getResAttrs().getAsRange<BoundsAttr>());
392-
if (resAttrs.size() != getNumResults())
393-
return emitOpError("expected number of results (")
394-
<< getNumResults()
395-
<< ") to equal the number of res_attrs BoundsAttrs ("
396-
<< resAttrs.size() << ")";
397-
398-
for (auto [idx, type] : llvm::enumerate(getResultTypes())) {
399-
BoundsAttr boundsAttr = resAttrs[idx];
400-
if (failed(verifyBoundsAttr("result", idx, type, boundsAttr,
401-
[&]() { return emitOpError(); })))
402-
return failure();
403-
}
404-
405390
return success();
406391
}
407392

@@ -424,33 +409,22 @@ InlineClosedGroupOp::getEntrySuccessorOperands(RegionBranchPoint point) {
424409

425410
void InlineClosedGroupOp::getAsmBlockArgumentNames(
426411
Region &region, OpAsmSetValueNameFn setNameFn) {
427-
assert(region.front().getNumArguments() ==
428-
getInputs().size() + getOuts().size() &&
429-
"expected one block arg for each input and destination argument");
430-
unsigned numInputs = getInputs().size();
412+
assert(region.front().getNumArguments() == getInputs().size() &&
413+
"expected one block arg for each input argument");
431414
for (BlockArgument arg : region.front().getArguments()) {
432-
StringRef name = arg.getArgNumber() < numInputs ? "in" : "out";
433-
setNameFn(arg, name);
415+
setNameFn(arg, "in");
434416
}
435417
}
436418

437419
void InlineClosedGroupOp::build(OpBuilder &b, OperationState &state,
438-
Attribute target, ValueRange inputs,
439-
ValueRange outs,
440-
ArrayRef<BoundsAttr> input_attrs,
441-
ArrayRef<BoundsAttr> result_attrs) {
420+
TypeRange resultTypes, Attribute target,
421+
ValueRange inputs,
422+
ArrayRef<BoundsAttr> input_attrs) {
423+
state.addTypes(resultTypes);
442424
state.addOperands(inputs);
443-
state.addOperands(outs);
444425
state.getOrAddProperties<Properties>().target = target;
445426
state.getOrAddProperties<Properties>().setInputAttrs(b.getArrayAttr(
446427
SmallVector<Attribute>(input_attrs.begin(), input_attrs.end())));
447-
state.getOrAddProperties<Properties>().setResAttrs(b.getArrayAttr(
448-
SmallVector<Attribute>(result_attrs.begin(), result_attrs.end())));
449-
450-
llvm::copy(
451-
ArrayRef<int32_t>{static_cast<int32_t>(inputs.size()),
452-
static_cast<int32_t>(outs.size())},
453-
state.getOrAddProperties<Properties>().operandSegmentSizes.begin());
454428
Region *body = state.addRegion();
455429
auto getLocs = [](ValueRange r) {
456430
SmallVector<Location> locs;
@@ -461,8 +435,6 @@ void InlineClosedGroupOp::build(OpBuilder &b, OperationState &state,
461435
};
462436
(void)body->emplaceBlock();
463437
body->addArguments(TypeRange(inputs), getLocs(inputs));
464-
body->addArguments(TypeRange(outs), getLocs(outs));
465-
state.addTypes(TypeRange(outs));
466438
}
467439

468440
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/AllocTensors.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -854,10 +854,10 @@ class AllocTensorsPass
854854

855855
// First rewrite public functions to conform to DPS style.
856856
IRRewriter rewriter(ctx);
857-
if (failed(rewriteNotPrivateFuncsToDPS(rewriter, op))) {
858-
op->emitError("Failed to convert non-private functions to DPS");
859-
return signalPassFailure();
860-
}
857+
// if (failed(rewriteNotPrivateFuncsToDPS(rewriter, op))) {
858+
// op->emitError("Failed to convert non-private functions to DPS");
859+
// return signalPassFailure();
860+
// }
861861

862862
// Rewrite SCF for and while loop bodies for better bufferization results,
863863
// if possible.

0 commit comments

Comments
 (0)