Skip to content

Commit 006bde8

Browse files
Adds support for a dimension names attribute for block arguments
Adds support for a dimension names attribute which is used to set dimension names in TensorRT. This conveys that the dimensions are equal at runtime.
1 parent 30f9816 commit 006bde8

File tree

4 files changed

+51
-4
lines changed

4 files changed

+51
-4
lines changed

mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,17 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
240240

241241
StringRef tensorrtShapeBoundsAttrName =
242242
mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName();
243+
StringRef tensorrtDimensionNamesAttrName =
244+
mlir::tensorrt::TensorRTDialect::getDimensionNamesArgAttrName();
243245
func::FuncOp funcContainingCluster =
244246
cluster.back()->getParentOfType<func::FuncOp>();
245247
SmallVector<Attribute> profileAttrsPerInput;
248+
SmallVector<Attribute> dimensionNamesAttrsPerInput;
246249
for (Value v : inputs) {
247250
auto rtt = dyn_cast<RankedTensorType>(v.getType());
248251
if (!rtt || rtt.hasStaticShape()) {
249252
profileAttrsPerInput.push_back(Attribute{});
253+
dimensionNamesAttrsPerInput.push_back(Attribute{});
250254
continue;
251255
}
252256

@@ -263,6 +267,10 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
263267
funcContainingCluster.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
264268
argIndex, tensorrtShapeBoundsAttrName));
265269

270+
dimensionNamesAttrsPerInput.push_back(
271+
funcContainingCluster.getArgAttrOfType<DictionaryAttr>(
272+
argIndex, tensorrtDimensionNamesAttrName));
273+
266274
if (!profileAttrsPerInput.back()) {
267275
return emitError(blockArg.getLoc())
268276
<< "Profile attribute (" << tensorrtShapeBoundsAttrName
@@ -271,10 +279,12 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
271279
}
272280

273281
for (unsigned idx = 0; idx < func->getNumArguments(); idx++) {
274-
if (!profileAttrsPerInput[idx])
275-
continue;
276-
func->setArgAttr(idx, tensorrtShapeBoundsAttrName,
277-
profileAttrsPerInput[idx]);
282+
if (profileAttrsPerInput[idx])
283+
func->setArgAttr(idx, tensorrtShapeBoundsAttrName,
284+
profileAttrsPerInput[idx]);
285+
if (dimensionNamesAttrsPerInput[idx])
286+
func->setArgAttr(idx, tensorrtDimensionNamesAttrName,
287+
dimensionNamesAttrsPerInput[idx]);
278288
}
279289

280290
rewriter.setInsertionPoint(inlineGroupOp);

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ def TensorRT_Dialect : Dialect {
3838
return "tensorrt.shape_profile";
3939
}
4040

41+
/// Return the name of the function arg attr that encodes
42+
/// the dimension names. It should have a type `DictionaryAttr`.
43+
static StringRef getDimensionNamesArgAttrName() {
44+
return "tensorrt.dimension_names";
45+
}
46+
4147
/// TensorRT quantization and dequantization mode markers.
4248
static constexpr StringRef kTensorRTPerTensorQuantizationMarker = "tensorrt.pt_q";
4349
static constexpr StringRef kTensorRTPerChannelQuantizationMarker = "tensorrt.pc_q";

mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,29 @@ LogicalResult NvInferNetworkEncoder::encodeFunc(FunctionOpInterface func) {
922922
return failure();
923923
nvinfer1::ITensor *inputTensor =
924924
getNetworkDefinition()->addInput(name.c_str(), *dtype, trtShape);
925+
926+
// setDimensionName must be called immediately after addInput, or TensorRT
927+
// will not deduplicate equal dimensions, which leads to perf gaps.
928+
auto dimNamesAttr = func.getArgAttrOfType<DictionaryAttr>(
929+
arg.getArgNumber(), TensorRTDialect::getDimensionNamesArgAttrName());
930+
if (dimNamesAttr) {
931+
932+
for (NamedAttribute namedAttr : dimNamesAttr) {
933+
int32_t key;
934+
if (namedAttr.getName().getValue().getAsInteger(10, key)) {
935+
return func->emitOpError()
936+
<< "dimension name key '" << namedAttr.getName()
937+
<< "' is not an integer";
938+
}
939+
940+
if (StringAttr strAttr = namedAttr.getValue().dyn_cast<StringAttr>()) {
941+
StringRef value = strAttr.getValue();
942+
inputTensor->setDimensionName(static_cast<int32_t>(key),
943+
value.str().c_str());
944+
}
945+
}
946+
}
947+
925948
if (!usesStronglyTyped && dtype == nvinfer1::DataType::kINT8)
926949
setIdentityInt8DynamicRange(inputTensor);
927950
this->map(arg, inputTensor);
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// RUN: %pick-one-gpu tensorrt-opt -split-input-file -pass-pipeline="builtin.module(translate-tensorrt-to-engine)" \
2+
// RUN: -mlir-elide-elementsattrs-if-larger=32 -tensorrt-builder-opt-level=0 %s | FileCheck %s
3+
4+
// CHECK-LABEL: @trt_dim_names
5+
func.func @trt_dim_names(%arg0: tensor<2x10xf32> {tensorrt.dimension_names = {"0" = "batch", "1" = "features"}}) -> tensor<2x10xf32> {
6+
%0 = tensorrt.identity %arg0 : tensor<2x10xf32> to tensor<2x10xf32>
7+
return %0 : tensor<2x10xf32>
8+
}

0 commit comments

Comments
 (0)