Skip to content

Commit 34df3e2

Browse files
Adds support for a dimension names attribute for block arguments
Adds support for a dimension names attribute which is used to set dimension names in TensorRT. This conveys that the dimensions are equal at runtime.
1 parent 30f9816 commit 34df3e2

File tree

4 files changed

+60
-4
lines changed

4 files changed

+60
-4
lines changed

mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,17 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
240240

241241
StringRef tensorrtShapeBoundsAttrName =
242242
mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName();
243+
StringRef tensorrtDimensionNamesAttrName =
244+
mlir::tensorrt::TensorRTDialect::getDimensionNamesArgAttrName();
243245
func::FuncOp funcContainingCluster =
244246
cluster.back()->getParentOfType<func::FuncOp>();
245247
SmallVector<Attribute> profileAttrsPerInput;
248+
SmallVector<Attribute> dimensionNamesAttrsPerInput;
246249
for (Value v : inputs) {
247250
auto rtt = dyn_cast<RankedTensorType>(v.getType());
248251
if (!rtt || rtt.hasStaticShape()) {
249252
profileAttrsPerInput.push_back(Attribute{});
253+
dimensionNamesAttrsPerInput.push_back(Attribute{});
250254
continue;
251255
}
252256

@@ -263,6 +267,10 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
263267
funcContainingCluster.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
264268
argIndex, tensorrtShapeBoundsAttrName));
265269

270+
dimensionNamesAttrsPerInput.push_back(
271+
funcContainingCluster.getArgAttrOfType<DictionaryAttr>(
272+
argIndex, tensorrtDimensionNamesAttrName));
273+
266274
if (!profileAttrsPerInput.back()) {
267275
return emitError(blockArg.getLoc())
268276
<< "Profile attribute (" << tensorrtShapeBoundsAttrName
@@ -271,10 +279,12 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
271279
}
272280

273281
for (unsigned idx = 0; idx < func->getNumArguments(); idx++) {
274-
if (!profileAttrsPerInput[idx])
275-
continue;
276-
func->setArgAttr(idx, tensorrtShapeBoundsAttrName,
277-
profileAttrsPerInput[idx]);
282+
if (profileAttrsPerInput[idx])
283+
func->setArgAttr(idx, tensorrtShapeBoundsAttrName,
284+
profileAttrsPerInput[idx]);
285+
if (dimensionNamesAttrsPerInput[idx])
286+
func->setArgAttr(idx, tensorrtDimensionNamesAttrName,
287+
dimensionNamesAttrsPerInput[idx]);
278288
}
279289

280290
rewriter.setInsertionPoint(inlineGroupOp);

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ def TensorRT_Dialect : Dialect {
3838
return "tensorrt.shape_profile";
3939
}
4040

41+
/// Return the name of the function arg attr that encodes
42+
/// the dimension names. It should have a type `DictionaryAttr`.
43+
static StringRef getDimensionNamesArgAttrName() {
44+
return "tensorrt.dimension_names";
45+
}
46+
4147
/// TensorRT quantization and dequantization mode markers.
4248
static constexpr StringRef kTensorRTPerTensorQuantizationMarker = "tensorrt.pt_q";
4349
static constexpr StringRef kTensorRTPerChannelQuantizationMarker = "tensorrt.pc_q";

mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir-tensorrt-dialect/TensorRT/Utils/Utils.h"
2828
#include "mlir-tensorrt-dialect/Utils/NvInferAdaptor.h"
2929
#include "mlir-tensorrt-dialect/Utils/StaticValueUtils.h"
30+
#include "mlir/IR/BuiltinAttributes.h"
3031
#include "mlir/IR/BuiltinTypes.h"
3132
#include "mlir/Interfaces/FunctionInterfaces.h"
3233
#include "llvm/ADT/STLExtras.h"
@@ -922,6 +923,34 @@ LogicalResult NvInferNetworkEncoder::encodeFunc(FunctionOpInterface func) {
922923
return failure();
923924
nvinfer1::ITensor *inputTensor =
924925
getNetworkDefinition()->addInput(name.c_str(), *dtype, trtShape);
926+
927+
// setDimensionName must be called immediately after addInput, or TensorRT
928+
// will not deduplicate equal dimensions, which leads to perf gaps.
929+
auto dimNamesAttr = func.getArgAttrOfType<DictionaryAttr>(
930+
arg.getArgNumber(), TensorRTDialect::getDimensionNamesArgAttrName());
931+
if (dimNamesAttr) {
932+
for (NamedAttribute namedAttr : dimNamesAttr) {
933+
int32_t key;
934+
if (namedAttr.getName().getValue().getAsInteger(10, key))
935+
return func->emitOpError()
936+
<< "dimension name key '" << namedAttr.getName()
937+
<< "' is not an integer";
938+
939+
if (key < 0 || key >= argType.getRank())
940+
return func->emitOpError()
941+
<< "dimension name key '" << key
942+
<< "' is out of bounds for rank " << argType.getRank();
943+
944+
StringAttr strAttr = dyn_cast<StringAttr>(namedAttr.getValue());
945+
if (!strAttr)
946+
return func->emitOpError()
947+
<< "dimension name value '" << namedAttr.getValue()
948+
<< "' is not a string";
949+
950+
inputTensor->setDimensionName(key, strAttr.getValue().str().c_str());
951+
}
952+
}
953+
925954
if (!usesStronglyTyped && dtype == nvinfer1::DataType::kINT8)
926955
setIdentityInt8DynamicRange(inputTensor);
927956
this->map(arg, inputTensor);

mlir-tensorrt/tensorrt/test/Target/TensorRT/translate-to-tensorrt.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,14 @@ func.func @trt_reduce(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1xf32> {
5454
func.func @input_passthrough(%arg0: tensor<1xf32>, %arg1: tensor<1xf16>, %arg2: tensor<1xi32>) -> (tensor<1xf32>, tensor<1xf32>, tensor<1xf16>, tensor<1xi32>) {
5555
return %arg0, %arg0, %arg1, %arg2: tensor<1xf32>, tensor<1xf32>, tensor<1xf16>, tensor<1xi32>
5656
}
57+
58+
59+
// CHECK-LABEL: @trt_dim_names
60+
// CHECK-SAME: tensorrt.engine
61+
func.func @trt_dim_names(
62+
%arg0: tensor<?x?xf32> {tensorrt.dimension_names = {"0" = "batch", "1" = "features"}, tensorrt.shape_profile = #tensorrt.shape_profile<min=[2, 2], opt=[5, 5], max=[10, 10]>},
63+
%arg1: tensor<?x?xf32> {tensorrt.dimension_names = {"0" = "batch", "1" = "features"}, tensorrt.shape_profile = #tensorrt.shape_profile<min=[2, 2], opt=[5, 5], max=[10, 10]>},
64+
%arg2: tensor<2x10xf32>) -> tensor<?x?xf32> {
65+
%0 = tensorrt.identity %arg0 : tensor<?x?xf32> to tensor<?x?xf32>
66+
return %0 : tensor<?x?xf32>
67+
}

0 commit comments

Comments
 (0)