@@ -280,12 +280,17 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
280280 mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName ();
281281 StringRef tensorrtDimensionNamesAttrName =
282282 mlir::tensorrt::TensorRTDialect::getDimensionNamesArgAttrName ();
283+ StringRef tensorrtValueBoundsAttrName =
284+ mlir::tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName ();
285+ StringRef hostTensorAttrName = mlir::getHostTensorArgAttrName ();
286+ StringRef memorySpaceAttrName =
287+ plan::PlanDialect::getMemorySpaceConstraintAttrName ();
283288
284289 SmallVector<Attribute> profileAttrsPerInput;
285290 SmallVector<Attribute> dimensionNamesAttrsPerInput;
286291 for (Value v : inputs) {
287292 auto rtt = dyn_cast<RankedTensorType>(v.getType ());
288- if (!rtt || rtt. hasStaticShape () ) {
293+ if (!rtt) {
289294 profileAttrsPerInput.push_back (Attribute{});
290295 dimensionNamesAttrsPerInput.push_back (Attribute{});
291296 continue ;
@@ -299,30 +304,42 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
299304 }
300305
301306 int64_t argIndex = blockArg.getArgNumber ();
302- profileAttrsPerInput.push_back (
303- parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
304- argIndex, tensorrtShapeBoundsAttrName));
305-
306- dimensionNamesAttrsPerInput.push_back (
307- parentFunc.getArgAttrOfType <DictionaryAttr>(
308- argIndex, tensorrtDimensionNamesAttrName));
309-
310- if (!profileAttrsPerInput.back ()) {
311- return emitError (blockArg.getLoc ())
312- << " Profile attribute (" << tensorrtShapeBoundsAttrName
313- << " ) of argument " << argIndex << " is not set" ;
307+ // Get shape profile and dynamision name attributes of the input
308+ if (rtt.hasStaticShape ()) {
309+ // static-shaped argument can only have value bound attr (shape input)
310+ auto valueBoundAttr =
311+ parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
312+ argIndex, tensorrtValueBoundsAttrName);
313+ if (valueBoundAttr) {
314+ func->setArgAttr (argIndex, tensorrtValueBoundsAttrName, valueBoundAttr);
315+ }
316+ // Get memory space attribute of the input
317+ auto memorySpaceAttr =
318+ parentFunc.getArgAttr (argIndex, memorySpaceAttrName);
319+ if (memorySpaceAttr) {
320+ func->setArgAttr (argIndex, memorySpaceAttrName, memorySpaceAttr);
321+ // Add tensorrt.host_tensor attr, it is needed by NetworkEncoder for now
322+ func->setArgAttr (argIndex, hostTensorAttrName, rewriter.getUnitAttr ());
323+ }
324+ } else {
325+ auto shapeBoundAttr =
326+ parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
327+ argIndex, tensorrtShapeBoundsAttrName);
328+ if (!shapeBoundAttr) {
329+ return emitError (blockArg.getLoc ())
330+ << " Profile attribute (" << tensorrtShapeBoundsAttrName
331+ << " ) of argument " << argIndex << " is not set" ;
332+ }
333+ func->setArgAttr (argIndex, tensorrtShapeBoundsAttrName, shapeBoundAttr);
334+ auto dimensionNameAttr = parentFunc.getArgAttrOfType <DictionaryAttr>(
335+ argIndex, tensorrtDimensionNamesAttrName);
336+ if (dimensionNameAttr) {
337+ func->setArgAttr (argIndex, tensorrtDimensionNamesAttrName,
338+ dimensionNameAttr);
339+ }
314340 }
315341 }
316342
317- for (unsigned idx = 0 ; idx < func->getNumArguments (); idx++) {
318- if (profileAttrsPerInput[idx])
319- func->setArgAttr (idx, tensorrtShapeBoundsAttrName,
320- profileAttrsPerInput[idx]);
321- if (dimensionNamesAttrsPerInput[idx])
322- func->setArgAttr (idx, tensorrtDimensionNamesAttrName,
323- dimensionNamesAttrsPerInput[idx]);
324- }
325-
326343 rewriter.setInsertionPoint (inlineGroupOp);
327344 auto callOp = rewriter.create <tensorrt::CallAllocOp>(
328345 inlineGroupOp.getLoc (), inlineGroupOp.getResultTypes (), inputs,
0 commit comments