Skip to content

Commit 6320c48

Browse files
committed
[TensorRT] Copy plan.memory_space attribute in outline pass
Copies plan.memory_space attribute and also adds tensorrt.host_tensor attribute accordingly, which is needed by NetworkEncoder to determine if input is a TRT shape tensor.
1 parent c1d6e9b commit 6320c48

File tree

3 files changed

+105
-22
lines changed

3 files changed

+105
-22
lines changed

.github/workflows/mlir-tensorrt-ci.yml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,48 @@ jobs:
3636
3737
sudo apt-get autoremove -y
3838
sudo apt-get autoclean -y
39+
40+
# Value of `github.workspace` is /home/runner/work/{repo_name}/{repo-name}
41+
# i.e. /home/runner/work/TensorRT-Incubator/TensorRT-Incubator in our case.
42+
# After this action, repo is cloned inside above path.
43+
- uses: actions/checkout@v4
44+
with:
45+
fetch-depth: 5
46+
47+
- name: Validate commit message
48+
if: ${{ github.event_name == 'pull_request' }}
49+
env:
50+
PR_HEAD_COMMIT_SHA: ${{ github.event.pull_request.head.sha }}
51+
run: |
52+
cat > commit_message_checker.py <<EOF
53+
#!/usr/bin/python3
54+
import re
55+
import sys
56+
import subprocess
57+
58+
git_cmd = f"git show -s --format=%B {sys.argv[1]}"
59+
try:
60+
commit_message_cmd = subprocess.run(git_cmd.split(' '), capture_output=True, text=True, check=True)
61+
commit_message = commit_message_cmd.stdout.strip()
62+
except subprocess.CalledProcessError as e:
63+
print(f"Failed to get PR HEAD commit message with error: {e.stderr.strip()}")
64+
65+
match = re.search(r"^(\[bot\].+|NFC: .+|(.+\n\n+.+\n+.+))$", commit_message, re.DOTALL)
66+
if match:
67+
print("Commit message is in canonical form :)")
68+
sys.exit(0)
69+
print("Commit message is not in the canonical form!")
70+
print(commit_message)
71+
print("")
72+
print("Expected format is, ")
73+
print("<title>")
74+
print("<body>")
75+
print("NOTE: Body should start on new line. '2 spaces + enter' for new line!")
76+
print("NOTE: Body should be at least two lines.")
77+
sys.exit(1)
78+
EOF
79+
80+
python3 commit_message_checker.py ${PR_HEAD_COMMIT_SHA}
3981
4082
# Run initial format check
4183
- name: Run python format and clang check

mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,17 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
280280
mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName();
281281
StringRef tensorrtDimensionNamesAttrName =
282282
mlir::tensorrt::TensorRTDialect::getDimensionNamesArgAttrName();
283+
StringRef tensorrtValueBoundsAttrName =
284+
mlir::tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName();
285+
StringRef hostTensorAttrName = mlir::getHostTensorArgAttrName();
286+
StringRef memorySpaceAttrName =
287+
plan::PlanDialect::getMemorySpaceConstraintAttrName();
283288

284289
SmallVector<Attribute> profileAttrsPerInput;
285290
SmallVector<Attribute> dimensionNamesAttrsPerInput;
286291
for (Value v : inputs) {
287292
auto rtt = dyn_cast<RankedTensorType>(v.getType());
288-
if (!rtt || rtt.hasStaticShape()) {
293+
if (!rtt) {
289294
profileAttrsPerInput.push_back(Attribute{});
290295
dimensionNamesAttrsPerInput.push_back(Attribute{});
291296
continue;
@@ -299,30 +304,42 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
299304
}
300305

301306
int64_t argIndex = blockArg.getArgNumber();
302-
profileAttrsPerInput.push_back(
303-
parentFunc.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
304-
argIndex, tensorrtShapeBoundsAttrName));
305-
306-
dimensionNamesAttrsPerInput.push_back(
307-
parentFunc.getArgAttrOfType<DictionaryAttr>(
308-
argIndex, tensorrtDimensionNamesAttrName));
309-
310-
if (!profileAttrsPerInput.back()) {
311-
return emitError(blockArg.getLoc())
312-
<< "Profile attribute (" << tensorrtShapeBoundsAttrName
313-
<< ") of argument " << argIndex << " is not set";
307+
// Get shape profile and dynamision name attributes of the input
308+
if (rtt.hasStaticShape()) {
309+
// static-shaped argument can only have value bound attr (shape input)
310+
auto valueBoundAttr =
311+
parentFunc.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
312+
argIndex, tensorrtValueBoundsAttrName);
313+
if (valueBoundAttr) {
314+
func->setArgAttr(argIndex, tensorrtValueBoundsAttrName, valueBoundAttr);
315+
}
316+
// Get memory space attribute of the input
317+
auto memorySpaceAttr =
318+
parentFunc.getArgAttr(argIndex, memorySpaceAttrName);
319+
if (memorySpaceAttr) {
320+
func->setArgAttr(argIndex, memorySpaceAttrName, memorySpaceAttr);
321+
// Add tensorrt.host_tensor attr, it is needed by NetworkEncoder for now
322+
func->setArgAttr(argIndex, hostTensorAttrName, rewriter.getUnitAttr());
323+
}
324+
} else {
325+
auto shapeBoundAttr =
326+
parentFunc.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
327+
argIndex, tensorrtShapeBoundsAttrName);
328+
if (!shapeBoundAttr) {
329+
return emitError(blockArg.getLoc())
330+
<< "Profile attribute (" << tensorrtShapeBoundsAttrName
331+
<< ") of argument " << argIndex << " is not set";
332+
}
333+
func->setArgAttr(argIndex, tensorrtShapeBoundsAttrName, shapeBoundAttr);
334+
auto dimensionNameAttr = parentFunc.getArgAttrOfType<DictionaryAttr>(
335+
argIndex, tensorrtDimensionNamesAttrName);
336+
if (dimensionNameAttr) {
337+
func->setArgAttr(argIndex, tensorrtDimensionNamesAttrName,
338+
dimensionNameAttr);
339+
}
314340
}
315341
}
316342

317-
for (unsigned idx = 0; idx < func->getNumArguments(); idx++) {
318-
if (profileAttrsPerInput[idx])
319-
func->setArgAttr(idx, tensorrtShapeBoundsAttrName,
320-
profileAttrsPerInput[idx]);
321-
if (dimensionNamesAttrsPerInput[idx])
322-
func->setArgAttr(idx, tensorrtDimensionNamesAttrName,
323-
dimensionNamesAttrsPerInput[idx]);
324-
}
325-
326343
rewriter.setInsertionPoint(inlineGroupOp);
327344
auto callOp = rewriter.create<tensorrt::CallAllocOp>(
328345
inlineGroupOp.getLoc(), inlineGroupOp.getResultTypes(), inputs,

mlir-tensorrt/tensorrt/test/Target/TensorRT/translate-to-tensorrt.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,27 @@ func.func @trt_dim_names(
6565
%0 = tensorrt.identity %arg0 : tensor<?x?xf32> to tensor<?x?xf32>
6666
return %0 : tensor<?x?xf32>
6767
}
68+
69+
// CHECK-LABEL: @trt_host_input
70+
// CHECK-SAME: tensorrt.engine
71+
func.func @trt_host_input(%arg0: tensor<?x4xf32> {tensorrt.dimension_names = {}, tensorrt.shape_profile = #tensorrt.shape_profile<min = [2, 4], opt = [4, 4], max = [6, 4]>}, %arg1: tensor<i32> {plan.memory_space = #plan.memory_space<host>, tensorrt.value_bounds = #tensorrt.shape_profile<min = [1], opt = [2], max = [3]>}) -> tensor<?x?xf32> {
72+
%0 = tensorrt.element_wise <kSUM>(%arg0, %arg0 : tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
73+
%1 = tensorrt.shape %0 : tensor<?x4xf32> -> tensor<2xi32>
74+
%2 = tensorrt.slice %1[0][1][1] : tensor<2xi32> to tensor<1xi32>
75+
%3 = tensorrt.collapse_rank %2 : tensor<1xi32> to tensor<i32>
76+
%cst_i32 = tensorrt.constant dense<1> : tensor<i32>
77+
%4 = tensorrt.element_wise <kPROD>(%3, %cst_i32 : tensor<i32>, tensor<i32>) -> tensor<i32>
78+
%5 = tensorrt.slice %1[1][1][1] : tensor<2xi32> to tensor<1xi32>
79+
%6 = tensorrt.collapse_rank %5 : tensor<1xi32> to tensor<i32>
80+
%7 = tensorrt.element_wise <kPROD>(%4, %6 : tensor<i32>, tensor<i32>) -> tensor<i32>
81+
%cst_i32_0 = tensorrt.constant dense<1> : tensor<i32>
82+
%8 = tensorrt.element_wise <kPROD>(%arg1, %cst_i32_0 : tensor<i32>, tensor<i32>) -> tensor<i32>
83+
%9 = tensorrt.element_wise <kFLOOR_DIV>(%7, %8 : tensor<i32>, tensor<i32>) -> tensor<i32>
84+
%cst_i32_1 = tensorrt.constant dense<1> : tensor<1xi32>
85+
%10 = tensorrt.reshape %9 shape(%cst_i32_1: tensor<1xi32>) : tensor<i32> to tensor<?xi32>
86+
%cst_i32_2 = tensorrt.constant dense<1> : tensor<1xi32>
87+
%11 = tensorrt.reshape %arg1 shape(%cst_i32_2: tensor<1xi32>) : tensor<i32> to tensor<?xi32>
88+
%12 = tensorrt.concatenation {axis = 0 : i32} ins(%10, %11 : tensor<?xi32>, tensor<?xi32>) -> tensor<2xi32>
89+
%13 = tensorrt.reshape %0 shape(%12: tensor<2xi32>) : tensor<?x4xf32> to tensor<?x?xf32>
90+
return %13 : tensor<?x?xf32>
91+
}

0 commit comments

Comments
 (0)