Skip to content

Commit 0625715

Browse files
authored
[MathToVecLib] Add support for setting bit-widths for AVX512, AVX, and SSE to prevent "Illegal instruction (core dumped)" (#234)
* [MathToVecLib] Add support for setting bit-widths for AVX512, AVX, and SSE to prevent "Illegal instruction (core dumped)" * [MathToVecLib] Fix incorrect vec_size_in_bits update method and initialization * [MathToVecLib] Add tests for generating SLEEF functions with different ISA. * [MathToVecLib] Remove unrelated huge vector contact from test * [MathToVecLib] Fix coding style issues and apply necessary adjustments * [MathToVecLib] Fix code formatting issues updated by pre-commit
1 parent d3dd504 commit 0625715

File tree

3 files changed

+59
-4
lines changed

3 files changed

+59
-4
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: triton-opt %s -split-input-file -triton-cpu-math-to-vec-lib="cpu_features=sse" | FileCheck %s --check-prefix=CHECK-SSE
2+
// RUN: triton-opt %s -split-input-file -triton-cpu-math-to-vec-lib="cpu_features=sse,sse2,sse3" | FileCheck %s --check-prefix=CHECK-SSE
3+
// RUN: triton-opt %s -split-input-file -triton-cpu-math-to-vec-lib="cpu_features=avx" | FileCheck %s --check-prefix=CHECK-AVX
4+
// RUN: triton-opt %s -split-input-file -triton-cpu-math-to-vec-lib="cpu_features=avx,avx2" | FileCheck %s --check-prefix=CHECK-AVX
5+
// RUN: triton-opt %s -split-input-file -triton-cpu-math-to-vec-lib="cpu_features=avx,sse" | FileCheck %s --check-prefix=CHECK-AVX
6+
// RUN: triton-opt %s -split-input-file -triton-cpu-math-to-vec-lib="cpu_features=avx512f" | FileCheck %s --check-prefix=CHECK-AVX512F
7+
// RUN: triton-opt %s -split-input-file -triton-cpu-math-to-vec-lib="cpu_features=avx512f,avx" | FileCheck %s --check-prefix=CHECK-AVX512F
8+
// RUN: triton-opt %s -split-input-file -triton-cpu-math-to-vec-lib="cpu_features=avx512f,avx,sse" | FileCheck %s --check-prefix=CHECK-AVX512F
9+
10+
// Convert math ops to VecLib ops.
11+
12+
// CHECK-SSE-LABEL: @exp_kernel
13+
// CHECK-SSE: %[[EXTRACTED:.*]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<256x4xf32>
14+
// CHECK-SSE-NEXT: %[[CALLED:.*]] = func.call @Sleef_expf4_u10(%[[EXTRACTED]]) : (vector<4xf32>) -> vector<4xf32>
15+
// CHECK-SSE-NEXT: %[[INSERTED:.*]] = vector.insert %[[CALLED]], %{{.*}}[0] : vector<4xf32> into vector<256x4xf32>
16+
17+
// CHECK-AVX-LABEL: @exp_kernel
18+
// CHECK-AVX: %[[EXTRACTED:.*]] = vector.extract %{{.*}}[0] : vector<8xf32> from vector<128x8xf32>
19+
// CHECK-AVX-NEXT: %[[CALLED:.*]] = func.call @Sleef_expf8_u10(%[[EXTRACTED]]) : (vector<8xf32>) -> vector<8xf32>
20+
// CHECK-AVX-NEXT: %[[INSERTED:.*]] = vector.insert %[[CALLED]], %{{.*}}[0] : vector<8xf32> into vector<128x8xf32>
21+
22+
// CHECK-AVX512F-LABEL: @exp_kernel
23+
// CHECK-AVX512F: %[[EXTRACTED:.*]] = vector.extract %{{.*}}[0] : vector<16xf32> from vector<64x16xf32>
24+
// CHECK-AVX512F-NEXT: %[[CALLED:.*]] = func.call @Sleef_expf16_u10(%[[EXTRACTED]]) : (vector<16xf32>) -> vector<16xf32>
25+
// CHECK-AVX512F-NEXT: %[[INSERTED:.*]] = vector.insert %[[CALLED]], %{{.*}}[0] : vector<16xf32> into vector<64x16xf32>
26+
27+
module {
28+
tt.func public @exp_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} , %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} , %arg2: i32 {tt.divisibility = 16 : i32} ) attributes {noinline = false} {
29+
%c0 = arith.constant 0 : index
30+
%0 = tt.get_program_id x : i32
31+
%1 = arith.muli %0, %arg2 : i32
32+
%2 = tt.addptr %arg1, %1 : !tt.ptr<f32>, i32
33+
%3 = triton_cpu.ptr_to_memref %2 : <f32> -> memref<1024xf32>
34+
%4 = vector.load %3[%c0] : memref<1024xf32>, vector<1024xf32>
35+
%5 = math.exp %4 : vector<1024xf32>
36+
%6 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
37+
%7 = triton_cpu.ptr_to_memref %6 : <f32> -> memref<1024xf32>
38+
vector.store %5, %7[%c0] : memref<1024xf32>, vector<1024xf32>
39+
tt.return
40+
}
41+
}

third_party/cpu/include/TritonCPUToLLVM/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def MathToVecLib : Pass<"triton-cpu-math-to-vec-lib", "mlir::ModuleOp"> {
104104
clEnumValN(mlir::triton::cpu::VecLib::Mvec, "mvec",
105105
"Use Mvec as mm lib")
106106
)}]>,
107+
ListOption<"cpu_features", "cpu_features", "std::string",
108+
"A list of available CPU features to choose proper vector functions">,
107109
];
108110

109111
let dependentDialects = ["mlir::vector::VectorDialect",

third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,8 @@ void populatePatternsForOp(RewritePatternSet &patterns,
346346
struct MathToVecLibPass
347347
: public mlir::triton::cpu::impl::MathToVecLibBase<MathToVecLibPass> {
348348
MathToVecLibPass() = default;
349-
size_t vec_size_in_bits;
349+
// Default to 128-bit if no features are specified.
350+
size_t vec_size_in_bits = 128;
350351

351352
explicit MathToVecLibPass(VecLib lib, std::set<std::string> cpu_features) {
352353
this->lib = lib;
@@ -358,10 +359,15 @@ struct MathToVecLibPass
358359
// Refactor this as an independent function.
359360
// And improve this to support other x86 SIMD ISAs and also for arm SVE
360361
// (VLA)
361-
vec_size_in_bits = 512;
362362
for (auto feature : cpu_features) {
363-
// Arm NEON is fixed 128-bit SIMD ISA.
364-
if (feature == "neon") {
363+
if (feature == "avx512f") {
364+
vec_size_in_bits = std::max<size_t>(vec_size_in_bits, 512);
365+
} else if (feature == "avx") {
366+
vec_size_in_bits = std::max<size_t>(vec_size_in_bits, 256);
367+
} else if (feature == "sse") {
368+
vec_size_in_bits = std::max<size_t>(vec_size_in_bits, 128);
369+
} else if (feature == "neon") {
370+
// Arm NEON is fixed 128-bit SIMD ISA.
365371
vec_size_in_bits = 128;
366372
break;
367373
}
@@ -374,6 +380,12 @@ struct MathToVecLibPass
374380

375381
RewritePatternSet patterns(context);
376382

383+
if (!cpu_features.empty()) {
384+
std::set<std::string> cpu_features_set{cpu_features.begin(),
385+
cpu_features.end()};
386+
update_vec_size(cpu_features_set);
387+
}
388+
377389
switch (lib) {
378390
case VecLib::Mvec: {
379391
populateCommonPatterns<MvecNameGenerator>(patterns);

0 commit comments

Comments
 (0)