@@ -280,12 +280,15 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
280280 mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName ();
281281 StringRef tensorrtDimensionNamesAttrName =
282282 mlir::tensorrt::TensorRTDialect::getDimensionNamesArgAttrName ();
283+ StringRef tensorrtValueBoundsAttrName =
284+ mlir::tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName ();
285+ StringRef hostTensorAttrName = mlir::getHostTensorArgAttrName ();
283286
284287 SmallVector<Attribute> profileAttrsPerInput;
285288 SmallVector<Attribute> dimensionNamesAttrsPerInput;
286289 for (Value v : inputs) {
287290 auto rtt = dyn_cast<RankedTensorType>(v.getType ());
288- if (!rtt || rtt. hasStaticShape () ) {
291+ if (!rtt) {
289292 profileAttrsPerInput.push_back (Attribute{});
290293 dimensionNamesAttrsPerInput.push_back (Attribute{});
291294 continue ;
@@ -299,30 +302,39 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
299302 }
300303
301304 int64_t argIndex = blockArg.getArgNumber ();
302- profileAttrsPerInput.push_back (
303- parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
304- argIndex, tensorrtShapeBoundsAttrName));
305-
306- dimensionNamesAttrsPerInput.push_back (
307- parentFunc.getArgAttrOfType <DictionaryAttr>(
308- argIndex, tensorrtDimensionNamesAttrName));
309-
310- if (!profileAttrsPerInput.back ()) {
311- return emitError (blockArg.getLoc ())
312- << " Profile attribute (" << tensorrtShapeBoundsAttrName
313- << " ) of argument " << argIndex << " is not set" ;
305+ // Get shape profile and dynamision name attributes of the input
306+ if (rtt.hasStaticShape ()) {
307+ // static-shaped argument can only have value bound attr (shape input)
308+ auto valueBoundAttr =
309+ parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
310+ argIndex, tensorrtValueBoundsAttrName);
311+ if (valueBoundAttr) {
312+ func->setArgAttr (argIndex, tensorrtValueBoundsAttrName, valueBoundAttr);
313+ }
314+ // Get host tensor attribute of the input
315+ auto hostTensorAttr = parentFunc.getArgAttr (argIndex, hostTensorAttrName);
316+ if (hostTensorAttr) {
317+ func->setArgAttr (argIndex, hostTensorAttrName, hostTensorAttr);
318+ }
319+ } else {
320+ auto shapeBoundAttr =
321+ parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
322+ argIndex, tensorrtShapeBoundsAttrName);
323+ if (!shapeBoundAttr) {
324+ return emitError (blockArg.getLoc ())
325+ << " Profile attribute (" << tensorrtShapeBoundsAttrName
326+ << " ) of argument " << argIndex << " is not set" ;
327+ }
328+ func->setArgAttr (argIndex, tensorrtShapeBoundsAttrName, shapeBoundAttr);
329+ auto dimensionNameAttr = parentFunc.getArgAttrOfType <DictionaryAttr>(
330+ argIndex, tensorrtDimensionNamesAttrName);
331+ if (dimensionNameAttr) {
332+ func->setArgAttr (argIndex, tensorrtDimensionNamesAttrName,
333+ dimensionNameAttr);
334+ }
314335 }
315336 }
316337
317- for (unsigned idx = 0 ; idx < func->getNumArguments (); idx++) {
318- if (profileAttrsPerInput[idx])
319- func->setArgAttr (idx, tensorrtShapeBoundsAttrName,
320- profileAttrsPerInput[idx]);
321- if (dimensionNamesAttrsPerInput[idx])
322- func->setArgAttr (idx, tensorrtDimensionNamesAttrName,
323- dimensionNamesAttrsPerInput[idx]);
324- }
325-
326338 rewriter.setInsertionPoint (inlineGroupOp);
327339 auto callOp = rewriter.create <tensorrt::CallAllocOp>(
328340 inlineGroupOp.getLoc (), inlineGroupOp.getResultTypes (), inputs,
0 commit comments