Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 75 additions & 5 deletions src/op/fill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../layout/layout.h"
#include "../layout/tcgen05_layout.h"
#include "../target/utils.h"
#include "../transform/common/attr.h"
#include "../transform/common/loop_fusion_utils.h"
Expand All @@ -21,6 +23,46 @@
namespace tvm {
namespace tl {

namespace {
/**
* @brief Check if a buffer can use the Sunmmio blockwise ZZ layout.
*
* This function enforces a relaxed constraint:
* 1. 1D buffers are strictly excluded.
* 2. "Degenerated 1D" buffers (e.g., (1, 64), (1, 1, 128)) are excluded
* and should fallback to linear layouts.
* 3. All other multi-dimensional shapes (e.g., (2, 64), (1, 4, 64), (32, 32))
* are allowed to use the blockwise ZZ layout. The python-side generator
* will automatically pad these shapes to the nearest 32x32 block boundary.
*/
bool CanUseBlockwiseZZ(const Buffer &buf) {
if (buf->shape.size() < 2)
return false;

// Check if it's (1, 1, ..., 1, N) shape
bool all_ones_except_last = true;
for (size_t i = 0; i < buf->shape.size() - 1; ++i) {
if (const auto *imm = buf->shape[i].as<IntImmNode>()) {
if (imm->value != 1) {
all_ones_except_last = false;
break;
}
} else {
// If any dimension before the last is not a constant,
// we conservatively assume it's not a (1,1,...,1,N) shape.
all_ones_except_last = false;
break;
}
}

if (all_ones_except_last) {
return false;
}

return true;
}
} // namespace

using namespace tir;

/**
Expand Down Expand Up @@ -277,19 +319,47 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
}

LayoutMap FillNode::InferLayoutSunmmioTileFill(const LayoutInferArgs &T,
InferLevel level) const {
if (level == InferLevel::kStrict) {
auto dst_scope = dst.scope();
ICHECK(dst_scope == "shared.rsram")
<< "For Sunmmio target, Fill operator dst must be in "
"shared.rsram scope, but got "
<< dst_scope;

if (CanUseBlockwiseZZ(dst)) {
const auto f =
ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout");
ICHECK(f != nullptr)
<< "Cannot find global function tl.layout.make_blockwise_zz_layout";
auto layout = Downcast<Layout>((*f)(dst));
return {{dst, layout}};
} else {
auto layout = makeLinearLayout(dst->shape);
return {{dst, layout}};
}
}
return {};
}

/**
* @brief Infer memory/layout mapping for the Fill operator.
*
* Returns the layout mapping produced by layout inference for this FillNode.
* Currently no layout inference is performed for Fill and the function returns
* an empty LayoutMap.
* For Sunmmio targets, if the destination buffer is in the shared.rsram scope,
* it strictly infers a blockwise zz layout. Otherwise, it returns an empty
* LayoutMap.
*
* @param T Context required for layout inference (unused).
* @param level The inference level requested (unused).
* @return LayoutMap Empty map indicating no inferred layouts for this operator.
* @param T Context required for layout inference (target, layout map, etc.).
* @param level The inference level requested.
* @return LayoutMap Inferred layout mappings.
*/
LayoutMap FillNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (TargetIsSunmmio(T.target)) {
return InferLayoutSunmmioTileFill(T, level);
}
return {};
}

Expand Down
5 changes: 4 additions & 1 deletion src/op/fill.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ class FillNode : public TileOperatorNode {
/// Create SIMT-style parallel loop for filling
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
/// Sunmmio Tile-based fill logic
Stmt MakeSunmmioTileFill(const LowerArgs &, arith::Analyzer *analyzer) const;
Stmt MakeSunmmioTileFill(const LowerArgs &T, arith::Analyzer *analyzer) const;
/// Sunmmio layout inference logic
LayoutMap InferLayoutSunmmioTileFill(const LayoutInferArgs &T,
InferLevel level) const;
};

/// Wrapper class for fill operations
Expand Down
1 change: 1 addition & 0 deletions src/op/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ struct LayoutInferArgs {
// fragment buffer accesses through let bindings
Map<Var, PrimExpr> let_var_to_expr;
LayoutMap global_layout_map;
TileViewMap tileview_map;
};

class TileOperator;
Expand Down
120 changes: 119 additions & 1 deletion src/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,46 @@
namespace tvm {
namespace tl {

namespace {
/**
* @brief Check if a buffer can use the Sunmmio blockwise ZZ layout.
*
* This function enforces a relaxed constraint:
* 1. 1D buffers are strictly excluded.
* 2. "Degenerated 1D" buffers (e.g., (1, 64), (1, 1, 128)) are excluded
* and should fallback to linear layouts.
* 3. All other multi-dimensional shapes (e.g., (2, 64), (1, 4, 64), (32, 32))
* are allowed to use the blockwise ZZ layout. The python-side generator
* will automatically pad these shapes to the nearest 32x32 block boundary.
*/
bool CanUseBlockwiseZZ(const Buffer &buf) {
if (buf->shape.size() < 2)
return false;

// Check if it's (1, 1, ..., 1, N) shape
bool all_ones_except_last = true;
for (size_t i = 0; i < buf->shape.size() - 1; ++i) {
if (const auto *imm = buf->shape[i].as<IntImmNode>()) {
if (imm->value != 1) {
all_ones_except_last = false;
break;
}
} else {
// If any dimension before the last is not a constant,
// we conservatively assume it's not a (1,1,...,1,N) shape.
all_ones_except_last = false;
break;
}
}

if (all_ones_except_last) {
return false;
}

return true;
}
} // namespace

using namespace tir;

// NormalizeToBufferRegion moved to src/op/utils.{h,cc}
Expand Down Expand Up @@ -711,8 +751,27 @@ Stmt ReduceOpNode::MakeSunmmioTileReduce(const LowerArgs &T,
alloc_buffers.push_back(dst_res.value());
}

Map<Buffer, Layout> local_layout_map;
const auto make_zz =
ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout");
ICHECK(make_zz != nullptr)
<< "Cannot find global function tl.layout.make_blockwise_zz_layout";

if (CanUseBlockwiseZZ(acc)) {
auto acc_layout = Downcast<Layout>((*make_zz)(acc));
local_layout_map.Set(acc, acc_layout);
} else {
auto acc_layout = makeLinearLayout(acc->shape);
local_layout_map.Set(acc, acc_layout);
}

if (dst_res.defined()) {
auto res_layout = makeLinearLayout(dst_res.value()->shape);
local_layout_map.Set(dst_res.value(), res_layout);
}

body = Block({}, {}, {}, "reduce_tile_op", body, std::nullopt, alloc_buffers,
{}, {});
{}, {{attr::kLayoutMap, local_layout_map}});

return body;
}
Expand Down Expand Up @@ -1032,8 +1091,67 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Stmt();
}

LayoutMap ReduceOpNode::InferLayoutSunmmioTileReduce(const LayoutInferArgs &T,
InferLevel level) const {
if (level == InferLevel::kStrict) {
auto src_scope = src.scope();
auto dst_scope = dst.scope();
ICHECK(src_scope == "shared.rsram" && dst_scope == "shared.rsram")
<< "For Sunmmio target, Reduce operator src and dst must be in "
"shared.rsram scope, but got "
<< src_scope << " and " << dst_scope;

LayoutMap result;
const auto make_zz =
ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout");
ICHECK(make_zz != nullptr)
<< "Cannot find global function tl.layout.make_blockwise_zz_layout";

if (CanUseBlockwiseZZ(src)) {
auto src_layout = Downcast<Layout>((*make_zz)(src));
result.Set(src, src_layout);
} else {
auto src_layout = makeLinearLayout(src->shape);
result.Set(src, src_layout);
}

Optional<TileView> opt_tv = FindTileView(T.tileview_map, src);
if (opt_tv.defined()) {
TileView tv = opt_tv.value();
bool is_tiled = false;
int src_ndim = src->shape.size();
for (size_t i = 0; i < tv->IndexMap().size(); i++) {
const auto *idx_ptr = tv->IndexMap()[i].as<IntImmNode>();
if (idx_ptr) {
int dim = idx_ptr->value;
if (dim < 0)
dim += src_ndim;
if (dim == this->dim) {
is_tiled = true;
break;
}
}
}

if (is_tiled || !CanUseBlockwiseZZ(dst)) {
auto dst_layout = makeLinearLayout(dst->shape);
result.Set(dst, dst_layout);
} else {
auto dst_layout = Downcast<Layout>((*make_zz)(dst));
result.Set(dst, dst_layout);
}
}
return result;
}
return {};
}

LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (TargetIsSunmmio(T.target)) {
return InferLayoutSunmmioTileReduce(T, level);
}

if (level >= InferLevel::kStrict)
return {};

Expand Down
3 changes: 3 additions & 0 deletions src/op/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class ReduceOpNode : public TileOperatorNode {
/// Sunmmio Tile-based reduction logic
Stmt MakeSunmmioTileReduce(const LowerArgs &T,
arith::Analyzer *analyzer) const;
/// Sunmmio layout inference logic
LayoutMap InferLayoutSunmmioTileReduce(const LayoutInferArgs &T,
InferLevel level) const;
};

/// Wrapper class for reduction operations
Expand Down
25 changes: 25 additions & 0 deletions src/op/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,31 @@ inline bool IsLocalVarBuffer(const Buffer &buffer) {
return buffer.defined() && buffer.scope() == "local.var";
}

// Helper to find TileView metadata for a buffer, supporting name-hint
// fallback. This is necessary because TVM may rename buffers (e.g.,
// adding suffixes like _1, _2) during lowering, which causes direct
// pointer-based lookup in tileview_map to fail.
inline Optional<TileView> FindTileView(const TileViewMap &tileview_map,
const Buffer &buf) {
if (tileview_map.count(buf->data)) {
return tileview_map.at(buf->data);
}
// Fallback: match by name hint, ignoring common suffixes like _1, _2.
auto simplify_name = [](std::string name) {
if (name.size() > 2 && name[name.size() - 2] == '_') {
return name.substr(0, name.size() - 2);
}
return name;
};
std::string target_name = simplify_name(buf->data->name_hint);
for (const auto &kv : tileview_map) {
if (simplify_name(kv.first->name_hint) == target_name) {
return kv.second;
}
}
return std::nullopt;
}

} // namespace tl
} // namespace tvm

Expand Down
13 changes: 12 additions & 1 deletion src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
buffer_oob,
{},
let_var_to_expr_,
global_layout_map_},
global_layout_map_,
tileview_map_},
level);

// Process the returned updates
Expand Down Expand Up @@ -755,6 +756,15 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
// BufferLoad/BufferStore
IRVisitorWithAnalyzer::VisitStmt_(op);

if (op->annotations.count(attr::kTileViewMap)) {
auto new_map = op->annotations.at(attr::kTileViewMap)
.as<Map<Var, TileView>>()
.value();
for (auto [k, v] : new_map) {
tileview_map_.Set(k, v);
}
}

// After visiting, apply layouts to all collected buffers
if (op->annotations.count(attr::kLayoutMap)) {
// Check if the layout map is Map<Var, Layout>
Expand Down Expand Up @@ -1023,6 +1033,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
Target target_;
LayoutMap annotated_layout_map_;
LayoutMap global_layout_map_;
TileViewMap tileview_map_;
bool skip_thread_partition_{false};

std::vector<TileOperator> BackupInferList() {
Expand Down
Loading