@@ -280,12 +280,17 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
280280 mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName ();
281281 StringRef tensorrtDimensionNamesAttrName =
282282 mlir::tensorrt::TensorRTDialect::getDimensionNamesArgAttrName ();
283+ StringRef tensorrtValueBoundsAttrName =
284+ mlir::tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName ();
285+ StringRef hostTensorAttrName = mlir::getHostTensorArgAttrName ();
286+ StringRef memorySpaceAttrName =
287+ plan::PlanDialect::getMemorySpaceConstraintAttrName ();
283288
284289 SmallVector<Attribute> profileAttrsPerInput;
285290 SmallVector<Attribute> dimensionNamesAttrsPerInput;
286291 for (Value v : inputs) {
287292 auto rtt = dyn_cast<RankedTensorType>(v.getType ());
288- if (!rtt || rtt. hasStaticShape () ) {
293+ if (!rtt) {
289294 profileAttrsPerInput.push_back (Attribute{});
290295 dimensionNamesAttrsPerInput.push_back (Attribute{});
291296 continue ;
@@ -299,30 +304,45 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
299304 }
300305
301306 int64_t argIndex = blockArg.getArgNumber ();
302- profileAttrsPerInput.push_back (
303- parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
304- argIndex, tensorrtShapeBoundsAttrName));
305-
306- dimensionNamesAttrsPerInput.push_back (
307- parentFunc.getArgAttrOfType <DictionaryAttr>(
308- argIndex, tensorrtDimensionNamesAttrName));
309-
310- if (!profileAttrsPerInput.back ()) {
311- return emitError (blockArg.getLoc ())
312- << " Profile attribute (" << tensorrtShapeBoundsAttrName
313- << " ) of argument " << argIndex << " is not set" ;
307+ // Get shape profile and dynamision name attributes of the input
308+ if (rtt.hasStaticShape ()) {
309+ // static-shaped argument can only have value bound attr (shape input)
310+ auto valueBoundAttr =
311+ parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
312+ argIndex, tensorrtValueBoundsAttrName);
313+ if (valueBoundAttr) {
314+ func->setArgAttr (argIndex, tensorrtValueBoundsAttrName, valueBoundAttr);
315+ }
316+ // Get host tensor attribute of the input
317+ auto hostTensorAttr = parentFunc.getArgAttr (argIndex, hostTensorAttrName);
318+ if (hostTensorAttr) {
319+ func->setArgAttr (argIndex, hostTensorAttrName, hostTensorAttr);
320+ // Add plan.memory_space attr, it is also required for the parent
321+ // function
322+ auto memorySpaceAttr = plan::MemorySpaceAttr::get (
323+ rewriter.getContext (), plan::MemorySpace::host);
324+ func->setArgAttr (argIndex, memorySpaceAttrName, memorySpaceAttr);
325+ parentFunc.setArgAttr (argIndex, memorySpaceAttrName, memorySpaceAttr);
326+ }
327+ } else {
328+ auto shapeBoundAttr =
329+ parentFunc.getArgAttrOfType <tensorrt::ShapeProfileAttr>(
330+ argIndex, tensorrtShapeBoundsAttrName);
331+ if (!shapeBoundAttr) {
332+ return emitError (blockArg.getLoc ())
333+ << " Profile attribute (" << tensorrtShapeBoundsAttrName
334+ << " ) of argument " << argIndex << " is not set" ;
335+ }
336+ func->setArgAttr (argIndex, tensorrtShapeBoundsAttrName, shapeBoundAttr);
337+ auto dimensionNameAttr = parentFunc.getArgAttrOfType <DictionaryAttr>(
338+ argIndex, tensorrtDimensionNamesAttrName);
339+ if (dimensionNameAttr) {
340+ func->setArgAttr (argIndex, tensorrtDimensionNamesAttrName,
341+ dimensionNameAttr);
342+ }
314343 }
315344 }
316345
317- for (unsigned idx = 0 ; idx < func->getNumArguments (); idx++) {
318- if (profileAttrsPerInput[idx])
319- func->setArgAttr (idx, tensorrtShapeBoundsAttrName,
320- profileAttrsPerInput[idx]);
321- if (dimensionNamesAttrsPerInput[idx])
322- func->setArgAttr (idx, tensorrtDimensionNamesAttrName,
323- dimensionNamesAttrsPerInput[idx]);
324- }
325-
326346 rewriter.setInsertionPoint (inlineGroupOp);
327347 auto callOp = rewriter.create <tensorrt::CallAllocOp>(
328348 inlineGroupOp.getLoc (), inlineGroupOp.getResultTypes (), inputs,
0 commit comments