Skip to content

Commit 3b89331

Browse files
authored
Merge branch 'main' into addflaggemstest
2 parents 392a782 + dce5335 commit 3b89331

File tree

324 files changed

+1146
-32875
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

324 files changed

+1146
-32875
lines changed

.github/workflows/mthreads-build-and-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ jobs:
6868
./python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt --pass-pipeline='builtin.module(convert-triton-to-tritongpu{target="cuda:CC" num-warps=4 threads-per-warp=32 num-ctas=1})' ./test/bin/mthreads/add_kernel.ttir
6969
./python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt --convert-ub-to-llvm ./test/bin/mthreads/add_kernel.ttgir
7070
python3 -m pytest -s third_party/mthreads/python/test/unit
71+
MUSA_VISIBLE_DEVICES=7 python3 -m pytest -s third_party/mthreads/python/test/unit
7172
7273
- name: FlagTree Test with FlagGems
7374
shell: bash
@@ -78,4 +79,3 @@ jobs:
7879
cd FlagGems
7980
git checkout addflagtreetest
8081
python3 -m pytest -s --mode=quick --limit-cases=1 --skipped-unselected-nodeids
81-

CMakeLists.txt

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ elseif(FLAGTREE_BACKEND STREQUAL "iluvatar")
2929
set (DEFAULT_PLUGIN_DIR "${Python3_SITELIB}/triton/_C")
3030
endif()
3131
add_definitions(-DDEFAULT_PLUGIN_DIR="${DEFAULT_PLUGIN_DIR}")
32-
add_compile_options("-Wno-deprecated-declarations")
33-
add_compile_options("-Wno-error=deprecated-declarations")
3432
elseif(FLAGTREE_BACKEND STREQUAL "mthreads")
3533
set(ENV{PATH} "$ENV{LLVM_SYSPATH}/bin:$ENV{PATH}")
3634
set(CMAKE_C_COMPILER clang)
@@ -62,6 +60,18 @@ function(get_flagtree_backend_lib lib_name output_lib)
6260
set(${output_lib} ${ret} PARENT_SCOPE)
6361
endfunction()
6462

63+
# FLAGTREE SPEC TD FILE GET FUNC
64+
function(set_flagtree_backend_td output_td td_filename)
65+
set(ret ${td_filename})
66+
file(RELATIVE_PATH relative_path "${PROJECT_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}")
67+
get_filename_component(BACKEND_SPEC_ROOT "${BACKEND_SPEC_INCLUDE_DIR}" DIRECTORY)
68+
set(BACKEND_SPEC_TD ${BACKEND_SPEC_ROOT}/${relative_path}/${td_filename})
69+
if(EXISTS ${BACKEND_SPEC_TD})
70+
set(ret ${BACKEND_SPEC_TD})
71+
endif()
72+
set(${output_td} ${ret} PARENT_SCOPE)
73+
endfunction()
74+
6575
project(triton)
6676
include(CTest)
6777

@@ -101,13 +111,14 @@ if(NOT WIN32)
101111
endif()
102112

103113
# Compiler flags
114+
set(FLAGTREE_BACKEND_DIR ${PROJECT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND})
104115
## flagtree spec include dir
105-
set(BACKEND_SPEC_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/backend/flagtree_backend_specialization/include)
116+
set(BACKEND_SPEC_INCLUDE_DIR ${FLAGTREE_BACKEND_DIR}/backend/spec/include)
106117
if(FLAGTREE_BACKEND AND EXISTS ${BACKEND_SPEC_INCLUDE_DIR})
107118
include_directories(${BACKEND_SPEC_INCLUDE_DIR})
108119
endif()
109120
## flagtree third_party include dir
110-
set(BACKEND_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/include)
121+
set(BACKEND_INCLUDE_DIR ${FLAGTREE_BACKEND_DIR}/include)
111122
if(FLAGTREE_BACKEND AND EXISTS "${BACKEND_INCLUDE_DIR}")
112123
include_directories(${BACKEND_INCLUDE_DIR})
113124
else()
@@ -118,9 +129,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17
118129

119130
if(FLAGTREE_BACKEND STREQUAL "metax")
120131
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_MACA -DUSE_MACA_OPAQUE_PTR -DUSE_BUILTIN -Wno-unused-result -Wno-attributes")
121-
endif()
122-
123-
if(FLAGTREE_BACKEND STREQUAL "hcu")
132+
elseif(FLAGTREE_BACKEND STREQUAL "hcu")
124133
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -Wno-error=return-type -std=gnu++17")
125134
endif()
126135

@@ -222,7 +231,7 @@ endif()
222231
# ------
223232
if(TRITON_BUILD_PYTHON_MODULE)
224233
message(STATUS "Adding Python module")
225-
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/python/src)
234+
set(PYTHON_SRC_PATH ${FLAGTREE_BACKEND_DIR}/python/src)
226235
if(NOT (FLAGTREE_BACKEND AND EXISTS "${PYTHON_SRC_PATH}"))
227236
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
228237
endif()

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
<div align="right"><a href="/README_cn.md">中文版</a></div>
22

3-
## FlagTree
3+
## <img width="30" height="30" alt="FlagTree-GitHub" src="https://github.com/user-attachments/assets/d8d24c81-6f46-4adc-94e2-b89b03afcb43" /> FlagTree
44

55
FlagTree is an open source, unified compiler for multiple AI chips project dedicated to developing a diverse ecosystem of AI chip compilers and related tooling platforms, thereby fostering and strengthening the upstream and downstream Triton ecosystem. Currently in its initial phase, the project aims to maintain compatibility with existing adaptation solutions while unifying the codebase to rapidly implement single-repository multi-backend support. For upstream model users, it provides unified compilation capabilities across multiple backends; for downstream chip manufacturers, it offers examples of Triton ecosystem integration.
66

77
## Latest News
88

9+
* 2025/11/26 Add FlagTree_Backend_Specialization Unified Design Document [FlagTree_Backend_Specialization](reports/decoupling/).
910
* 2025/09/30 Support flagtree_hints for shared memory on GPGPU.
1011
* 2025/09/29 SDK storage migrated to ksyuncs, improving download stability.
1112
* 2025/09/25 Support flagtree_hints for ascend backend compilation capability.

README_cn.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
<div align="right"><a href="/README.md">English</a></div>
22

3-
## FlagTree
3+
## <img width="30" height="30" alt="FlagTree-GitHub" src="https://github.com/user-attachments/assets/d8d24c81-6f46-4adc-94e2-b89b03afcb43" /> FlagTree
44

55
FlagTree 是面向多种 AI 芯片的开源、统一编译器。FlagTree 致力于打造多元 AI 芯片编译器及相关工具平台,发展和壮大 Triton 上下游生态。项目当前处于初期,目标是兼容现有适配方案,统一代码仓库,快速实现单仓库多后端支持。对于上游模型用户,提供多后端的统一编译能力;对于下游芯片厂商,提供 Triton 生态接入范例。
66

77
## 新特性
8+
* 2025/11/26 添加 FlagTree 后端特化统一设计文档 [FlagTree_Backend_Specialization](reports/decoupling/)
89
* 2025/09/30 在 GPGPU 上支持编译指导 shared memory。
910
* 2025/09/29 SDK 存储迁移至金山云,大幅提升下载稳定性。
1011
* 2025/09/25 支持编译指导 ascend 的后端编译能力。

include/triton/Analysis/Allocation.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,36 @@
99

1010
#include "triton/Dialect/Triton/IR/Dialect.h"
1111
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
12+
#ifdef __NVIDIA__
1213
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
14+
#endif
1315
#include <atomic>
1416
#include <limits>
1517

18+
#if __has_include("flagtree_spec.h")
19+
#include "flagtree_spec.h"
20+
#endif
21+
1622
namespace mlir {
1723

1824
namespace triton {
25+
26+
#ifdef FLAGTREE_SPEC_Analysis_Allocation_getCvtOrder
27+
std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
28+
getCvtOrder(Attribute srcLayout, Attribute dstLayout);
29+
#endif
30+
1931
class AllocationAnalysis;
2032

2133
SmallVector<unsigned>
2234
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
2335
unsigned &outVec);
2436
SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op);
2537

38+
#ifdef FLAGTREE_SPEC_Analysis_Allocation_AllocationAnalysis_getScratchValueSizeElems
39+
unsigned getScratchValueSizeElems(const SmallVector<unsigned> &smemShape);
40+
#endif
41+
2642
} // namespace triton
2743

2844
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
@@ -175,6 +191,12 @@ class Allocation {
175191
/// BufferId -> Buffer
176192
using BufferSetT = std::map<BufferId, BufferT>;
177193

194+
#ifdef FLAGTREE_SPEC_Analysis_Allocation_AllocationAnalysis_dump
195+
public:
196+
friend class AllocationAnalysis;
197+
static void dump(llvm::MapVector<BufferT *, Interval<size_t>> bufferRange);
198+
#endif
199+
178200
private:
179201
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
180202
void addBuffer(KeyType &key, Args &&...args) {

include/triton/Analysis/AxisInfo.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,18 @@
1313
#include <optional>
1414
#include <type_traits>
1515

16+
#if __has_include("flagtree_spec.h")
17+
#include "flagtree_spec.h"
18+
#endif
19+
1620
namespace mlir::triton {
1721

1822
//===----------------------------------------------------------------------===//
1923
// AxisInfo
2024
//===----------------------------------------------------------------------===//
2125

2226
/// This lattice value represents known information on the axes of a lattice.
27+
#ifndef FLAGTREE_SPEC_AxisInfo
2328
class AxisInfo {
2429
public:
2530
typedef SmallVector<int64_t> DimVectorT;
@@ -151,6 +156,7 @@ class AxisInfo {
151156
// The constant value of the lattice if we can infer it.
152157
std::optional<int64_t> constantValue;
153158
};
159+
#endif
154160

155161
// Module level axis info analysis based on the call graph, assuming that we do
156162
// not have recursive functions.

include/triton/Analysis/Membar.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66

77
#include <set>
88

9+
#if __has_include("flagtree_spec.h")
10+
#include "flagtree_spec.h"
11+
#endif
12+
913
namespace mlir {
1014

1115
class OpBuilder;
@@ -43,6 +47,16 @@ struct BlockInfo {
4347
syncWriteIntervals.clear();
4448
}
4549

50+
#ifdef FLAGTREE_SPEC_BlockInfo_erase
51+
// type: 0 all | 1 del W from other R |2 del R from other W
52+
void erase(BlockInfo &other, int type = 0);
53+
#endif
54+
55+
#ifdef FLAGTREE_SPEC_BlockInfo_printIntervals
56+
// for debug
57+
void printIntervals();
58+
#endif
59+
4660
/// Compares two BlockInfo objects.
4761
bool operator==(const BlockInfo &other) const {
4862
return syncReadIntervals == other.syncReadIntervals &&

include/triton/Analysis/Utility.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
#include "triton/Dialect/Triton/IR/Dialect.h"
88
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
99

10+
#if __has_include("flagtree_spec.h")
11+
#include "flagtree_spec.h"
12+
#endif
13+
1014
namespace mlir {
1115

1216
inline bool isZeroConst(Value v) {
@@ -194,6 +198,10 @@ bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
194198

195199
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
196200

201+
#ifdef FLAGTREE_SPEC_Analysis_Utility_isMmaToMmaShortcut
202+
bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding);
203+
#endif
204+
197205
// Return true if the src and dst layout match.
198206
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
199207
RankedTensorType dstTy);
@@ -210,10 +218,35 @@ bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);
210218
SetVector<Operation *>
211219
multiRootTopologicalSort(const SetVector<Operation *> &toSort);
212220

221+
#ifdef FLAGTREE_SPEC_Utility_isMmaToDotSlowShortcut
222+
bool isMmaToDotSlowShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
223+
#endif
224+
225+
#ifdef FLAGTREE_SPEC_Utility_getBackwardSliceCorex
226+
/// This function dones't use assertion check.
227+
void getBackwardSliceCorex(Operation *op, SetVector<Operation *> *backwardSlice,
228+
TransitiveFilter filter = nullptr,
229+
bool omitBlockArguments = false);
230+
#endif
231+
232+
#ifdef FLAGTREE_SPEC_Utility_getBackwardSliceImplCorex
233+
void getBackwardSliceImplCorex(Operation *op,
234+
SetVector<Operation *> *backwardSlice,
235+
TransitiveFilter filter,
236+
bool omitBlockArguments = false);
237+
238+
#endif
239+
213240
/// This uses the toplogicalSort above
214241
SetVector<Operation *>
215242
multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr,
243+
#ifndef FLAGTREE_SPEC_Utility_multiRootGetSlice_ARG
216244
TransitiveFilter forwardFilter = nullptr);
245+
#else
246+
TransitiveFilter forwardFilter = nullptr,
247+
FLAGTREE_SPEC_Utility_multiRootGetSlice_ARG
248+
omitBlockArguments = true);
249+
#endif
217250

218251
/// Create a basic DataFlowSolver with constant and dead code analysis included.
219252
std::unique_ptr<DataFlowSolver> createDataFlowSolver();

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
99
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
1010

11+
#if __has_include("flagtree_spec.h")
12+
#include "flagtree_spec.h"
13+
#endif
14+
1115
using namespace mlir;
1216
using namespace mlir::triton;
1317

@@ -28,6 +32,15 @@ SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
2832

2933
Type getElementType(Value value);
3034

35+
#ifdef FLAGTREE_SPEC_ElementwiseOpConversionBase_maybeDeduplicate
36+
bool maybeDeduplicate_baseEncoding(Attribute baseEncoding);
37+
#endif
38+
39+
#ifdef FLAGTREE_SPEC_ElementwiseOpConversionBase_matchAndRewrite
40+
void matchAndRewrite_elemTy(const mlir::TypeConverter *typeConverter,
41+
mlir::Type &elemTy, const mlir::Type &resultTy);
42+
#endif
43+
3144
class MultipleOperandsRange
3245
: public iterator_range<SmallVector<SmallVector<Value>>::iterator> {
3346
using ContainerT = SmallVector<SmallVector<Value>>;
@@ -102,6 +115,10 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
102115
// test_core::test_fp8_dot_acc
103116
return resultVals;
104117
}
118+
#ifdef FLAGTREE_SPEC_ElementwiseOpConversionBase_maybeDeduplicate
119+
if (maybeDeduplicate_baseEncoding(baseEncoding))
120+
return resultVals;
121+
#endif
105122

106123
SmallVector<unsigned> elemsPerThread = getElemsPerThread(rtType);
107124
int rank = elemsPerThread.size();
@@ -182,6 +199,9 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
182199
// element type
183200
auto resultElementTy = getElementTypeOrSelf(resultTy);
184201
Type elemTy = this->getTypeConverter()->convertType(resultElementTy);
202+
#ifdef FLAGTREE_SPEC_ElementwiseOpConversionBase_matchAndRewrite
203+
matchAndRewrite_elemTy(this->getTypeConverter(), elemTy, resultTy);
204+
#endif
185205
SmallVector<SmallVector<Value>> allOperands;
186206
for (auto operand : adaptor.getOperands()) {
187207
auto argTy = op->getOperand(0).getType();

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
#include "triton/Conversion/MLIRTypes.h"
55

6+
#if __has_include("flagtree_spec.h")
7+
#include "flagtree_spec.h"
8+
#endif
9+
610
namespace mlir::triton {
711
class TargetInfoBase {
812
public:
@@ -13,11 +17,18 @@ class TargetInfoBase {
1317
virtual Value ballot(ConversionPatternRewriter &rewriter, Location loc,
1418
Type type, Value cmp) const = 0;
1519

20+
#ifdef FLAGTREE_SPEC_TargetInfoBase_function
21+
virtual Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
22+
Value ptr, Value val, Value pred) const = 0;
23+
virtual Value loadShared(ConversionPatternRewriter &rewriter, Location loc,
24+
Value ptr, Type elemTy, Value pred) const = 0;
25+
#else
1626
virtual void storeShared(ConversionPatternRewriter &rewriter, Location loc,
1727
Value ptr, Value val, Value pred) const = 0;
1828
virtual Value loadShared(ConversionPatternRewriter &rewriter, Location loc,
1929
const TypeConverter *converter, Value ptr,
2030
Type elemTy, Value pred) const = 0;
31+
#endif
2132

2233
virtual Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc,
2334
Value val, int i) const = 0;

0 commit comments

Comments
 (0)