Skip to content

Commit 92adfc4

Browse files
Adds support for a dimension names attribute for block arguments
Adds support for a dimension names attribute which is used to set dimension names in TensorRT. This conveys that the dimensions are equal at runtime.
1 parent a62916f commit 92adfc4

File tree

4 files changed

+60
-4
lines changed

4 files changed

+60
-4
lines changed

mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,11 +278,16 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
278278

279279
StringRef tensorrtShapeBoundsAttrName =
280280
mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName();
281+
StringRef tensorrtDimensionNamesAttrName =
282+
mlir::tensorrt::TensorRTDialect::getDimensionNamesArgAttrName();
283+
281284
SmallVector<Attribute> profileAttrsPerInput;
285+
SmallVector<Attribute> dimensionNamesAttrsPerInput;
282286
for (Value v : inputs) {
283287
auto rtt = dyn_cast<RankedTensorType>(v.getType());
284288
if (!rtt || rtt.hasStaticShape()) {
285289
profileAttrsPerInput.push_back(Attribute{});
290+
dimensionNamesAttrsPerInput.push_back(Attribute{});
286291
continue;
287292
}
288293

@@ -298,6 +303,10 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
298303
parentFunc.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
299304
argIndex, tensorrtShapeBoundsAttrName));
300305

306+
dimensionNamesAttrsPerInput.push_back(
307+
parentFunc.getArgAttrOfType<DictionaryAttr>(
308+
argIndex, tensorrtDimensionNamesAttrName));
309+
301310
if (!profileAttrsPerInput.back()) {
302311
return emitError(blockArg.getLoc())
303312
<< "Profile attribute (" << tensorrtShapeBoundsAttrName
@@ -306,10 +315,12 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
306315
}
307316

308317
for (unsigned idx = 0; idx < func->getNumArguments(); idx++) {
309-
if (!profileAttrsPerInput[idx])
310-
continue;
311-
func->setArgAttr(idx, tensorrtShapeBoundsAttrName,
312-
profileAttrsPerInput[idx]);
318+
if (profileAttrsPerInput[idx])
319+
func->setArgAttr(idx, tensorrtShapeBoundsAttrName,
320+
profileAttrsPerInput[idx]);
321+
if (dimensionNamesAttrsPerInput[idx])
322+
func->setArgAttr(idx, tensorrtDimensionNamesAttrName,
323+
dimensionNamesAttrsPerInput[idx]);
313324
}
314325

315326
rewriter.setInsertionPoint(inlineGroupOp);

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ def TensorRT_Dialect : Dialect {
3838
return "tensorrt.shape_profile";
3939
}
4040

41+
/// Return the name of the function arg attr that encodes
42+
/// the dimension names. It should have a type `DictionaryAttr`.
43+
static StringRef getDimensionNamesArgAttrName() {
44+
return "tensorrt.dimension_names";
45+
}
46+
4147
/// TensorRT quantization and dequantization mode markers.
4248
static constexpr StringRef kTensorRTPerTensorQuantizationMarker = "tensorrt.pt_q";
4349
static constexpr StringRef kTensorRTPerChannelQuantizationMarker = "tensorrt.pc_q";

mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,34 @@ LogicalResult NvInferNetworkEncoder::encodeFunc(FunctionOpInterface func) {
922922
return failure();
923923
nvinfer1::ITensor *inputTensor =
924924
getNetworkDefinition()->addInput(name.c_str(), *dtype, trtShape);
925+
926+
// setDimensionName must be called immediately after addInput, or TensorRT
927+
// will not deduplicate equal dimensions, which leads to perf gaps.
928+
auto dimNamesAttr = func.getArgAttrOfType<DictionaryAttr>(
929+
arg.getArgNumber(), TensorRTDialect::getDimensionNamesArgAttrName());
930+
if (dimNamesAttr) {
931+
for (NamedAttribute namedAttr : dimNamesAttr) {
932+
int32_t key;
933+
if (namedAttr.getName().getValue().getAsInteger(10, key))
934+
return func->emitOpError()
935+
<< "dimension name key '" << namedAttr.getName()
936+
<< "' is not an integer";
937+
938+
if (key < 0 || key >= argType.getRank())
939+
return func->emitOpError()
940+
<< "dimension name key '" << key
941+
<< "' is out of bounds for rank " << argType.getRank();
942+
943+
StringAttr strAttr = dyn_cast<StringAttr>(namedAttr.getValue());
944+
if (!strAttr)
945+
return func->emitOpError()
946+
<< "dimension name value '" << namedAttr.getValue()
947+
<< "' is not a string";
948+
949+
inputTensor->setDimensionName(key, strAttr.getValue().str().c_str());
950+
}
951+
}
952+
925953
if (!usesStronglyTyped && dtype == nvinfer1::DataType::kINT8)
926954
setIdentityInt8DynamicRange(inputTensor);
927955
this->map(arg, inputTensor);

mlir-tensorrt/tensorrt/test/Target/TensorRT/translate-to-tensorrt.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,14 @@ func.func @trt_reduce(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1xf32> {
5454
func.func @input_passthrough(%arg0: tensor<1xf32>, %arg1: tensor<1xf16>, %arg2: tensor<1xi32>) -> (tensor<1xf32>, tensor<1xf32>, tensor<1xf16>, tensor<1xi32>) {
5555
return %arg0, %arg0, %arg1, %arg2: tensor<1xf32>, tensor<1xf32>, tensor<1xf16>, tensor<1xi32>
5656
}
57+
58+
59+
// CHECK-LABEL: @trt_dim_names
60+
// CHECK-SAME: tensorrt.engine
61+
func.func @trt_dim_names(
62+
%arg0: tensor<?x?xf32> {tensorrt.dimension_names = {"0" = "batch", "1" = "features"}, tensorrt.shape_profile = #tensorrt.shape_profile<min=[2, 2], opt=[5, 5], max=[10, 10]>},
63+
%arg1: tensor<?x?xf32> {tensorrt.dimension_names = {"0" = "batch", "1" = "features"}, tensorrt.shape_profile = #tensorrt.shape_profile<min=[2, 2], opt=[5, 5], max=[10, 10]>},
64+
%arg2: tensor<2x10xf32>) -> tensor<?x?xf32> {
65+
%0 = tensorrt.identity %arg0 : tensor<?x?xf32> to tensor<?x?xf32>
66+
return %0 : tensor<?x?xf32>
67+
}

0 commit comments

Comments
 (0)