Skip to content

Commit 515b0e3

Browse files
authored
[TensorRT] Copy plan.memory_space attribute in outline pass (#672)
Copies tensorrt.host_tensor attribute and also adds plan.memory_space attribute accordingly.
1 parent c1d6e9b commit 515b0e3

File tree

3 files changed

+105
-26
lines changed

3 files changed

+105
-26
lines changed

.github/workflows/mlir-tensorrt-ci.yml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,48 @@ jobs:
3636
3737
sudo apt-get autoremove -y
3838
sudo apt-get autoclean -y
39+
40+
# Value of `github.workspace` is /home/runner/work/{repo_name}/{repo-name}
41+
# i.e. /home/runner/work/TensorRT-Incubator/TensorRT-Incubator in our case.
42+
# After this action, repo is cloned inside above path.
43+
- uses: actions/checkout@v4
44+
with:
45+
fetch-depth: 5
46+
47+
- name: Validate commit message
48+
if: ${{ github.event_name == 'pull_request' }}
49+
env:
50+
PR_HEAD_COMMIT_SHA: ${{ github.event.pull_request.head.sha }}
51+
run: |
52+
cat > commit_message_checker.py <<EOF
53+
#!/usr/bin/python3
54+
import re
55+
import sys
56+
import subprocess
57+
58+
git_cmd = f"git show -s --format=%B {sys.argv[1]}"
59+
try:
60+
commit_message_cmd = subprocess.run(git_cmd.split(' '), capture_output=True, text=True, check=True)
61+
commit_message = commit_message_cmd.stdout.strip()
62+
except subprocess.CalledProcessError as e:
63+
print(f"Failed to get PR HEAD commit message with error: {e.stderr.strip()}")
64+
65+
match = re.search(r"^(\[bot\].+|NFC: .+|(.+\n\n+.+\n+.+))$", commit_message, re.DOTALL)
66+
if match:
67+
print("Commit message is in canonical form :)")
68+
sys.exit(0)
69+
print("Commit message is not in the canonical form!")
70+
print(commit_message)
71+
print("")
72+
print("Expected format is, ")
73+
print("<title>")
74+
print("<body>")
75+
print("NOTE: Body should start on new line. '2 spaces + enter' for new line!")
76+
print("NOTE: Body should be at least two lines.")
77+
sys.exit(1)
78+
EOF
79+
80+
python3 commit_message_checker.py ${PR_HEAD_COMMIT_SHA}
3981
4082
# Run initial format check
4183
- name: Run python format and clang check

mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -280,14 +280,15 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
280280
mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName();
281281
StringRef tensorrtDimensionNamesAttrName =
282282
mlir::tensorrt::TensorRTDialect::getDimensionNamesArgAttrName();
283+
StringRef tensorrtValueBoundsAttrName =
284+
mlir::tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName();
285+
StringRef hostTensorAttrName = mlir::getHostTensorArgAttrName();
286+
StringRef memorySpaceAttrName =
287+
plan::PlanDialect::getMemorySpaceConstraintAttrName();
283288

284-
SmallVector<Attribute> profileAttrsPerInput;
285-
SmallVector<Attribute> dimensionNamesAttrsPerInput;
286289
for (Value v : inputs) {
287290
auto rtt = dyn_cast<RankedTensorType>(v.getType());
288-
if (!rtt || rtt.hasStaticShape()) {
289-
profileAttrsPerInput.push_back(Attribute{});
290-
dimensionNamesAttrsPerInput.push_back(Attribute{});
291+
if (!rtt) {
291292
continue;
292293
}
293294

@@ -299,30 +300,42 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
299300
}
300301

301302
int64_t argIndex = blockArg.getArgNumber();
302-
profileAttrsPerInput.push_back(
303-
parentFunc.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
304-
argIndex, tensorrtShapeBoundsAttrName));
305-
306-
dimensionNamesAttrsPerInput.push_back(
307-
parentFunc.getArgAttrOfType<DictionaryAttr>(
308-
argIndex, tensorrtDimensionNamesAttrName));
309-
310-
if (!profileAttrsPerInput.back()) {
311-
return emitError(blockArg.getLoc())
312-
<< "Profile attribute (" << tensorrtShapeBoundsAttrName
313-
<< ") of argument " << argIndex << " is not set";
303+
// Get shape profile and dynamision name attributes of the input
304+
if (rtt.hasStaticShape()) {
305+
// static-shaped argument can only have value bound attr (shape input)
306+
auto valueBoundAttr =
307+
parentFunc.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
308+
argIndex, tensorrtValueBoundsAttrName);
309+
if (valueBoundAttr) {
310+
func->setArgAttr(argIndex, tensorrtValueBoundsAttrName, valueBoundAttr);
311+
}
312+
// Get memory space attribute of the input
313+
auto memorySpaceAttr =
314+
parentFunc.getArgAttr(argIndex, memorySpaceAttrName);
315+
if (memorySpaceAttr) {
316+
func->setArgAttr(argIndex, memorySpaceAttrName, memorySpaceAttr);
317+
// Add tensorrt.host_tensor attr, it is needed by NetworkEncoder for now
318+
func->setArgAttr(argIndex, hostTensorAttrName, rewriter.getUnitAttr());
319+
}
320+
} else {
321+
auto shapeBoundAttr =
322+
parentFunc.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
323+
argIndex, tensorrtShapeBoundsAttrName);
324+
if (!shapeBoundAttr) {
325+
return emitError(blockArg.getLoc())
326+
<< "Profile attribute (" << tensorrtShapeBoundsAttrName
327+
<< ") of argument " << argIndex << " is not set";
328+
}
329+
func->setArgAttr(argIndex, tensorrtShapeBoundsAttrName, shapeBoundAttr);
330+
auto dimensionNameAttr = parentFunc.getArgAttrOfType<DictionaryAttr>(
331+
argIndex, tensorrtDimensionNamesAttrName);
332+
if (dimensionNameAttr) {
333+
func->setArgAttr(argIndex, tensorrtDimensionNamesAttrName,
334+
dimensionNameAttr);
335+
}
314336
}
315337
}
316338

317-
for (unsigned idx = 0; idx < func->getNumArguments(); idx++) {
318-
if (profileAttrsPerInput[idx])
319-
func->setArgAttr(idx, tensorrtShapeBoundsAttrName,
320-
profileAttrsPerInput[idx]);
321-
if (dimensionNamesAttrsPerInput[idx])
322-
func->setArgAttr(idx, tensorrtDimensionNamesAttrName,
323-
dimensionNamesAttrsPerInput[idx]);
324-
}
325-
326339
rewriter.setInsertionPoint(inlineGroupOp);
327340
auto callOp = rewriter.create<tensorrt::CallAllocOp>(
328341
inlineGroupOp.getLoc(), inlineGroupOp.getResultTypes(), inputs,

mlir-tensorrt/tensorrt/test/Target/TensorRT/translate-to-tensorrt.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,27 @@ func.func @trt_dim_names(
6565
%0 = tensorrt.identity %arg0 : tensor<?x?xf32> to tensor<?x?xf32>
6666
return %0 : tensor<?x?xf32>
6767
}
68+
69+
// CHECK-LABEL: @trt_host_input
70+
// CHECK-SAME: tensorrt.engine
71+
func.func @trt_host_input(%arg0: tensor<?x4xf32> {tensorrt.dimension_names = {}, tensorrt.shape_profile = #tensorrt.shape_profile<min = [2, 4], opt = [4, 4], max = [6, 4]>}, %arg1: tensor<i32> {plan.memory_space = #plan.memory_space<host>, tensorrt.value_bounds = #tensorrt.shape_profile<min = [1], opt = [2], max = [3]>}) -> tensor<?x?xf32> {
72+
%0 = tensorrt.element_wise <kSUM>(%arg0, %arg0 : tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
73+
%1 = tensorrt.shape %0 : tensor<?x4xf32> -> tensor<2xi32>
74+
%2 = tensorrt.slice %1[0][1][1] : tensor<2xi32> to tensor<1xi32>
75+
%3 = tensorrt.collapse_rank %2 : tensor<1xi32> to tensor<i32>
76+
%cst_i32 = tensorrt.constant dense<1> : tensor<i32>
77+
%4 = tensorrt.element_wise <kPROD>(%3, %cst_i32 : tensor<i32>, tensor<i32>) -> tensor<i32>
78+
%5 = tensorrt.slice %1[1][1][1] : tensor<2xi32> to tensor<1xi32>
79+
%6 = tensorrt.collapse_rank %5 : tensor<1xi32> to tensor<i32>
80+
%7 = tensorrt.element_wise <kPROD>(%4, %6 : tensor<i32>, tensor<i32>) -> tensor<i32>
81+
%cst_i32_0 = tensorrt.constant dense<1> : tensor<i32>
82+
%8 = tensorrt.element_wise <kPROD>(%arg1, %cst_i32_0 : tensor<i32>, tensor<i32>) -> tensor<i32>
83+
%9 = tensorrt.element_wise <kFLOOR_DIV>(%7, %8 : tensor<i32>, tensor<i32>) -> tensor<i32>
84+
%cst_i32_1 = tensorrt.constant dense<1> : tensor<1xi32>
85+
%10 = tensorrt.reshape %9 shape(%cst_i32_1: tensor<1xi32>) : tensor<i32> to tensor<?xi32>
86+
%cst_i32_2 = tensorrt.constant dense<1> : tensor<1xi32>
87+
%11 = tensorrt.reshape %arg1 shape(%cst_i32_2: tensor<1xi32>) : tensor<i32> to tensor<?xi32>
88+
%12 = tensorrt.concatenation {axis = 0 : i32} ins(%10, %11 : tensor<?xi32>, tensor<?xi32>) -> tensor<2xi32>
89+
%13 = tensorrt.reshape %0 shape(%12: tensor<2xi32>) : tensor<?x4xf32> to tensor<?x?xf32>
90+
return %13 : tensor<?x?xf32>
91+
}

0 commit comments

Comments
 (0)