Skip to content

Commit e00c198

Browse files
committed
[TensorRT] Copy tensorrt.host_tensor attribute in outline pass
Copies tensorrt.host_tensor attribute and also adds plan.memory_space attribute accordingly.
1 parent c1d6e9b commit e00c198

File tree

2 files changed

+66
-22
lines changed

2 files changed

+66
-22
lines changed

mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,17 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
280280
mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName();
281281
StringRef tensorrtDimensionNamesAttrName =
282282
mlir::tensorrt::TensorRTDialect::getDimensionNamesArgAttrName();
283+
StringRef tensorrtValueBoundsAttrName =
284+
mlir::tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName();
285+
StringRef hostTensorAttrName = mlir::getHostTensorArgAttrName();
286+
StringRef memorySpaceAttrName =
287+
plan::PlanDialect::getMemorySpaceConstraintAttrName();
283288

284289
SmallVector<Attribute> profileAttrsPerInput;
285290
SmallVector<Attribute> dimensionNamesAttrsPerInput;
286291
for (Value v : inputs) {
287292
auto rtt = dyn_cast<RankedTensorType>(v.getType());
288-
if (!rtt || rtt.hasStaticShape()) {
293+
if (!rtt) {
289294
profileAttrsPerInput.push_back(Attribute{});
290295
dimensionNamesAttrsPerInput.push_back(Attribute{});
291296
continue;
@@ -299,30 +304,45 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
299304
}
300305

301306
int64_t argIndex = blockArg.getArgNumber();
302-
profileAttrsPerInput.push_back(
303-
parentFunc.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
304-
argIndex, tensorrtShapeBoundsAttrName));
305-
306-
dimensionNamesAttrsPerInput.push_back(
307-
parentFunc.getArgAttrOfType<DictionaryAttr>(
308-
argIndex, tensorrtDimensionNamesAttrName));
309-
310-
if (!profileAttrsPerInput.back()) {
311-
return emitError(blockArg.getLoc())
312-
<< "Profile attribute (" << tensorrtShapeBoundsAttrName
313-
<< ") of argument " << argIndex << " is not set";
307+
// Get shape profile and dynamision name attributes of the input
308+
if (rtt.hasStaticShape()) {
309+
// static-shaped argument can only have value bound attr (shape input)
310+
auto valueBoundAttr =
311+
parentFunc.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
312+
argIndex, tensorrtValueBoundsAttrName);
313+
if (valueBoundAttr) {
314+
func->setArgAttr(argIndex, tensorrtValueBoundsAttrName, valueBoundAttr);
315+
}
316+
// Get host tensor attribute of the input
317+
auto hostTensorAttr = parentFunc.getArgAttr(argIndex, hostTensorAttrName);
318+
if (hostTensorAttr) {
319+
func->setArgAttr(argIndex, hostTensorAttrName, hostTensorAttr);
320+
// Add plan.memory_space attr, it is also required for the parent
321+
// function
322+
auto memorySpaceAttr = plan::MemorySpaceAttr::get(
323+
rewriter.getContext(), plan::MemorySpace::host);
324+
func->setArgAttr(argIndex, memorySpaceAttrName, memorySpaceAttr);
325+
parentFunc.setArgAttr(argIndex, memorySpaceAttrName, memorySpaceAttr);
326+
}
327+
} else {
328+
auto shapeBoundAttr =
329+
parentFunc.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
330+
argIndex, tensorrtShapeBoundsAttrName);
331+
if (!shapeBoundAttr) {
332+
return emitError(blockArg.getLoc())
333+
<< "Profile attribute (" << tensorrtShapeBoundsAttrName
334+
<< ") of argument " << argIndex << " is not set";
335+
}
336+
func->setArgAttr(argIndex, tensorrtShapeBoundsAttrName, shapeBoundAttr);
337+
auto dimensionNameAttr = parentFunc.getArgAttrOfType<DictionaryAttr>(
338+
argIndex, tensorrtDimensionNamesAttrName);
339+
if (dimensionNameAttr) {
340+
func->setArgAttr(argIndex, tensorrtDimensionNamesAttrName,
341+
dimensionNameAttr);
342+
}
314343
}
315344
}
316345

317-
for (unsigned idx = 0; idx < func->getNumArguments(); idx++) {
318-
if (profileAttrsPerInput[idx])
319-
func->setArgAttr(idx, tensorrtShapeBoundsAttrName,
320-
profileAttrsPerInput[idx]);
321-
if (dimensionNamesAttrsPerInput[idx])
322-
func->setArgAttr(idx, tensorrtDimensionNamesAttrName,
323-
dimensionNamesAttrsPerInput[idx]);
324-
}
325-
326346
rewriter.setInsertionPoint(inlineGroupOp);
327347
auto callOp = rewriter.create<tensorrt::CallAllocOp>(
328348
inlineGroupOp.getLoc(), inlineGroupOp.getResultTypes(), inputs,

mlir-tensorrt/tensorrt/test/Target/TensorRT/translate-to-tensorrt.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,27 @@ func.func @trt_dim_names(
6565
%0 = tensorrt.identity %arg0 : tensor<?x?xf32> to tensor<?x?xf32>
6666
return %0 : tensor<?x?xf32>
6767
}
68+
69+
// CHECK-LABEL: @trt_host_input
70+
// CHECK-SAME: tensorrt.engine
71+
func.func @trt_host_input(%arg0: tensor<?x4xf32> {tensorrt.dimension_names = {}, tensorrt.shape_profile = #tensorrt.shape_profile<min = [2, 4], opt = [4, 4], max = [6, 4]>}, %arg1: tensor<i32> {tensorrt.host_tensor, tensorrt.value_bounds = #tensorrt.shape_profile<min = [1], opt = [2], max = [3]>}) -> tensor<?x?xf32> {
72+
%0 = tensorrt.element_wise <kSUM>(%arg0, %arg0 : tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
73+
%1 = tensorrt.shape %0 : tensor<?x4xf32> -> tensor<2xi32>
74+
%2 = tensorrt.slice %1[0][1][1] : tensor<2xi32> to tensor<1xi32>
75+
%3 = tensorrt.collapse_rank %2 : tensor<1xi32> to tensor<i32>
76+
%cst_i32 = tensorrt.constant dense<1> : tensor<i32>
77+
%4 = tensorrt.element_wise <kPROD>(%3, %cst_i32 : tensor<i32>, tensor<i32>) -> tensor<i32>
78+
%5 = tensorrt.slice %1[1][1][1] : tensor<2xi32> to tensor<1xi32>
79+
%6 = tensorrt.collapse_rank %5 : tensor<1xi32> to tensor<i32>
80+
%7 = tensorrt.element_wise <kPROD>(%4, %6 : tensor<i32>, tensor<i32>) -> tensor<i32>
81+
%cst_i32_0 = tensorrt.constant dense<1> : tensor<i32>
82+
%8 = tensorrt.element_wise <kPROD>(%arg1, %cst_i32_0 : tensor<i32>, tensor<i32>) -> tensor<i32>
83+
%9 = tensorrt.element_wise <kFLOOR_DIV>(%7, %8 : tensor<i32>, tensor<i32>) -> tensor<i32>
84+
%cst_i32_1 = tensorrt.constant dense<1> : tensor<1xi32>
85+
%10 = tensorrt.reshape %9 shape(%cst_i32_1: tensor<1xi32>) : tensor<i32> to tensor<?xi32>
86+
%cst_i32_2 = tensorrt.constant dense<1> : tensor<1xi32>
87+
%11 = tensorrt.reshape %arg1 shape(%cst_i32_2: tensor<1xi32>) : tensor<i32> to tensor<?xi32>
88+
%12 = tensorrt.concatenation {axis = 0 : i32} ins(%10, %11 : tensor<?xi32>, tensor<?xi32>) -> tensor<2xi32>
89+
%13 = tensorrt.reshape %0 shape(%12: tensor<2xi32>) : tensor<?x4xf32> to tensor<?x?xf32>
90+
return %13 : tensor<?x?xf32>
91+
}

0 commit comments

Comments
 (0)