@@ -280,12 +280,17 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
280280 mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName ();
281281 StringRef tensorrtDimensionNamesAttrName =
282282 mlir::tensorrt::TensorRTDialect::getDimensionNamesArgAttrName ();
283+ StringRef tensorrtValueBoundsAttrName =
284+ mlir::tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName ();
285+ StringRef hostTensorAttrName = mlir::getHostTensorArgAttrName ();
286+ StringRef memorySpaceAttrName =
287+ plan::PlanDialect::getMemorySpaceConstraintAttrName ();
283288
284289 SmallVector<Attribute> profileAttrsPerInput;
285290 SmallVector<Attribute> dimensionNamesAttrsPerInput;
286291 for (Value v : inputs) {
287292 auto rtt = dyn_cast<RankedTensorType>(v.getType ());
288- if (!rtt || rtt. hasStaticShape () ) {
293+ if (!rtt) {
289294 profileAttrsPerInput.push_back (Attribute{});
290295 dimensionNamesAttrsPerInput.push_back (Attribute{});
291296 continue ;
@@ -299,30 +304,41 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
299304 }
300305
301306 int64_t argIndex = blockArg.getArgNumber ();
302- profileAttrsPerInput.push_back (
303- parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
304- argIndex, tensorrtShapeBoundsAttrName));
305-
306- dimensionNamesAttrsPerInput.push_back (
307- parentFunc.getArgAttrOfType <DictionaryAttr>(
308- argIndex, tensorrtDimensionNamesAttrName));
309-
310- if (!profileAttrsPerInput.back ()) {
311- return emitError (blockArg.getLoc ())
312- << " Profile attribute (" << tensorrtShapeBoundsAttrName
313- << " ) of argument " << argIndex << " is not set" ;
307+ // Get shape profile and dynamision name attributes of the input
308+ if (rtt.hasStaticShape ()) {
309+ // static-shaped argument can only have value bound attr (shape input)
310+ auto valueBoundAttr =
311+ parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
312+ argIndex, tensorrtValueBoundsAttrName);
313+ if (valueBoundAttr) {
314+ func->setArgAttr (argIndex, tensorrtValueBoundsAttrName, valueBoundAttr);
315+ }
316+ // Get host tensor attribute of the input
317+ auto memorySpaceAttr = parentFunc.getArgAttr (argIndex, memorySpaceAttrName);
318+ if (memorySpaceAttr) {
319+ func->setArgAttr (argIndex, memorySpaceAttrName, memorySpaceAttr);
320+ // Add tensorrt.host_tensor attr, it is needed by NetworkEncoder for now
321+ func->setArgAttr (argIndex, hostTensorAttrName, rewriter.getUnitAttr ());
322+ }
323+ } else {
324+ auto shapeBoundAttr =
325+ parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
326+ argIndex, tensorrtShapeBoundsAttrName);
327+ if (!shapeBoundAttr) {
328+ return emitError (blockArg.getLoc ())
329+ << " Profile attribute (" << tensorrtShapeBoundsAttrName
330+ << " ) of argument " << argIndex << " is not set" ;
331+ }
332+ func->setArgAttr (argIndex, tensorrtShapeBoundsAttrName, shapeBoundAttr);
333+ auto dimensionNameAttr = parentFunc.getArgAttrOfType <DictionaryAttr>(
334+ argIndex, tensorrtDimensionNamesAttrName);
335+ if (dimensionNameAttr) {
336+ func->setArgAttr (argIndex, tensorrtDimensionNamesAttrName,
337+ dimensionNameAttr);
338+ }
314339 }
315340 }
316341
317- for (unsigned idx = 0 ; idx < func->getNumArguments (); idx++) {
318- if (profileAttrsPerInput[idx])
319- func->setArgAttr (idx, tensorrtShapeBoundsAttrName,
320- profileAttrsPerInput[idx]);
321- if (dimensionNamesAttrsPerInput[idx])
322- func->setArgAttr (idx, tensorrtDimensionNamesAttrName,
323- dimensionNamesAttrsPerInput[idx]);
324- }
325-
326342 rewriter.setInsertionPoint (inlineGroupOp);
327343 auto callOp = rewriter.create <tensorrt::CallAllocOp>(
328344 inlineGroupOp.getLoc (), inlineGroupOp.getResultTypes (), inputs,
0 commit comments