@@ -280,16 +280,22 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
280280 mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName ();
281281 StringRef tensorrtDimensionNamesAttrName =
282282 mlir::tensorrt::TensorRTDialect::getDimensionNamesArgAttrName ();
283+ StringRef tensorrtValueBoundsAttrName =
284+ mlir::tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName ();
285+ StringRef hostTensorAttrName = mlir::getHostTensorArgAttrName ();
283286
284287 SmallVector<Attribute> profileAttrsPerInput;
285288 SmallVector<Attribute> dimensionNamesAttrsPerInput;
286289 for (Value v : inputs) {
287290 auto rtt = dyn_cast<RankedTensorType>(v.getType ());
288- if (!rtt || rtt. hasStaticShape () ) {
291+ if (!rtt) {
289292 profileAttrsPerInput.push_back (Attribute{});
290293 dimensionNamesAttrsPerInput.push_back (Attribute{});
291294 continue ;
292295 }
296+ if (rtt.hasStaticShape ()) {
297+ dimensionNamesAttrsPerInput.push_back (Attribute{});
298+ }
293299
294300 auto blockArg = dyn_cast<BlockArgument>(v);
295301 if (!blockArg || blockArg.getOwner ()->getParentOp () != parentFunc) {
@@ -299,18 +305,30 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
299305 }
300306
301307 int64_t argIndex = blockArg.getArgNumber ();
302- profileAttrsPerInput.push_back (
308+ // Get shape profile and dynamision name attributes of the input
309+ if (rtt.hasStaticShape ()) {
310+ // static-shaped argument can only have value bound attr (shape input)
311+ profileAttrsPerInput.push_back (
312+ parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
313+ argIndex, tensorrtValueBoundsAttrName));
314+ dimensionNamesAttrsPerInput.push_back (Attribute{});
315+ // Get host tensor attribute of the input
316+ auto hostTensorAttr = parentFunc.getArgAttr (argIndex, hostTensorAttrName);
317+ if (hostTensorAttr) {
318+ func->setArgAttr (argIndex, hostTensorAttrName, hostTensorAttr);
319+ }
320+ } else {
321+ profileAttrsPerInput.push_back (
303322 parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
304323 argIndex, tensorrtShapeBoundsAttrName));
305-
306- dimensionNamesAttrsPerInput.push_back (
324+ dimensionNamesAttrsPerInput.push_back (
307325 parentFunc.getArgAttrOfType <DictionaryAttr>(
308326 argIndex, tensorrtDimensionNamesAttrName));
309-
310- if (!profileAttrsPerInput. back ()) {
311- return emitError (blockArg. getLoc ())
312- << " Profile attribute ( " << tensorrtShapeBoundsAttrName
313- << " ) of argument " << argIndex << " is not set " ;
327+ if (!profileAttrsPerInput. back ()) {
328+ return emitError (blockArg. getLoc ())
329+ << " Profile attribute ( " << tensorrtShapeBoundsAttrName
330+ << " ) of argument " << argIndex << " is not set " ;
331+ }
314332 }
315333 }
316334
0 commit comments