Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 36 additions & 17 deletions src/op/comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ BroadcastOp::BroadcastOp(Array<PrimExpr> args,
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
node->size = Downcast<IntImm>(args[2]);
node->dst_offset = Downcast<IntImm>(args[3]);
node->src_core = Downcast<IntImm>(args[4]);
node->src_core = args[4];
node->direction = Downcast<IntImm>(args[5])->value;
data_ = std::move(node);
}
Expand Down Expand Up @@ -145,9 +145,12 @@ Stmt BroadcastOpNode::Lower(const LowerArgs &T,
int mesh_ncol = mesh.ncol;

// check for valid core id
ICHECK(src_core->value >= 0 and src_core->value < mesh_nrow * mesh_ncol)
<< "Source core id " << src_core->value << " out of range [0, "
<< mesh_nrow * mesh_ncol << ")";
if (src_core.as<IntImmNode>()) {
int src_core_val = src_core.as<IntImmNode>()->value;
ICHECK(src_core_val >= 0 and src_core_val < mesh_nrow * mesh_ncol)
<< "Source core id " << src_core_val << " out of range [0, "
<< mesh_nrow * mesh_ncol << ")";
}

// check for src and dst buffer sizes
PrimExpr src_elements = 1;
Expand Down Expand Up @@ -197,7 +200,6 @@ Stmt BroadcastOpNode::Lower(const LowerArgs &T,
PrimExpr dst_addr =
dst.access_ptr(2, DataType::Handle(), 1,
Downcast<IntImm>(dst_offset->value), src_elements);
int src_core_col = src_core->value % mesh_ncol;

if (direction == 0 or direction == 1) {
// 1D broadcast
Expand All @@ -211,6 +213,11 @@ Stmt BroadcastOpNode::Lower(const LowerArgs &T,
return broadcast;
} else {
// 2D broadcast
ICHECK(src_core.as<IntImmNode>())
<< "2D broadcast only supports constant source core id.";
int src_core_val = src_core.as<IntImmNode>()->value;
int src_core_col = src_core_val % mesh_ncol;

Array<Stmt> seq;
// vertical broadcast
Array<PrimExpr> args;
Expand Down Expand Up @@ -255,8 +262,8 @@ PutOp::PutOp(Array<PrimExpr> args, Map<String, ObjectRef> annotations) {
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
node->size = Downcast<IntImm>(args[2]);
node->src_core = Downcast<IntImm>(args[3]);
node->dst_core = Downcast<IntImm>(args[4]);
node->src_core = args[3];
node->dst_core = args[4];
data_ = std::move(node);
}

Expand Down Expand Up @@ -287,12 +294,18 @@ Stmt PutOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int mesh_ncol = mesh.ncol;

// check for valid core id
ICHECK(src_core->value >= 0 and src_core->value < mesh_nrow * mesh_ncol)
<< "Source core id " << src_core->value << " out of range [0, "
<< mesh_nrow * mesh_ncol << ")";
ICHECK(dst_core->value >= 0 and dst_core->value < mesh_nrow * mesh_ncol)
<< "Destination core id " << dst_core->value << " out of range [0, "
<< mesh_nrow * mesh_ncol << ")";
if (src_core.as<IntImmNode>()) {
int src_core_val = src_core.as<IntImmNode>()->value;
ICHECK(src_core_val >= 0 and src_core_val < mesh_nrow * mesh_ncol)
<< "Source core id " << src_core_val << " out of range [0, "
<< mesh_nrow * mesh_ncol << ")";
}
if (dst_core.as<IntImmNode>()) {
int dst_core_val = dst_core.as<IntImmNode>()->value;
ICHECK(dst_core_val >= 0 and dst_core_val < mesh_nrow * mesh_ncol)
<< "Destination core id " << dst_core_val << " out of range [0, "
<< mesh_nrow * mesh_ncol << ")";
}

// check for src and dst buffer sizes
PrimExpr src_elements = 1;
Expand Down Expand Up @@ -334,10 +347,16 @@ Stmt PutOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
// all checks passed, generate the call
PrimExpr src_addr = src.access_ptr(1, DataType::Handle(), 1, 0, src_elements);
PrimExpr dst_addr = dst.access_ptr(2, DataType::Handle(), 1, 0, dst_elements);
int src_core_row = src_core->value / mesh_ncol;
int src_core_col = src_core->value % mesh_ncol;
int dst_core_row = dst_core->value / mesh_ncol;
int dst_core_col = dst_core->value % mesh_ncol;
ICHECK(src_core.as<IntImmNode>())
<< "Put only supports constant source core id.";
ICHECK(dst_core.as<IntImmNode>())
<< "Put only supports constant destination core id.";
int src_core_val = src_core.as<IntImmNode>()->value;
int dst_core_val = dst_core.as<IntImmNode>()->value;
int src_core_row = src_core_val / mesh_ncol;
int src_core_col = src_core_val % mesh_ncol;
int dst_core_row = dst_core_val / mesh_ncol;
int dst_core_col = dst_core_val % mesh_ncol;

if (src_core_row == dst_core_row) {
// 1D put via horizontal communication
Expand Down
12 changes: 6 additions & 6 deletions src/op/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class BroadcastOpNode : public TileOperatorNode {
PrimExpr src_expr, dst_expr;
IntImm size;
IntImm dst_offset;
IntImm src_core;
PrimExpr src_core;
int direction;

TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_broadcast", BroadcastOpNode,
Expand All @@ -44,7 +44,7 @@ class BroadcastOpNode : public TileOperatorNode {
.def_ro("dst_offset", &BroadcastOpNode::dst_offset);
}

TileOperator Clone() const;
TileOperator Clone() const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
Expand All @@ -64,7 +64,7 @@ class PutOpNode : public TileOperatorNode {
Buffer src, dst;
Array<Range> src_range, dst_range;
PrimExpr src_expr, dst_expr;
IntImm src_core, dst_core;
PrimExpr src_core, dst_core;
IntImm size;

TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_put", PutOpNode, TileOperatorNode);
Expand All @@ -81,7 +81,7 @@ class PutOpNode : public TileOperatorNode {
.def_ro("size", &PutOpNode::size);
}

TileOperator Clone() const;
TileOperator Clone() const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
Expand Down Expand Up @@ -112,7 +112,7 @@ class AllgatherOpNode : public TileOperatorNode {
.def_ro("size", &AllgatherOpNode::size);
}

TileOperator Clone() const;
TileOperator Clone() const override;
LayoutMap ComputeLayout(const LayoutInferArgs &T, InferLevel level,
Buffer src, Buffer dst) const;
LayoutMap InferLayout(const LayoutInferArgs &T,
Expand Down Expand Up @@ -156,7 +156,7 @@ class AllreduceOpNode : public TileOperatorNode {
.def_ro("dst_copy", &AllreduceOpNode::dst_copy);
}

TileOperator Clone() const;
TileOperator Clone() const override;
LayoutMap ComputeLayout(const LayoutInferArgs &T, InferLevel level,
Buffer src, Buffer dst, int dim) const;
LayoutMap InferLayout(const LayoutInferArgs &T,
Expand Down
Loading
Loading