Skip to content

Commit 05ce756

Browse files
committed
[TensorRT] Copy tensorrt.host_tensor attribute in outline pass
1 parent b55e62c commit 05ce756

File tree

1 file changed

+27
-9
lines changed
  • mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable

1 file changed

+27
-9
lines changed

mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -280,16 +280,22 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
280280
mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName();
281281
StringRef tensorrtDimensionNamesAttrName =
282282
mlir::tensorrt::TensorRTDialect::getDimensionNamesArgAttrName();
283+
StringRef tensorrtValueBoundsAttrName =
284+
mlir::tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName();
285+
StringRef hostTensorAttrName = mlir::getHostTensorArgAttrName();
283286

284287
SmallVector<Attribute> profileAttrsPerInput;
285288
SmallVector<Attribute> dimensionNamesAttrsPerInput;
286289
for (Value v : inputs) {
287290
auto rtt = dyn_cast<RankedTensorType>(v.getType());
288-
if (!rtt || rtt.hasStaticShape()) {
291+
if (!rtt) {
289292
profileAttrsPerInput.push_back(Attribute{});
290293
dimensionNamesAttrsPerInput.push_back(Attribute{});
291294
continue;
292295
}
296+
if (rtt.hasStaticShape()) {
297+
dimensionNamesAttrsPerInput.push_back(Attribute{});
298+
}
293299

294300
auto blockArg = dyn_cast<BlockArgument>(v);
295301
if (!blockArg || blockArg.getOwner()->getParentOp() != parentFunc) {
@@ -299,18 +305,30 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
299305
}
300306

301307
int64_t argIndex = blockArg.getArgNumber();
302-
profileAttrsPerInput.push_back(
308+
// Get shape profile and dynamision name attributes of the input
309+
if (rtt.hasStaticShape()) {
310+
// static-shaped argument can only have value bound attr (shape input)
311+
profileAttrsPerInput.push_back(
312+
parentFunc.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
313+
argIndex, tensorrtValueBoundsAttrName));
314+
dimensionNamesAttrsPerInput.push_back(Attribute{});
315+
// Get host tensor attribute of the input
316+
auto hostTensorAttr = parentFunc.getArgAttr(argIndex, hostTensorAttrName);
317+
if (hostTensorAttr) {
318+
func->setArgAttr(argIndex, hostTensorAttrName, hostTensorAttr);
319+
}
320+
} else {
321+
profileAttrsPerInput.push_back(
303322
parentFunc.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
304323
argIndex, tensorrtShapeBoundsAttrName));
305-
306-
dimensionNamesAttrsPerInput.push_back(
324+
dimensionNamesAttrsPerInput.push_back(
307325
parentFunc.getArgAttrOfType<DictionaryAttr>(
308326
argIndex, tensorrtDimensionNamesAttrName));
309-
310-
if (!profileAttrsPerInput.back()) {
311-
return emitError(blockArg.getLoc())
312-
<< "Profile attribute (" << tensorrtShapeBoundsAttrName
313-
<< ") of argument " << argIndex << " is not set";
327+
if (!profileAttrsPerInput.back()) {
328+
return emitError(blockArg.getLoc())
329+
<< "Profile attribute (" << tensorrtShapeBoundsAttrName
330+
<< ") of argument " << argIndex << " is not set";
331+
}
314332
}
315333
}
316334

0 commit comments

Comments
 (0)