@@ -280,14 +280,15 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
280280 mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName ();
281281 StringRef tensorrtDimensionNamesAttrName =
282282 mlir::tensorrt::TensorRTDialect::getDimensionNamesArgAttrName ();
283+ StringRef tensorrtValueBoundsAttrName =
284+ mlir::tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName ();
285+ StringRef hostTensorAttrName = mlir::getHostTensorArgAttrName ();
286+ StringRef memorySpaceAttrName =
287+ plan::PlanDialect::getMemorySpaceConstraintAttrName ();
283288
284- SmallVector<Attribute> profileAttrsPerInput;
285- SmallVector<Attribute> dimensionNamesAttrsPerInput;
286289 for (Value v : inputs) {
287290 auto rtt = dyn_cast<RankedTensorType>(v.getType ());
288- if (!rtt || rtt.hasStaticShape ()) {
289- profileAttrsPerInput.push_back (Attribute{});
290- dimensionNamesAttrsPerInput.push_back (Attribute{});
291+ if (!rtt) {
291292 continue ;
292293 }
293294
@@ -299,30 +300,42 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
299300 }
300301
301302 int64_t argIndex = blockArg.getArgNumber ();
302- profileAttrsPerInput.push_back (
303- parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
304- argIndex, tensorrtShapeBoundsAttrName));
305-
306- dimensionNamesAttrsPerInput.push_back (
307- parentFunc.getArgAttrOfType <DictionaryAttr>(
308- argIndex, tensorrtDimensionNamesAttrName));
309-
310- if (!profileAttrsPerInput.back ()) {
311- return emitError (blockArg.getLoc ())
312- << " Profile attribute (" << tensorrtShapeBoundsAttrName
313- << " ) of argument " << argIndex << " is not set" ;
303+ // Get shape profile and dynamision name attributes of the input
304+ if (rtt.hasStaticShape ()) {
305+ // static-shaped argument can only have value bound attr (shape input)
306+ auto valueBoundAttr =
307+ parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
308+ argIndex, tensorrtValueBoundsAttrName);
309+ if (valueBoundAttr) {
310+ func->setArgAttr (argIndex, tensorrtValueBoundsAttrName, valueBoundAttr);
311+ }
312+ // Get memory space attribute of the input
313+ auto memorySpaceAttr =
314+ parentFunc.getArgAttr (argIndex, memorySpaceAttrName);
315+ if (memorySpaceAttr) {
316+ func->setArgAttr (argIndex, memorySpaceAttrName, memorySpaceAttr);
317+ // Add tensorrt.host_tensor attr, it is needed by NetworkEncoder for now
318+ func->setArgAttr (argIndex, hostTensorAttrName, rewriter.getUnitAttr ());
319+ }
320+ } else {
321+ auto shapeBoundAttr =
322+ parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
323+ argIndex, tensorrtShapeBoundsAttrName);
324+ if (!shapeBoundAttr) {
325+ return emitError (blockArg.getLoc ())
326+ << " Profile attribute (" << tensorrtShapeBoundsAttrName
327+ << " ) of argument " << argIndex << " is not set" ;
328+ }
329+ func->setArgAttr (argIndex, tensorrtShapeBoundsAttrName, shapeBoundAttr);
330+ auto dimensionNameAttr = parentFunc.getArgAttrOfType <DictionaryAttr>(
331+ argIndex, tensorrtDimensionNamesAttrName);
332+ if (dimensionNameAttr) {
333+ func->setArgAttr (argIndex, tensorrtDimensionNamesAttrName,
334+ dimensionNameAttr);
335+ }
314336 }
315337 }
316338
317- for (unsigned idx = 0 ; idx < func->getNumArguments (); idx++) {
318- if (profileAttrsPerInput[idx])
319- func->setArgAttr (idx, tensorrtShapeBoundsAttrName,
320- profileAttrsPerInput[idx]);
321- if (dimensionNamesAttrsPerInput[idx])
322- func->setArgAttr (idx, tensorrtDimensionNamesAttrName,
323- dimensionNamesAttrsPerInput[idx]);
324- }
325-
326339 rewriter.setInsertionPoint (inlineGroupOp);
327340 auto callOp = rewriter.create <tensorrt::CallAllocOp>(
328341 inlineGroupOp.getLoc (), inlineGroupOp.getResultTypes (), inputs,
0 commit comments