Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c321c2e
implement the critical path schedule algorithm
wanghz18 Mar 3, 2026
d82ce8d
implementation of inject pipeline
wanghz18 Mar 10, 2026
2619a8b
rebase
wanghz18 Mar 10, 2026
47f6d1c
rebase
wanghz18 Mar 10, 2026
61e96ef
code style
wanghz18 Mar 10, 2026
88ebec0
refine the algorithm to avoid the read-write loop
wanghz18 Mar 11, 2026
7f59dc3
Memoization for b_level
wanghz18 Mar 11, 2026
4999702
optimize algorithm with 3 phases, refactor the sunmmio_pipeline_plann…
wanghz18 Mar 12, 2026
ad7a10d
disable log
wanghz18 Mar 12, 2026
edd7a16
Merge branch 'tilelang_mesh_main' of github.com:SUNMMIO/Tilelang-Mesh…
wanghz18 Mar 12, 2026
653e5ee
Merge branch 'SUNMMIO:tilelang_mesh_main' into pipeline_rebase
wanghz18 Mar 12, 2026
5ce0b92
update algorithm
wanghz18 Mar 16, 2026
a39d8b6
for pipeline algorithm
wanghz18 Mar 16, 2026
4d5bd14
Merge branch 'tilelang_mesh_main' of github.com:SUNMMIO/Tilelang-Mesh…
wanghz18 Mar 16, 2026
243f2b5
Merge branch 'SUNMMIO:tilelang_mesh_main' into pipeline_rebase
wanghz18 Mar 16, 2026
fc7b16f
Merge branch 'pipeline_rebase' of github.com:wanghz18/Tilelang-Mesh i…
wanghz18 Mar 16, 2026
f6841df
move debug control to python
wanghz18 Mar 16, 2026
83b7c75
for pull request
wanghz18 Mar 24, 2026
d6dbfc9
Merge branch 'SUNMMIO:tilelang_mesh_main' into pipeline_rebase
wanghz18 Mar 24, 2026
a6fbddd
Merge branch 'tilelang_mesh_main' of github.com:SUNMMIO/Tilelang-Mesh…
wanghz18 Mar 24, 2026
1939a41
Merge branch 'pipeline_rebase' of github.com:wanghz18/Tilelang-Mesh i…
wanghz18 Mar 25, 2026
7d13afb
for ci
wanghz18 Mar 25, 2026
5e09186
Merge branch 'tilelang_mesh_main' of github.com:SUNMMIO/Tilelang-Mesh…
wanghz18 Apr 1, 2026
e55d12c
Merge branch 'SUNMMIO:tilelang_mesh_main' into pipeline_rebase
wanghz18 Apr 5, 2026
2cd6155
update pipeline algorithm
wanghz18 Apr 15, 2026
4003165
Merge branch 'pipeline_rebase' of github.com:wanghz18/Tilelang-Mesh i…
wanghz18 Apr 15, 2026
85ea9d3
Merge branch 'SUNMMIO:tilelang_mesh_main' into pipeline_rebase
wanghz18 Apr 15, 2026
397ebbb
Merge branch 'tilelang_mesh_main' of github.com:SUNMMIO/Tilelang-Mesh…
wanghz18 Apr 15, 2026
dd06f4f
withdraw changes in fill.cc
wanghz18 Apr 15, 2026
468a31d
Merge branch 'pipeline_rebase' of github.com:wanghz18/Tilelang-Mesh i…
wanghz18 Apr 15, 2026
613d231
for ci
wanghz18 Apr 15, 2026
685abdb
for ci
wanghz18 Apr 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/op/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ inline bool IsSharedBuffer(const Buffer &buffer, bool allow_dynamic = true) {
}
}

inline bool IsSunmmioSharedBuffer(const Buffer &buffer) {
return buffer.defined() &&
(buffer.scope() == "shared.asram" ||
buffer.scope() == "shared.wsram" || buffer.scope() == "shared.rsram");
}

inline bool IsGlobalBuffer(const Buffer &buffer) {
return buffer.defined() && buffer.scope() == "global";
}
Expand Down
317 changes: 317 additions & 0 deletions src/transform/common/ast_traverser.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
#ifndef AST_TRAVERSER_H
#define AST_TRAVERSER_H

#include "../../op/builtin.h"
#include "../../op/utils.h"
#include "tvm/ir/expr.h"
#include "tvm/runtime/logging.h"
#include "tvm/tir/buffer.h"
#include "tvm/tir/expr.h"
#include "tvm/tir/function.h"
#include "tvm/tir/stmt.h"
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tl {

using namespace tir;

class BufferAccessCollector : public ExprVisitor {
public:
BufferAccessCollector(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {}

Array<BufferRegion> GetReads() const { return reads_; }
Array<BufferRegion> GetWrites() const { return writes_; }

private:
void VisitExpr_(const BufferLoadNode *op) final {
auto load_buffer = op->buffer;
Array<PrimExpr> indices = op->indices;
// convert indices to region
Array<Range> region;
for (const auto &index : indices) {
region.push_back(Range::FromMinExtent(index, 1));
}
auto load_region = BufferRegion(load_buffer, region);
reads_.push_back(load_region);
}

void VisitExpr_(const CallNode *op) final {
auto args = op->args;
if (op->op.same_as(builtin::address_of())) {
BufferRegion buffer_region;
if (const auto *load = op->args[0].as<BufferLoadNode>()) {
buffer_region = BufferRegion::FullRegion(load->buffer);
} else if (const auto *var_node = op->args[0].as<VarNode>()) {
Var data_var = tvm::ffi::GetRef<Var>(var_node);
auto it = buffer_data_to_buffer_.find(data_var);
if (it != buffer_data_to_buffer_.end()) {
buffer_region = BufferRegion::FullRegion((*it).second);
}
}
if (buffer_region.defined()) {
reads_.push_back(buffer_region);
}
} else if (op->op.same_as(builtin::tvm_access_ptr())) {
const VarNode *buffer_var = op->args[1].as<VarNode>();
ICHECK(buffer_var);
auto it = buffer_data_to_buffer_.find(tvm::ffi::GetRef<Var>(buffer_var));
if (it != buffer_data_to_buffer_.end()) {
const Buffer &buffer = (*it).second;
const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
reads_.push_back(buffer_region);
}
}
// else if (op->op.same_as(tl::mbarrier_wait_parity())) {
// ICHECK(args[0].as<BufferLoadNode>());
// Buffer mbar_buf = args[0].as<BufferLoadNode>()->buffer;
// auto buffer_reads =
// chain_builder_.mbar_to_buffer_reads_.find(mbar_buf.get());
// auto buffer_writes =
// chain_builder_.mbar_to_buffer_writes_.find(mbar_buf.get());
// if (buffer_reads != chain_builder_.mbar_to_buffer_reads_.end()) {
// reads_.insert(reads_.end(), buffer_reads->second.begin(),
// buffer_reads->second.end());
// }
// if (buffer_writes != chain_builder_.mbar_to_buffer_writes_.end()) {
// writes_.insert(
// writes_.end(),
// chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).begin(),
// chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).end());
// }
// }

else {
ExprVisitor::VisitExpr_(op);
}
}

private:
Array<BufferRegion> reads_;
Array<BufferRegion> writes_;
Map<Var, Buffer> buffer_data_to_buffer_;
};

class ASTTraverser : public StmtVisitor {
public:
ASTTraverser(const PrimFunc &f) {
for (const auto &[_, buffer] : f->buffer_map) {
this->buffer_data_to_buffer_.Set(buffer->data, buffer);
}
}

std::pair<Array<BufferRegion>, Array<BufferRegion>>
buffer_region_collector(const PrimExpr &expr) {
auto buf_load_collector = BufferAccessCollector(buffer_data_to_buffer_);
buf_load_collector(expr);
Array<BufferRegion> read_regions = buf_load_collector.GetReads();
Array<BufferRegion> write_regions = buf_load_collector.GetWrites();
return {read_regions, write_regions};
}

void VisitStmt_(const AttrStmtNode *op) {
auto [read_regions, write_regions] = buffer_region_collector(op->value);
for (auto it : read_regions) {
read_buffer_regions_.insert(it);
}
for (auto it : write_regions) {
write_buffer_regions_.insert(it);
}
StmtVisitor::VisitStmt_(op);
}

void VisitStmt_(const LetStmtNode *op) {
auto [read_regions, write_regions] = buffer_region_collector(op->value);
for (auto it : read_regions) {
read_buffer_regions_.insert(it);
}
for (auto it : write_regions) {
write_buffer_regions_.insert(it);
}
StmtVisitor::VisitStmt_(op);
}

void VisitStmt_(const ForNode *op) {
auto [min_read_regions, min_write_regions] =
buffer_region_collector(op->min);
for (auto it : min_read_regions) {
read_buffer_regions_.insert(it);
}
for (auto it : min_write_regions) {
write_buffer_regions_.insert(it);
}

auto [extent_read_regions, extent_write_regions] =
buffer_region_collector(op->extent);
for (auto it : extent_read_regions) {
read_buffer_regions_.insert(it);
}
for (auto it : extent_write_regions) {
write_buffer_regions_.insert(it);
}
StmtVisitor::VisitStmt_(op);
}

void VisitStmt_(const WhileNode *op) {
auto [read_regions, write_regions] = buffer_region_collector(op->condition);
for (auto it : read_regions) {
read_buffer_regions_.insert(it);
}
for (auto it : write_regions) {
write_buffer_regions_.insert(it);
}
StmtVisitor::VisitStmt_(op);
}

void VisitStmt_(const AllocateNode *op) {
auto [read_regions, write_regions] = buffer_region_collector(op->condition);
for (auto it : read_regions) {
read_buffer_regions_.insert(it);
}
for (auto it : write_regions) {
write_buffer_regions_.insert(it);
}
StmtVisitor::VisitStmt_(op);
}

void VisitStmt_(const BufferRealizeNode *op) {
auto [read_regions, write_regions] = buffer_region_collector(op->condition);
for (auto it : read_regions) {
read_buffer_regions_.insert(it);
}
for (auto it : write_regions) {
write_buffer_regions_.insert(it);
}
StmtVisitor::VisitStmt_(op);
}

void VisitStmt_(const AssertStmtNode *op) {
auto [condition_read_regions, condition_write_regions] =
buffer_region_collector(op->condition);
for (auto it : condition_read_regions) {
read_buffer_regions_.insert(it);
}
for (auto it : condition_write_regions) {
write_buffer_regions_.insert(it);
}

auto [message_read_regions, message_write_regions] =
buffer_region_collector(op->message);
for (auto it : message_read_regions) {
read_buffer_regions_.insert(it);
}
for (auto it : message_write_regions) {
write_buffer_regions_.insert(it);
}
StmtVisitor::VisitStmt_(op);
}

void VisitStmt_(const BlockRealizeNode *op) {
auto [read_regions, write_regions] = buffer_region_collector(op->predicate);
for (auto it : read_regions) {
read_buffer_regions_.insert(it);
}
for (auto it : write_regions) {
write_buffer_regions_.insert(it);
}
StmtVisitor::VisitStmt_(op);
}

void VisitStmt_(const BufferStoreNode *op) {
// For a buffer store statement, we need to check the dependencies for the
// buffer to be stored. For example, in the statement A[i] = B[j] + C[k], we
// need to check the dependencies for the buffer A.
Buffer store_buffer = op->buffer;
Array<PrimExpr> indices = op->indices;
// convert indices to region
Array<Range> region;
for (const auto &index : indices) {
region.push_back(Range::FromMinExtent(index, 1));
}
auto store_region = BufferRegion(store_buffer, region);
write_buffer_regions_.insert(store_region);

// For a store statement, we also need to check the read dependencies for
// the value to be stored. For example, in the statement A[i] = B[j] + C[k],
// we need to check the read dependencies for the buffers B and C.
auto [read_regions, write_regions] = buffer_region_collector(op->value);
for (auto it : read_regions) {
read_buffer_regions_.insert(it);
}
for (auto it : write_regions) {
write_buffer_regions_.insert(it);
}
StmtVisitor::VisitStmt_(op);
}

void VisitStmt_(const EvaluateNode *op) {
const CallNode *call = op->value.as<CallNode>();
if (call->op.same_as(dma_copy())) {
read_buffer_regions_.insert(NormalizeToBufferRegion(call->args[0]));
write_buffer_regions_.insert(NormalizeToBufferRegion(call->args[1]));
} else if (call->op.same_as(mma_sunmmio())) {
read_buffer_regions_.insert(NormalizeToBufferRegion(call->args[0]));
read_buffer_regions_.insert(NormalizeToBufferRegion(call->args[1]));
read_buffer_regions_.insert(NormalizeToBufferRegion(call->args[2]));

write_buffer_regions_.insert(NormalizeToBufferRegion(call->args[2]));
// } else if (call->op.same_as(broadcast_())) {
// read_buffer_regions_.insert(NormalizeToBufferRegion(call->args[0]));
// write_buffer_regions_.insert(NormalizeToBufferRegion(call->args[1]));
} else {
auto [read_regions, write_regions] = buffer_region_collector(op->value);
for (auto it : read_regions) {
read_buffer_regions_.insert(it);
}
for (auto it : write_regions) {
write_buffer_regions_.insert(it);
}
}
StmtVisitor::VisitStmt_(op);
}

void VisitStmt_(const BlockNode *op) final {
for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
StmtVisitor::VisitStmt_(op);
for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data);
}
}

void clear() {
read_buffer_regions_.clear();
write_buffer_regions_.clear();
}

void traverse_stmt(Stmt stmt) {
clear();
VisitStmt(stmt);
}

void traverse_expr(PrimExpr expr) {
clear();
buffer_region_collector(expr);
}

public:
Map<Var, Buffer> buffer_data_to_buffer_;

std::set<BufferRegion> read_buffer_regions_;
std::set<BufferRegion> write_buffer_regions_;
};

} // namespace tl
} // namespace tvm

#endif
18 changes: 18 additions & 0 deletions src/transform/common/sunmmio_pipeline_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef SUNMMIO_PIPELINE_UTILS_H
#define SUNMMIO_PIPELINE_UTILS_H

#include <string>

namespace tvm {
namespace tl {
inline int name2iter(const std::string &name) {
return std::stoi(name.substr(0, name.find('-')));
}

inline int name2id(const std::string &name) {
return std::stoi(name.substr(name.find('-') + 1));
}

} // namespace tl
} // namespace tvm
#endif
Loading
Loading