Skip to content

Make backends return errors #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions blog/2024-11-21-optimizing-matrix-mul/code/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions blog/2024-11-21-optimizing-matrix-mul/code/Cargo.toml
Original file line number Diff line number Diff line change
@@ -48,6 +48,7 @@ spirv-std = { git = "https://github.com/rust-gpu/rust-gpu", rev = "0da80f8a61867
futures = "0.3"
glam = { version = "0.29.2", features = ["cuda", "bytemuck"] }
tracing = "0.1.40"
wgpu = { version = "23.0", features = ["spirv", "vulkan-portability"] }

# Enable incremental by default in release mode.
[profile.release]
30 changes: 7 additions & 23 deletions blog/2024-11-21-optimizing-matrix-mul/code/benches/gpu_bench.rs
Original file line number Diff line number Diff line change
@@ -29,17 +29,17 @@ const SIZES: &[(u32, u32, u32)] = &[
(64, 32, 128), // A: 64x32, B: 32x128, Result: 64x128
(1024, 512, 2048), // A: 1024x512, B: 512x2048, Result: 1024x2048
(2048, 1024, 4096), // A: 2048x1024, B: 1024x4096, Result: 2048x4096
*/
];

fn bench_all_variants(c: &mut Criterion) {
// Initialize all variants outside the loop
let multiplier_naive = matmul::naive::wgpu();
let multiplier_workgroup_256 = matmul::workgroup_256::wgpu();
let multiplier_workgroup_2d = matmul::workgroup_2d::wgpu();
let multiplier_tiling_1d = matmul::tiling_1d::wgpu();
let multiplier_tiling_1d_loop = matmul::tiling_1d_loop::wgpu();
let multiplier_tiling_2d = matmul::tiling_2d::wgpu();
let multiplier_isomorphic_gpu = matmul::isomorphic::wgpu();
let multiplier_naive = matmul::naive::wgpu().unwrap();
let multiplier_workgroup_256 = matmul::workgroup_256::wgpu().unwrap();
let multiplier_workgroup_2d = matmul::workgroup_2d::wgpu().unwrap();
let multiplier_tiling_1d = matmul::tiling_1d::wgpu().unwrap();
let multiplier_tiling_1d_loop = matmul::tiling_1d_loop::wgpu().unwrap();
let multiplier_tiling_2d = matmul::tiling_2d::wgpu().unwrap();

for &(m, k, n) in SIZES {
// Calculate FLOPs for this size
@@ -134,22 +134,6 @@ fn bench_all_variants(c: &mut Criterion) {
});
},
);

group.bench_with_input(
BenchmarkId::new("isomorphic:wgpu", format!("{}x{}x{}", m, k, n)),
&(m, k, n),
|bench, &(m, k, n)| {
bench.iter(|| {
black_box(multiplier_isomorphic_gpu.multiply(
black_box(&a),
black_box(&b),
m,
k,
n,
))
});
},
);
}
}

Original file line number Diff line number Diff line change
@@ -6,7 +6,6 @@ use rand::Rng;
use std::time::Duration;

const WARMUP_TIME: Duration = Duration::from_secs(2);
const MEASUREMENT_TIME: Duration = Duration::from_secs(5 * 60);
const SAMPLE_SIZE: usize = 10;

/// Matrix sizes to benchmark
@@ -34,19 +33,18 @@ const SIZES: &[(u32, u32, u32)] = &[

fn bench_isomorphic_variants(c: &mut Criterion) {
// Initialize isomorphic variants
let multiplier_isomorphic_gpu = matmul::isomorphic::wgpu();
let multiplier_isomorphic_cpu_single = matmul::isomorphic::cpu::single_threaded();
let multiplier_isomorphic_cpu_multi = matmul::isomorphic::cpu::multi_threaded();
let multiplier_isomorphic_gpu = matmul::isomorphic::wgpu().unwrap();
let multiplier_isomorphic_cpu_single = matmul::isomorphic::cpu::single_threaded().unwrap();
let multiplier_isomorphic_cpu_multi = matmul::isomorphic::cpu::multi_threaded().unwrap();

for &(m, k, n) in SIZES {
// Calculate FLOPs for this size
let flops = 2.0 * (m as f64 * n as f64 * k as f64);

let mut group = c.benchmark_group(format!("isomorphic_matmul{}x{}x{}", m, k, n));
let mut group = c.benchmark_group("isomorphic");
group.sampling_mode(SamplingMode::Flat);
group.warm_up_time(WARMUP_TIME);
//group.measurement_time(MEASUREMENT_TIME);
group.sample_size(SAMPLE_SIZE);

// Calculate FLOPs for this size
let flops = 2.0 * (m as f64 * n as f64 * k as f64);
group.throughput(Throughput::Elements(flops as u64));

// Create matrices for the given size
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ path = "src/bin.rs"
[dependencies]
matmul = { path = "../../crates/cpu/matmul" }
settings = { path = "../../crates/shared/settings" }
wgpu.workspace = true
futures.workspace = true
tracing.workspace = true
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "std"] }
196 changes: 143 additions & 53 deletions blog/2024-11-21-optimizing-matrix-mul/code/bin/blog/src/bin.rs
Original file line number Diff line number Diff line change
@@ -1,83 +1,173 @@
use matmul::MatrixMultiply;
use std::fmt::Display;
use std::sync::{Mutex, OnceLock};
use std::time::Instant;
use tracing::{debug, info, instrument, span, trace, Level};
use tracing::{debug, error, info, instrument, span, trace, warn, Level};
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
use wgpu::Device;

// Thread-safe global error state for WGPU.
// See https://github.com/gfx-rs/wgpu/issues/2912
static WGPU_ERROR_STATE: OnceLock<Mutex<Option<wgpu::Error>>> = OnceLock::new();

/// Initializes the global error state. Should be called once at startup.
fn init_error_state() {
WGPU_ERROR_STATE.set(Mutex::new(None)).unwrap();
}

/// Sets an error in the global error state.
fn set_error(error: wgpu::Error) {
if let Some(state) = WGPU_ERROR_STATE.get() {
let mut state_lock = state.lock().unwrap();
*state_lock = Some(error);
} else {
panic!("Error state not initialized!");
}
}

/// Clears the global error state.
fn clear_error() {
if let Some(state) = WGPU_ERROR_STATE.get() {
let mut state_lock = state.lock().unwrap();
*state_lock = None;
} else {
panic!("Error state not initialized!");
}
}

/// Retrieves and clears the last error.
fn take_error() -> Option<wgpu::Error> {
if let Some(state) = WGPU_ERROR_STATE.get() {
let mut state_lock = state.lock().unwrap();
state_lock.take()
} else {
panic!("Error state not initialized!");
}
}

/// Installs a global error handler for the given device.
fn install_error_handler(device: &Device) {
device.on_uncaptured_error(Box::new(move |error| {
set_error(error);
}));
}

fn main() {
// Initialize the error state.
init_error_state();

tracing_subscriber::registry()
.with(fmt::Layer::default())
.with(EnvFilter::from_default_env())
.init();

let sizes = [
// Square matrices
(2, 2, 2),
(4, 4, 4),
(8, 8, 8),
(16, 16, 16),
(32, 32, 32),
(64, 64, 64),
(128, 128, 128),
// Non-square matrices
(4, 2, 8), // A: 4x2, B: 2x8, Result: 4x8
(8, 4, 2), // A: 8x4, B: 4x2, Result: 8x2
(16, 8, 32), // A: 16x8, B: 8x32, Result: 16x32
(32, 16, 8), // A: 32x16, B: 16x8, Result: 32x8
(64, 32, 128), // A: 64x32, B: 32x128, Result: 64x128
(256, 256, 256),
(512, 512, 512),
(1024, 1024, 1024),
(2048, 2048, 2048),
];

run_tests(matmul::naive::wgpu(), &sizes);
run_tests(matmul::workgroup_256::wgpu(), &sizes);
run_tests(matmul::workgroup_2d::wgpu(), &sizes);
run_tests(matmul::tiling_1d::wgpu(), &sizes);
run_tests(matmul::tiling_1d_loop::wgpu(), &sizes);
run_tests(matmul::tiling_2d::wgpu(), &sizes);
for size in sizes {
let matmul = matmul::naive::wgpu().unwrap();
install_error_handler(&matmul.device);
run_test(matmul, size);
clear_error();
}

for size in sizes {
let matmul = matmul::workgroup_256::wgpu().unwrap();
install_error_handler(&matmul.device);
run_test(matmul, size);
clear_error();
}

for size in sizes {
let matmul = matmul::workgroup_2d::wgpu().unwrap();
install_error_handler(&matmul.device);
run_test(matmul, size);
clear_error();
}

for size in sizes {
let matmul = matmul::tiling_1d::wgpu().unwrap();
install_error_handler(&matmul.device);
run_test(matmul, size);
clear_error();
}

run_tests(matmul::isomorphic::wgpu(), &sizes);
run_tests(matmul::isomorphic::cpu::single_threaded(), &sizes);
run_tests(matmul::isomorphic::cpu::multi_threaded(), &sizes);
for size in sizes {
let matmul = matmul::tiling_1d_loop::wgpu().unwrap();
install_error_handler(&matmul.device);
run_test(matmul, size);
clear_error();
}

for size in sizes {
let matmul = matmul::tiling_2d::wgpu().unwrap();
install_error_handler(&matmul.device);
run_test(matmul, size);
clear_error();
}
}

#[instrument(skip(multiplier, sizes), fields(algorithm = %multiplier))]
fn run_tests<T: Display, U: MatrixMultiply<T>>(multiplier: U, sizes: &[(u32, u32, u32)]) {
#[instrument(skip(multiplier, size), fields(algorithm = %multiplier, size=?size))]
fn run_test<T: Display, U: MatrixMultiply<T>>(multiplier: U, size: (u32, u32, u32)) {
debug!(algorithm = %multiplier, "Starting tests");
let (m, k, n) = size;

let span = tracing::span!(Level::DEBUG, "matmul", algorithm = %multiplier, m, k, n);
let _enter = span.enter();

for &(m, k, n) in sizes {
let span = tracing::span!(Level::INFO, "matrix_test", algorithm = %multiplier, m, k, n);
let _enter = span.enter();

info!("Testing size: {}x{}x{}", m, k, n);

// Setup phase
let setup_span = span!(Level::INFO, "setup_phase");
let _setup_enter = setup_span.enter();
let a: Vec<f32> = (0..m * k).map(|i| i as f32).collect();
let b: Vec<f32> = (0..k * n).map(|i| i as f32).collect();
drop(_setup_enter);

// Compute phase
let compute_span = span!(Level::INFO, "compute_phase");
let compute_start = Instant::now();
let _compute_enter = compute_span.enter();
let result = multiplier.multiply(&a, &b, m, k, n);
let compute_time = compute_start.elapsed();
drop(_compute_enter);

// Calculate GFLOPS
let gflop_span = span!(Level::INFO, "calculate_gflops");
let _gflop_enter = gflop_span.enter();
let ops = 2.0 * (m * n * k) as f64;
let flops = ops / compute_time.as_secs_f64() / 1e9;
info!("Flops: {}", flops);
drop(_gflop_enter);

// Verification phase
let verify_span = span!(Level::INFO, "verification_phase");
let _verify_enter = verify_span.enter();
verify_results(&a, &b, &result, m, k, n);
drop(_verify_enter);
trace!("Testing size: {}x{}x{}", m, k, n);

// Setup phase
let setup_span = span!(Level::DEBUG, "setup_phase");
let _setup_enter = setup_span.enter();
let a: Vec<f32> = (0..m * k).map(|i| i as f32).collect();
let b: Vec<f32> = (0..k * n).map(|i| i as f32).collect();
drop(_setup_enter);

// Compute phase
let compute_span = span!(Level::DEBUG, "compute_phase");
let compute_start = Instant::now();
let _compute_enter = compute_span.enter();
let result = multiplier.multiply(&a, &b, m, k, n);
let compute_time = compute_start.elapsed();
drop(_compute_enter);

if let Some(error) = take_error() {
warn!("wgpu error occurred: {:?}", error);
return;
}

if result.is_err() {
error!("Error during computation: {:?}", result);
return;
}

let result = result.unwrap();

// Calculate FLOPS
let flop_span = span!(Level::DEBUG, "calculate_flops");
let _flop_enter = flop_span.enter();
let ops = 2.0 * (m * n * k) as f64;
let flops = ops / compute_time.as_secs_f64() / 1e9;
info!("Flops: {}", flops);
drop(_flop_enter);

// Verification phase
let verify_span = span!(Level::DEBUG, "verification_phase");
let _verify_enter = verify_span.enter();
verify_results(&a, &b, &result, m, k, n);
drop(_verify_enter);
}

#[instrument(skip(a, b, result), fields(rows = m, cols = n))]
Original file line number Diff line number Diff line change
@@ -9,12 +9,12 @@ crate-type = ["lib", "cdylib"]
[dependencies]
settings = { path = "../../shared/settings" }
bytemuck = { version = "1.9", features = ["derive"] }
wgpu = { version = "23.0", features = ["spirv"] }
ash = { version = "0.37" }
rayon = "1.10"
futures.workspace = true
tracing.workspace = true
glam.workspace = true
tracing.workspace = true
wgpu.workspace = true

# The following dependencies are used to link to the compiled shaders.
compiled_naive = { path = "../compiled_for_gpu/naive" }
@@ -26,3 +26,4 @@ compiled_tiling_2d = { path = "../compiled_for_gpu/tiling_2d" }
compiled_isomorphic = { path = "../compiled_for_gpu/isomorphic" }
# The CPU side of the isomophic implementation.
isomorphic = { path = "../../shared/isomorphic" }
thiserror = "2.0.3"
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Cpu, GridComputation, MatrixMultiply};
use crate::{Cpu, GridComputation, MatrixMultiply, MatrixMultiplyError};
use glam::UVec3;
use rayon::prelude::*;
use settings::Dimensions;
@@ -23,11 +23,18 @@ impl<T> MatrixMultiply<T> for SingleThreadedMatMul<T>
where
T: Cpu + GridComputation + Display + Send + Sync,
{
fn new(variant: T) -> impl Future<Output = Self> + Send {
async move { SingleThreadedMatMul { variant } }
fn new(variant: T) -> impl Future<Output = Result<Self, MatrixMultiplyError>> + Send {
async move { Ok(SingleThreadedMatMul { variant }) }
}

fn multiply(&self, a: &[f32], b: &[f32], m: u32, k: u32, n: u32) -> Vec<f32> {
fn multiply(
&self,
a: &[f32],
b: &[f32],
m: u32,
k: u32,
n: u32,
) -> Result<Vec<f32>, MatrixMultiplyError> {
// Initialize the result vector with zeros as that is what the GPU does.
let mut result = vec![0.0; (m * n) as usize];

@@ -68,7 +75,7 @@ where
}
}

result
Ok(result)
}
}

@@ -87,11 +94,18 @@ impl<T> MatrixMultiply<T> for MultiThreadedMatMul<T>
where
T: Cpu + GridComputation + Display + Send + Sync,
{
fn new(variant: T) -> impl Future<Output = Self> + Send {
async move { MultiThreadedMatMul { variant } }
fn new(variant: T) -> impl Future<Output = Result<Self, MatrixMultiplyError>> + Send {
async move { Ok(MultiThreadedMatMul { variant }) }
}

fn multiply(&self, a: &[f32], b: &[f32], m: u32, k: u32, n: u32) -> Vec<f32> {
fn multiply(
&self,
a: &[f32],
b: &[f32],
m: u32,
k: u32,
n: u32,
) -> Result<Vec<f32>, MatrixMultiplyError> {
// Initialize the result vector with zeros
let result = vec![0.0; (m * n) as usize];
let result = Mutex::new(result);
@@ -123,12 +137,14 @@ where
.collect();

// Process each (x, y) pair in parallel
tasks.par_iter().for_each(|&(x, y)| {
tasks.par_iter().try_for_each(|&(x, y)| {
// Define global_id (adjust z if necessary)
let global_id = UVec3::new(x as u32, y as u32, 0); // Changed z to 0 for consistency

// Lock the mutex to get mutable access to the result vector
let mut result_lock = result.lock().unwrap();
let mut result_lock = result
.lock()
.map_err(|_| MatrixMultiplyError::CpuLockError)?;

// Perform the matmul operation for element (x, y)
<T as Cpu>::call(
@@ -139,18 +155,21 @@ where
&b,
&mut result_lock,
);
});

Ok(())
})?;

// Extract the result vector from the Mutex
let result = Mutex::into_inner(result).unwrap();
let result = Mutex::into_inner(result).map_err(|_| MatrixMultiplyError::CpuLockError)?;

result
Ok(result)
}
}

#[cfg(test)]
mod tests {
use super::*;
use futures::executor::block_on;

#[test]
fn test_single_threaded_matmul_2x1x1() {
@@ -164,9 +183,12 @@ mod tests {
let expected = vec![3.0, 6.0];

let variant = crate::variants::Isomorphic;
let matrix_multiplier = futures::executor::block_on(SingleThreadedMatMul::new(variant));
let matrix_multiplier =
block_on(SingleThreadedMatMul::new(variant)).expect("Failed to create");

let result = matrix_multiplier.multiply(&a, &b, m, k, n);
let result = matrix_multiplier
.multiply(&a, &b, m, k, n)
.expect("Matrix multiplication failed");

assert_eq!(result, expected);
}
@@ -195,9 +217,12 @@ mod tests {
];

let variant = crate::variants::Isomorphic;
let matrix_multiplier = futures::executor::block_on(SingleThreadedMatMul::new(variant));
let matrix_multiplier =
block_on(SingleThreadedMatMul::new(variant)).expect("Failed to create");

let result = matrix_multiplier.multiply(&a, &b, m, k, n);
let result = matrix_multiplier
.multiply(&a, &b, m, k, n)
.expect("Matrix multiplication failed");

assert_eq!(result, expected);
}
@@ -214,9 +239,46 @@ mod tests {
let expected = vec![3.0, 6.0];

let variant = crate::variants::Isomorphic;
let matrix_multiplier = futures::executor::block_on(MultiThreadedMatMul::new(variant));
let matrix_multiplier =
block_on(MultiThreadedMatMul::new(variant)).expect("Failed to create");

let result = matrix_multiplier
.multiply(&a, &b, m, k, n)
.expect("Matrix multiplication failed");

assert_eq!(result, expected);
}

#[test]
fn test_multithreaded_matmul_4x4() {
let m = 4;
let k = 4;
let n = 4;

// Define matrix `a` (4x4) in row-major order
let a = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
];

// Define matrix `b` (4x4) in row-major order
let b = vec![
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
31.0, 32.0,
];

// Expected result (4x4) after multiplying `a` and `b`
let expected = vec![
250.0, 260.0, 270.0, 280.0, 618.0, 644.0, 670.0, 696.0, 986.0, 1028.0, 1070.0, 1112.0,
1354.0, 1412.0, 1470.0, 1528.0,
];

let variant = crate::variants::Isomorphic;
let matrix_multiplier =
block_on(MultiThreadedMatMul::new(variant)).expect("Failed to create");

let result = matrix_multiplier.multiply(&a, &b, m, k, n);
let result = matrix_multiplier
.multiply(&a, &b, m, k, n)
.expect("Matrix multiplication failed");

assert_eq!(result, expected);
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Gpu, GridComputation, MatrixMultiply};
use crate::{Gpu, GridComputation, MatrixMultiply, MatrixMultiplyError};
use bytemuck;
use futures::channel::oneshot;
use futures::executor::block_on;
@@ -11,7 +11,7 @@ use wgpu::{self, util::DeviceExt};

/// Matrix multiplication on the GPU using `wgpu`.
pub struct MatrixMultiplier<T> {
device: wgpu::Device,
pub device: wgpu::Device,
queue: wgpu::Queue,
pipeline: wgpu::ComputePipeline,
bind_group_layout: wgpu::BindGroupLayout,
@@ -29,7 +29,7 @@ where
T: Gpu + GridComputation + Display + Send,
{
/// Initializes a new `MatrixMultiplier` with necessary GPU resources.
async fn new(variant: T) -> Self {
async fn new(variant: T) -> Result<Self, MatrixMultiplyError> {
// Set up WGPU to talk to the system's GPUs and manage rendering or compute tasks.
let instance = create_instance().await;

@@ -51,20 +51,27 @@ where
// Build the actual GPU pipeline to run the GPU program and manage execution.
let pipeline = create_compute_pipeline(&device, &pipeline_layout, &shader);

Self {
Ok(Self {
device,
queue,
pipeline,
bind_group_layout,
variant,
}
})
}

/// Executes matrix multiplication for given input matrices.
///
/// Uploads the input matrices to the GPU, dispatches the compute shader,
/// and retrieves the result.
fn multiply(&self, a: &[f32], b: &[f32], m: u32, k: u32, n: u32) -> Vec<f32> {
fn multiply(
&self,
a: &[f32],
b: &[f32],
m: u32,
k: u32,
n: u32,
) -> Result<Vec<f32>, MatrixMultiplyError> {
trace!(?a, ?b, "Starting matrix multiplication");

let result_size = (m * n * std::mem::size_of::<f32>() as u32) as u64;
@@ -100,7 +107,7 @@ where
//
// This is a `uniform` buffer instead of `storage` buffer because the data is
// the same for all workgroups, it is read-only, and it is small enough to fit
// in a single buffer (`uniform` buffers are limited to to 64 KB on most GPUs
// in a single buffer (`uniform` buffers are limited to 64 KB on most GPUs
// and often less on older GPUs).
let dimensions = Dimensions::new(m, k, n);
let dimensions_buffer = create_buffer_init(
@@ -163,8 +170,8 @@ where

// Wait for the mapping to complete and verify success.
block_on(receiver)
.expect("Failed to receive data")
.expect("Map async failed");
.map_err(|_| MatrixMultiplyError::GpuDataReceive)?
.map_err(|_| MatrixMultiplyError::GpuBufferMapping)?;

// Read and convert the result data into a typed vector instead of raw bytes.
let data = slice.get_mapped_range();
@@ -173,7 +180,7 @@ where
staging_buffer.unmap();

trace!(?result, "Matrix multiplication result");
result
Ok(result)
}
}

Original file line number Diff line number Diff line change
@@ -4,14 +4,39 @@ use glam::UVec3;
use settings::Dimensions;
use std::fmt::Display;
use std::future::Future;
use thiserror::Error;

mod backends;
pub mod variants;

/// Errors that can happen for matrix multiply on the CPU or GPU.
#[derive(Error, Debug)]
pub enum MatrixMultiplyError {
#[error("Failed to initialize GPU instance")]
GpuInstanceCreation,
#[error("Failed to find an appropriate GPU adapter")]
GpuAdapterRequest,
#[error("Failed to create GPU device and queue")]
GpuDeviceCreation,
#[error("Failed to receive data from the GPU")]
GpuDataReceive,
#[error("Mapping GPU buffer failed")]
GpuBufferMapping,
#[error("Failed to acquire a lock on the result vector")]
CpuLockError,
}

/// The trait that defines how to multiply two matrices.
pub trait MatrixMultiply<T>: Display {
fn new(variant: T) -> impl Future<Output = Self> + Send;
fn multiply(&self, a: &[f32], b: &[f32], m: u32, k: u32, n: u32) -> Vec<f32>;
pub trait MatrixMultiply<T>: Display + Sized {
fn new(variant: T) -> impl Future<Output = Result<Self, MatrixMultiplyError>> + Send;
fn multiply(
&self,
a: &[f32],
b: &[f32],
m: u32,
k: u32,
n: u32,
) -> Result<Vec<f32>, MatrixMultiplyError>;
}

/// Matrix multiplication logic that can be run on the CPU.
@@ -44,7 +69,7 @@ pub mod naive {
use super::*;
use crate::backends::wgpu::MatrixMultiplier;

pub fn wgpu() -> MatrixMultiplier<variants::Naive> {
pub fn wgpu() -> Result<MatrixMultiplier<variants::Naive>, MatrixMultiplyError> {
futures::executor::block_on(backends::wgpu::MatrixMultiplier::new(variants::Naive))
}
}
@@ -53,7 +78,7 @@ pub mod workgroup_256 {
use super::*;
use crate::backends::wgpu::MatrixMultiplier;

pub fn wgpu() -> MatrixMultiplier<variants::Workgroup256> {
pub fn wgpu() -> Result<MatrixMultiplier<variants::Workgroup256>, MatrixMultiplyError> {
futures::executor::block_on(backends::wgpu::MatrixMultiplier::new(
variants::Workgroup256,
))
@@ -64,7 +89,7 @@ pub mod workgroup_2d {
use super::*;
use crate::backends::wgpu::MatrixMultiplier;

pub fn wgpu() -> MatrixMultiplier<variants::Workgroup2d> {
pub fn wgpu() -> Result<MatrixMultiplier<variants::Workgroup2d>, MatrixMultiplyError> {
futures::executor::block_on(MatrixMultiplier::new(variants::Workgroup2d))
}
}
@@ -73,7 +98,7 @@ pub mod tiling_1d {
use super::*;
use crate::backends::wgpu::MatrixMultiplier;

pub fn wgpu() -> MatrixMultiplier<variants::Tiling1d> {
pub fn wgpu() -> Result<MatrixMultiplier<variants::Tiling1d>, MatrixMultiplyError> {
futures::executor::block_on(MatrixMultiplier::new(variants::Tiling1d))
}
}
@@ -82,7 +107,7 @@ pub mod tiling_1d_loop {
use super::*;
use crate::backends::wgpu::MatrixMultiplier;

pub fn wgpu() -> MatrixMultiplier<variants::Tiling1dLoop> {
pub fn wgpu() -> Result<MatrixMultiplier<variants::Tiling1dLoop>, MatrixMultiplyError> {
futures::executor::block_on(MatrixMultiplier::new(variants::Tiling1dLoop))
}
}
@@ -91,7 +116,7 @@ pub mod tiling_2d {
use super::*;
use crate::backends::wgpu::MatrixMultiplier;

pub fn wgpu() -> MatrixMultiplier<variants::Tiling2d> {
pub fn wgpu() -> Result<MatrixMultiplier<variants::Tiling2d>, MatrixMultiplyError> {
futures::executor::block_on(MatrixMultiplier::new(variants::Tiling2d))
}
}
@@ -100,19 +125,21 @@ pub mod isomorphic {
use super::*;
use crate::backends::wgpu::MatrixMultiplier;

pub fn wgpu() -> MatrixMultiplier<variants::Isomorphic> {
pub fn wgpu() -> Result<MatrixMultiplier<variants::Isomorphic>, MatrixMultiplyError> {
futures::executor::block_on(MatrixMultiplier::new(variants::Isomorphic))
}

pub mod cpu {
use super::*;
use crate::backends::cpu::{MultiThreadedMatMul, SingleThreadedMatMul};

pub fn single_threaded() -> SingleThreadedMatMul<variants::Isomorphic> {
pub fn single_threaded(
) -> Result<SingleThreadedMatMul<variants::Isomorphic>, MatrixMultiplyError> {
futures::executor::block_on(SingleThreadedMatMul::new(variants::Isomorphic))
}

pub fn multi_threaded() -> MultiThreadedMatMul<variants::Isomorphic> {
pub fn multi_threaded(
) -> Result<MultiThreadedMatMul<variants::Isomorphic>, MatrixMultiplyError> {
futures::executor::block_on(MultiThreadedMatMul::new(variants::Isomorphic))
}
}