Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 9f9997c

Browse files
committedNov 21, 2024·
Fix 1D tiling
The original code from the blog looks wrong. The code in the repo has these checks and they make tests pass.
1 parent 7dbd06a commit 9f9997c

File tree

3 files changed

+56
-9
lines changed
  • blog/2024-11-21-optimizing-matrix-mul/code

3 files changed

+56
-9
lines changed
 

‎blog/2024-11-21-optimizing-matrix-mul/code/bin/blog/src/bin.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ fn main() {
3030
run_tests(matmul::naive::wgpu(), &sizes);
3131
run_tests(matmul::workgroup_256::wgpu(), &sizes);
3232
run_tests(matmul::workgroup_2d::wgpu(), &sizes);
33-
//run_tests(matmul::tiling_1d::wgpu(), &sizes);
33+
run_tests(matmul::tiling_1d::wgpu(), &sizes);
3434
run_tests(matmul::tiling_2d_simd::wgpu(), &sizes);
3535

3636
run_tests(matmul::isomorphic::wgpu(), &sizes);

‎blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/backends/cpu.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,37 @@ mod tests {
171171
assert_eq!(result, expected);
172172
}
173173

174+
#[test]
175+
fn test_single_threaded_matmul_4x4() {
176+
let m = 4;
177+
let k = 4;
178+
let n = 4;
179+
180+
// Define matrix `a` (4x4) in row-major order
181+
let a = vec![
182+
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
183+
];
184+
185+
// Define matrix `b` (4x4) in row-major order
186+
let b = vec![
187+
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
188+
31.0, 32.0,
189+
];
190+
191+
// Expected result (4x4) after multiplying `a` and `b`
192+
let expected = vec![
193+
250.0, 260.0, 270.0, 280.0, 618.0, 644.0, 670.0, 696.0, 986.0, 1028.0, 1070.0, 1112.0,
194+
1354.0, 1412.0, 1470.0, 1528.0,
195+
];
196+
197+
let variant = crate::variants::Isomorphic;
198+
let matrix_multiplier = futures::executor::block_on(SingleThreadedMatMul::new(variant));
199+
200+
let result = matrix_multiplier.multiply(&a, &b, m, k, n);
201+
202+
assert_eq!(result, expected);
203+
}
204+
174205
#[test]
175206
fn test_multithreaded_matmul_2x1x1() {
176207
let m = 2;

‎blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/tiling_1d/src/lib.rs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,30 @@ pub fn matmul(
2727

2828
for i in 0..dimensions.k as usize {
2929
let a_elem = a[row * dimensions.k as usize + i];
30-
sum00 += a_elem * b[i * dimensions.n as usize + col];
31-
sum01 += a_elem * b[i * dimensions.n as usize + col + 1];
32-
sum02 += a_elem * b[i * dimensions.n as usize + col + 2];
33-
sum03 += a_elem * b[i * dimensions.n as usize + col + 3];
30+
if col < dimensions.n as usize {
31+
sum00 += a_elem * b[i * dimensions.n as usize + col];
32+
}
33+
if col + 1 < dimensions.n as usize {
34+
sum01 += a_elem * b[i * dimensions.n as usize + col + 1];
35+
}
36+
if col + 2 < dimensions.n as usize {
37+
sum02 += a_elem * b[i * dimensions.n as usize + col + 2];
38+
}
39+
if col + 3 < dimensions.n as usize {
40+
sum03 += a_elem * b[i * dimensions.n as usize + col + 3];
41+
}
3442
}
3543

36-
result[row * dimensions.n as usize + col] = sum00;
37-
result[row * dimensions.n as usize + col + 1] = sum01;
38-
result[row * dimensions.n as usize + col + 2] = sum02;
39-
result[row * dimensions.n as usize + col + 3] = sum03;
44+
if col < dimensions.n as usize {
45+
result[row * dimensions.n as usize + col] = sum00;
46+
}
47+
if col + 1 < dimensions.n as usize {
48+
result[row * dimensions.n as usize + col + 1] = sum01;
49+
}
50+
if col + 2 < dimensions.n as usize {
51+
result[row * dimensions.n as usize + col + 2] = sum02;
52+
}
53+
if col + 3 < dimensions.n as usize {
54+
result[row * dimensions.n as usize + col + 3] = sum03;
55+
}
4056
}

0 commit comments

Comments
 (0)
Please sign in to comment.