diff --git a/src/op/fill.cc b/src/op/fill.cc index 210ff87ef1..e2a3210f43 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -10,6 +10,8 @@ #include #include +#include "../layout/layout.h" +#include "../layout/tcgen05_layout.h" #include "../target/utils.h" #include "../transform/common/attr.h" #include "../transform/common/loop_fusion_utils.h" @@ -21,6 +23,46 @@ namespace tvm { namespace tl { +namespace { +/** + * @brief Check if a buffer can use the Sunmmio blockwise ZZ layout. + * + * This function enforces a relaxed constraint: + * 1. 1D buffers are strictly excluded. + * 2. "Degenerated 1D" buffers (e.g., (1, 64), (1, 1, 128)) are excluded + * and should fallback to linear layouts. + * 3. All other multi-dimensional shapes (e.g., (2, 64), (1, 4, 64), (32, 32)) + * are allowed to use the blockwise ZZ layout. The python-side generator + * will automatically pad these shapes to the nearest 32x32 block boundary. + */ +bool CanUseBlockwiseZZ(const Buffer &buf) { + if (buf->shape.size() < 2) + return false; + + // Check if it's (1, 1, ..., 1, N) shape + bool all_ones_except_last = true; + for (size_t i = 0; i < buf->shape.size() - 1; ++i) { + if (const auto *imm = buf->shape[i].as()) { + if (imm->value != 1) { + all_ones_except_last = false; + break; + } + } else { + // If any dimension before the last is not a constant, + // we conservatively assume it's not a (1,1,...,1,N) shape. + all_ones_except_last = false; + break; + } + } + + if (all_ones_except_last) { + return false; + } + + return true; +} +} // namespace + using namespace tir; /** @@ -277,19 +319,47 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } } +LayoutMap FillNode::InferLayoutSunmmioTileFill(const LayoutInferArgs &T, + InferLevel level) const { + if (level == InferLevel::kStrict) { + auto dst_scope = dst.scope(); + ICHECK(dst_scope == "shared.rsram") + << "For Sunmmio target, Fill operator dst must be in " + "shared.rsram scope, but got " + << dst_scope; + + if (CanUseBlockwiseZZ(dst)) { + const auto f = + ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout"); + ICHECK(f != nullptr) + << "Cannot find global function tl.layout.make_blockwise_zz_layout"; + auto layout = Downcast((*f)(dst)); + return {{dst, layout}}; + } else { + auto layout = makeLinearLayout(dst->shape); + return {{dst, layout}}; + } + } + return {}; +} + /** * @brief Infer memory/layout mapping for the Fill operator. * * Returns the layout mapping produced by layout inference for this FillNode. - * Currently no layout inference is performed for Fill and the function returns - * an empty LayoutMap. + * For Sunmmio targets, if the destination buffer is in the shared.rsram scope, + * it strictly infers a blockwise zz layout. Otherwise, it returns an empty + * LayoutMap. * - * @param T Context required for layout inference (unused). - * @param level The inference level requested (unused). - * @return LayoutMap Empty map indicating no inferred layouts for this operator. + * @param T Context required for layout inference (target, layout map, etc.). + * @param level The inference level requested. + * @return LayoutMap Inferred layout mappings. */ LayoutMap FillNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { + if (TargetIsSunmmio(T.target)) { + return InferLayoutSunmmioTileFill(T, level); + } return {}; } diff --git a/src/op/fill.h b/src/op/fill.h index 75bdc8932f..cceb90804e 100644 --- a/src/op/fill.h +++ b/src/op/fill.h @@ -40,7 +40,10 @@ class FillNode : public TileOperatorNode { /// Create SIMT-style parallel loop for filling For MakeSIMTLoop(arith::Analyzer *analyzer) const; /// Sunmmio Tile-based fill logic - Stmt MakeSunmmioTileFill(const LowerArgs &, arith::Analyzer *analyzer) const; + Stmt MakeSunmmioTileFill(const LowerArgs &T, arith::Analyzer *analyzer) const; + /// Sunmmio layout inference logic + LayoutMap InferLayoutSunmmioTileFill(const LayoutInferArgs &T, + InferLevel level) const; }; /// Wrapper class for fill operations diff --git a/src/op/operator.h b/src/op/operator.h index 2ef67e8cd9..89aac1dd5e 100644 --- a/src/op/operator.h +++ b/src/op/operator.h @@ -73,6 +73,7 @@ struct LayoutInferArgs { // fragment buffer accesses through let bindings Map let_var_to_expr; LayoutMap global_layout_map; + TileViewMap tileview_map; }; class TileOperator; diff --git a/src/op/reduce.cc b/src/op/reduce.cc index af7f4bde5b..0e4e03128c 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -28,6 +28,46 @@ namespace tvm { namespace tl { +namespace { +/** + * @brief Check if a buffer can use the Sunmmio blockwise ZZ layout. + * + * This function enforces a relaxed constraint: + * 1. 1D buffers are strictly excluded. + * 2. "Degenerated 1D" buffers (e.g., (1, 64), (1, 1, 128)) are excluded + * and should fallback to linear layouts. + * 3. All other multi-dimensional shapes (e.g., (2, 64), (1, 4, 64), (32, 32)) + * are allowed to use the blockwise ZZ layout. The python-side generator + * will automatically pad these shapes to the nearest 32x32 block boundary. + */ +bool CanUseBlockwiseZZ(const Buffer &buf) { + if (buf->shape.size() < 2) + return false; + + // Check if it's (1, 1, ..., 1, N) shape + bool all_ones_except_last = true; + for (size_t i = 0; i < buf->shape.size() - 1; ++i) { + if (const auto *imm = buf->shape[i].as()) { + if (imm->value != 1) { + all_ones_except_last = false; + break; + } + } else { + // If any dimension before the last is not a constant, + // we conservatively assume it's not a (1,1,...,1,N) shape. + all_ones_except_last = false; + break; + } + } + + if (all_ones_except_last) { + return false; + } + + return true; +} +} // namespace + using namespace tir; // NormalizeToBufferRegion moved to src/op/utils.{h,cc} @@ -711,8 +751,27 @@ Stmt ReduceOpNode::MakeSunmmioTileReduce(const LowerArgs &T, alloc_buffers.push_back(dst_res.value()); } + Map local_layout_map; + const auto make_zz = + ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout"); + ICHECK(make_zz != nullptr) + << "Cannot find global function tl.layout.make_blockwise_zz_layout"; + + if (CanUseBlockwiseZZ(acc)) { + auto acc_layout = Downcast((*make_zz)(acc)); + local_layout_map.Set(acc, acc_layout); + } else { + auto acc_layout = makeLinearLayout(acc->shape); + local_layout_map.Set(acc, acc_layout); + } + + if (dst_res.defined()) { + auto res_layout = makeLinearLayout(dst_res.value()->shape); + local_layout_map.Set(dst_res.value(), res_layout); + } + body = Block({}, {}, {}, "reduce_tile_op", body, std::nullopt, alloc_buffers, - {}, {}); + {}, {{attr::kLayoutMap, local_layout_map}}); return body; } @@ -1032,8 +1091,67 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return Stmt(); } +LayoutMap ReduceOpNode::InferLayoutSunmmioTileReduce(const LayoutInferArgs &T, + InferLevel level) const { + if (level == InferLevel::kStrict) { + auto src_scope = src.scope(); + auto dst_scope = dst.scope(); + ICHECK(src_scope == "shared.rsram" && dst_scope == "shared.rsram") + << "For Sunmmio target, Reduce operator src and dst must be in " + "shared.rsram scope, but got " + << src_scope << " and " << dst_scope; + + LayoutMap result; + const auto make_zz = + ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout"); + ICHECK(make_zz != nullptr) + << "Cannot find global function tl.layout.make_blockwise_zz_layout"; + + if (CanUseBlockwiseZZ(src)) { + auto src_layout = Downcast((*make_zz)(src)); + result.Set(src, src_layout); + } else { + auto src_layout = makeLinearLayout(src->shape); + result.Set(src, src_layout); + } + + Optional opt_tv = FindTileView(T.tileview_map, src); + if (opt_tv.defined()) { + TileView tv = opt_tv.value(); + bool is_tiled = false; + int src_ndim = src->shape.size(); + for (size_t i = 0; i < tv->IndexMap().size(); i++) { + const auto *idx_ptr = tv->IndexMap()[i].as(); + if (idx_ptr) { + int dim = idx_ptr->value; + if (dim < 0) + dim += src_ndim; + if (dim == this->dim) { + is_tiled = true; + break; + } + } + } + + if (is_tiled || !CanUseBlockwiseZZ(dst)) { + auto dst_layout = makeLinearLayout(dst->shape); + result.Set(dst, dst_layout); + } else { + auto dst_layout = Downcast((*make_zz)(dst)); + result.Set(dst, dst_layout); + } + } + return result; + } + return {}; +} + LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { + if (TargetIsSunmmio(T.target)) { + return InferLayoutSunmmioTileReduce(T, level); + } + if (level >= InferLevel::kStrict) return {}; diff --git a/src/op/reduce.h b/src/op/reduce.h index fba633bc38..14c358d793 100644 --- a/src/op/reduce.h +++ b/src/op/reduce.h @@ -121,6 +121,9 @@ class ReduceOpNode : public TileOperatorNode { /// Sunmmio Tile-based reduction logic Stmt MakeSunmmioTileReduce(const LowerArgs &T, arith::Analyzer *analyzer) const; + /// Sunmmio layout inference logic + LayoutMap InferLayoutSunmmioTileReduce(const LayoutInferArgs &T, + InferLevel level) const; }; /// Wrapper class for reduction operations diff --git a/src/op/utils.h b/src/op/utils.h index d0b27aa014..5ad2959ff9 100644 --- a/src/op/utils.h +++ b/src/op/utils.h @@ -85,6 +85,31 @@ inline bool IsLocalVarBuffer(const Buffer &buffer) { return buffer.defined() && buffer.scope() == "local.var"; } +// Helper to find TileView metadata for a buffer, supporting name-hint +// fallback. This is necessary because TVM may rename buffers (e.g., +// adding suffixes like _1, _2) during lowering, which causes direct +// pointer-based lookup in tileview_map to fail. +inline Optional FindTileView(const TileViewMap &tileview_map, + const Buffer &buf) { + if (tileview_map.count(buf->data)) { + return tileview_map.at(buf->data); + } + // Fallback: match by name hint, ignoring common suffixes like _1, _2. + auto simplify_name = [](std::string name) { + if (name.size() > 2 && name[name.size() - 2] == '_') { + return name.substr(0, name.size() - 2); + } + return name; + }; + std::string target_name = simplify_name(buf->data->name_hint); + for (const auto &kv : tileview_map) { + if (simplify_name(kv.first->name_hint) == target_name) { + return kv.second; + } + } + return std::nullopt; +} + } // namespace tl } // namespace tvm diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 21229d4460..642686018d 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -123,7 +123,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { buffer_oob, {}, let_var_to_expr_, - global_layout_map_}, + global_layout_map_, + tileview_map_}, level); // Process the returned updates @@ -755,6 +756,15 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { // BufferLoad/BufferStore IRVisitorWithAnalyzer::VisitStmt_(op); + if (op->annotations.count(attr::kTileViewMap)) { + auto new_map = op->annotations.at(attr::kTileViewMap) + .as>() + .value(); + for (auto [k, v] : new_map) { + tileview_map_.Set(k, v); + } + } + // After visiting, apply layouts to all collected buffers if (op->annotations.count(attr::kLayoutMap)) { // Check if the layout map is Map @@ -1023,6 +1033,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { Target target_; LayoutMap annotated_layout_map_; LayoutMap global_layout_map_; + TileViewMap tileview_map_; bool skip_thread_partition_{false}; std::vector BackupInferList() {