diff --git a/docs/api-specs/precompiled_shaders.md b/docs/api-specs/precompiled_shaders.md new file mode 100644 index 00000000000..398e8d95e42 --- /dev/null +++ b/docs/api-specs/precompiled_shaders.md @@ -0,0 +1,11 @@ +# Precompiled shaders +There are two main issues an implementation needs to cover +* Including and using reflection info +* Exposing how individual backends compile shaders outside of the backends +What changes need to be made +* I propose making a new crate, `wgpu-shaders` + * This crate would be a "wrapper" around `naga`, that would include all shader compiling logic + * This logic could then be used by both compile time macros and `wgpu-hal` itself + * This crate would include "backend"-specific parts, but it wouldn't need actual access to backends +* I also propose moving many `naga` types into `wgpu-types`, primarily those useful for reflection. + * The type to look out for here is `wgpu_core::validation::Interface`. This would also need to be moved into `wgpu-types` \ No newline at end of file diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index 3c3e8d38dd0..b22f87424ac 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -996,6 +996,12 @@ impl Global { runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), } } + pipeline::ShaderModuleDescriptorPassthrough::Generic(inner) => { + pipeline::ShaderModuleDescriptor { + label: inner.label.clone(), + runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), + } + } }, data, }); diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index 86bb5eb9b1d..63af5ea050a 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -1796,7 +1796,7 @@ impl Device { pipeline::ShaderModuleDescriptorPassthrough::Msl(inner) => { self.require_features(wgt::Features::MSL_SHADER_PASSTHROUGH)?; hal::ShaderInput::Msl { - shader: inner.source.to_string(), + shader: inner.source, entry_point: inner.entry_point.to_string(), num_workgroups: inner.num_workgroups, } @@ -1817,6 +1817,16 @@ impl Device { num_workgroups: inner.num_workgroups, } } + pipeline::ShaderModuleDescriptorPassthrough::Generic(inner) => { + self.require_features(wgt::Features::EXPERIMENTAL_PRECOMPILED_SHADERS)?; + hal::ShaderInput::Generic { + entry_point: inner.entry_point.clone(), + num_workgroups: inner.num_workgroups, + spirv: inner.spirv.as_deref(), + dxil: inner.dxil.as_deref(), + msl: inner.msl.as_deref(), + } + } }; let hal_desc = hal::ShaderModuleDescriptor { diff --git a/wgpu-hal/src/dx12/adapter.rs b/wgpu-hal/src/dx12/adapter.rs index ae3478b0949..29d16f6dd62 100644 --- a/wgpu-hal/src/dx12/adapter.rs +++ b/wgpu-hal/src/dx12/adapter.rs @@ -361,7 +361,8 @@ impl super::Adapter { | wgt::Features::DUAL_SOURCE_BLENDING | wgt::Features::TEXTURE_FORMAT_NV12 | wgt::Features::FLOAT32_FILTERABLE - | wgt::Features::TEXTURE_ATOMIC; + | wgt::Features::TEXTURE_ATOMIC + | wgt::Features::EXPERIMENTAL_PRECOMPILED_SHADERS; //TODO: in order to expose this, we need to run a compute shader // that extract the necessary statistics out of the D3D12 result. diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index 24cd3826d4b..54ce1fec00a 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -1732,6 +1732,22 @@ impl crate::Device for super::Device { raw_name, runtime_checks: desc.runtime_checks, }), + crate::ShaderInput::Generic { + dxil, + entry_point, + num_workgroups, + .. + } => Ok(super::ShaderModule { + source: super::ShaderModuleSource::DxilPassthrough(super::DxilPassthroughShader { + shader: dxil + .expect("Generic passthrough was given to dx12 backend without DXIL data") + .to_vec(), + entry_point, + num_workgroups, + }), + raw_name, + runtime_checks: desc.runtime_checks, + }), } } unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) { diff --git a/wgpu-hal/src/gles/device.rs b/wgpu-hal/src/gles/device.rs index 0f36f734b8c..95070480595 100644 --- a/wgpu-hal/src/gles/device.rs +++ b/wgpu-hal/src/gles/device.rs @@ -1349,6 +1349,9 @@ impl crate::Device for super::Device { crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => { panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled") } + crate::ShaderInput::Generic { .. } => { + panic!("`Features::EXPERIMENTAL_PRECOMPILED_SHADERS` is not enabled") + } }, label: desc.label.map(|str| str.to_string()), id: self.shared.next_shader_id.fetch_add(1, Ordering::Relaxed), diff --git a/wgpu-hal/src/lib.rs b/wgpu-hal/src/lib.rs index c5d81b28601..7fcc370a994 100644 --- a/wgpu-hal/src/lib.rs +++ b/wgpu-hal/src/lib.rs @@ -2116,7 +2116,7 @@ impl fmt::Debug for NagaShader { pub enum ShaderInput<'a> { Naga(NagaShader), Msl { - shader: String, + shader: &'a str, entry_point: String, num_workgroups: (u32, u32, u32), }, @@ -2131,6 +2131,14 @@ pub enum ShaderInput<'a> { entry_point: String, num_workgroups: (u32, u32, u32), }, + Generic { + entry_point: String, + num_workgroups: (u32, u32, u32), + + spirv: Option<&'a [u32]>, + dxil: Option<&'a [u8]>, + msl: Option<&'a str>, + }, } pub struct ShaderModuleDescriptor<'a> { diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index 6ecbff679f3..59901df4639 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -926,7 +926,8 @@ impl super::PrivateCapabilities { | F::TEXTURE_FORMAT_16BIT_NORM | F::SHADER_F16 | F::DEPTH32FLOAT_STENCIL8 - | F::BGRA8UNORM_STORAGE; + | F::BGRA8UNORM_STORAGE + | F::EXPERIMENTAL_PRECOMPILED_SHADERS; features.set(F::FLOAT32_FILTERABLE, self.supports_float_filtering); features.set( diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 6b3aeb3f9a8..d80b9144545 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1012,12 +1012,18 @@ impl crate::Device for super::Device { shader: source, entry_point, num_workgroups, + } + | crate::ShaderInput::Generic { + msl: Some(source), + entry_point, + num_workgroups, + .. } => { let options = metal::CompileOptions::new(); // Obtain the locked device from shared let device = self.shared.device.lock(); let library = device - .new_library_with_source(&source, &options) + .new_library_with_source(source, &options) .map_err(|e| crate::ShaderError::Compilation(format!("MSL: {:?}", e)))?; let function = library.get_function(&entry_point, None).map_err(|_| { crate::ShaderError::Compilation(format!( @@ -1042,6 +1048,9 @@ impl crate::Device for super::Device { crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => { panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled for this backend") } + crate::ShaderInput::Generic { .. } => { + panic!("Generic passthrough was given to metal backend without MSL data") + } } } diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index b429f2314dc..10a07ab271e 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -555,7 +555,8 @@ impl PhysicalDeviceFeatures { | F::CLEAR_TEXTURE | F::PIPELINE_CACHE | F::SHADER_EARLY_DEPTH_TEST - | F::TEXTURE_ATOMIC; + | F::TEXTURE_ATOMIC + | F::EXPERIMENTAL_PRECOMPILED_SHADERS; let mut dl_flags = Df::COMPUTE_SHADERS | Df::BASE_VERTEX diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index b6c8dba4053..fff8c66d6d9 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -1912,6 +1912,9 @@ impl crate::Device for super::Device { panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled") } crate::ShaderInput::SpirV(spv) => Cow::Borrowed(spv), + crate::ShaderInput::Generic { spirv, .. } => Cow::Borrowed( + spirv.expect("Generic passthrough was given to vulkan backend without SPIRV data"), + ), }; let raw = self.create_shader_module_impl(&spv)?; diff --git a/wgpu-types/src/features.rs b/wgpu-types/src/features.rs index e01885fc412..d1096487f46 100644 --- a/wgpu-types/src/features.rs +++ b/wgpu-types/src/features.rs @@ -1254,6 +1254,19 @@ bitflags_array! { /// /// This is a native only feature. const HLSL_DXIL_SHADER_PASSTHROUGH = 1 << 53; + + /// Enables creating shaders from passthrough with reflection info (unsafe) + /// + /// Shader code isn't parsed or interpreted in any way. It is the user's + /// responsibility to ensure the reflection is correct. + /// + /// Supported platforms + /// - Vulkan + /// - DX12 + /// + /// Ideally, in the future, all platforms will be supported. For more info, see + /// [my comment](https://github.com/gfx-rs/wgpu/issues/3103#issuecomment-2833058367). + const EXPERIMENTAL_PRECOMPILED_SHADERS = 1 << 54; } /// Features that are not guaranteed to be supported. diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index 7915b88020d..a00cafdd852 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -7769,6 +7769,8 @@ pub enum CreateShaderModuleDescriptorPassthrough<'a, L> { Dxil(ShaderModuleDescriptorDxil<'a, L>), /// Passthrough for HLSL Hlsl(ShaderModuleDescriptorHlsl<'a, L>), + /// Passthrough for multiple types of sources, with optional reflection + Generic(ShaderModuleDescriptorGeneric<'a, L>), } impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> { @@ -7791,7 +7793,7 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> { entry_point: inner.entry_point.clone(), label: fun(&inner.label), num_workgroups: inner.num_workgroups, - source: inner.source.clone(), + source: inner.source, }) } CreateShaderModuleDescriptorPassthrough::Dxil(inner) => { @@ -7810,6 +7812,20 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> { source: inner.source, }) } + CreateShaderModuleDescriptorPassthrough::Generic(inner) => { + CreateShaderModuleDescriptorPassthrough::<'_, K>::Generic( + ShaderModuleDescriptorGeneric { + entry_point: inner.entry_point.clone(), + label: fun(&inner.label), + num_workgroups: inner.num_workgroups, + reflection: inner.reflection.clone(), + spirv: inner.spirv.clone(), + dxil: inner.dxil.clone(), + msl: inner.msl.clone(), + runtime_checks: inner.runtime_checks, + }, + ) + } } } @@ -7820,6 +7836,7 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> { CreateShaderModuleDescriptorPassthrough::Msl(inner) => &inner.label, CreateShaderModuleDescriptorPassthrough::Dxil(inner) => &inner.label, CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => &inner.label, + CreateShaderModuleDescriptorPassthrough::Generic(inner) => &inner.label, } } @@ -7833,6 +7850,17 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> { CreateShaderModuleDescriptorPassthrough::Msl(inner) => inner.source.as_bytes(), CreateShaderModuleDescriptorPassthrough::Dxil(inner) => inner.source, CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => inner.source.as_bytes(), + CreateShaderModuleDescriptorPassthrough::Generic(inner) => { + if let Some(spirv) = &inner.spirv { + bytemuck::cast_slice(spirv) + } else if let Some(msl) = &inner.msl { + msl.as_bytes() + } else if let Some(dxil) = &inner.dxil { + dxil + } else { + panic!("No binary data provided to `ShaderModuleDescriptorGeneric`") + } + } } } @@ -7844,6 +7872,17 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> { CreateShaderModuleDescriptorPassthrough::Msl(..) => "msl", CreateShaderModuleDescriptorPassthrough::Dxil(..) => "dxil", CreateShaderModuleDescriptorPassthrough::Hlsl(..) => "hlsl", + CreateShaderModuleDescriptorPassthrough::Generic(inner) => { + if inner.spirv.is_some() { + "spv" + } else if inner.msl.is_some() { + "msl" + } else if inner.dxil.is_some() { + "dxil" + } else { + panic!("No binary data provided to `ShaderModuleDescriptorGeneric`") + } + } } } } @@ -7861,7 +7900,7 @@ pub struct ShaderModuleDescriptorMsl<'a, L> { /// Number of workgroups in each dimension x, y and z. pub num_workgroups: (u32, u32, u32), /// Shader MSL source. - pub source: Cow<'a, str>, + pub source: &'a str, } /// Descriptor for a shader module given by DirectX DXIL source. @@ -7876,11 +7915,10 @@ pub struct ShaderModuleDescriptorDxil<'a, L> { pub label: L, /// Number of workgroups in each dimension x, y and z. pub num_workgroups: (u32, u32, u32), - /// Shader DXIL source. + /// Shader MSL source. pub source: &'a [u8], } - -/// Descriptor for a shader module given by DirectX HLSL source. +/// Descriptor for a shader module given by DirectX DXIL source. /// /// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, /// only WGSL source code strings are accepted. @@ -7892,7 +7930,7 @@ pub struct ShaderModuleDescriptorHlsl<'a, L> { pub label: L, /// Number of workgroups in each dimension x, y and z. pub num_workgroups: (u32, u32, u32), - /// Shader HLSL source. + /// Shader MSL source. pub source: &'a str, } @@ -7907,3 +7945,33 @@ pub struct ShaderModuleDescriptorSpirV<'a, L> { /// Binary SPIR-V data, in 4-byte words. pub source: Cow<'a, [u32]>, } + +/// Descriptor for a shader module given by any of several sources, with optional reflection information. +/// All shader types that may be used by the backend must be `Some`, otherwise usage is undefined behavior +#[derive(Debug, Clone)] +pub struct ShaderModuleDescriptorGeneric<'a, L> { + /// Entrypoint. + pub entry_point: String, + /// Debug label of the shader module. This will show up in graphics debuggers for easy identification. + pub label: L, + /// Number of workgroups in each dimension x, y and z. + pub num_workgroups: (u32, u32, u32), + /// Reflection information + pub reflection: Option, + /// Runtime checks that should be enabled. + pub runtime_checks: ShaderRuntimeChecks, + + /// Binary SPIR-V data, in 4-byte words. + pub spirv: Option>, + /// Binary DXIL data + pub dxil: Option>, + /// Shader MSL source. + pub msl: Option>, +} + +/// Reflection information for a shader compiled with `naga` +#[derive(Debug, Clone)] +pub struct ShaderModuleReflection { + /// Number of workgroups in each dimension x, y and z. + pub num_workgroups: (u32, u32, u32), +} diff --git a/wgpu/src/api/shader_module.rs b/wgpu/src/api/shader_module.rs index c481de6218a..18e1c340eae 100644 --- a/wgpu/src/api/shader_module.rs +++ b/wgpu/src/api/shader_module.rs @@ -248,7 +248,7 @@ pub type ShaderModuleDescriptorMsl<'a> = wgt::ShaderModuleDescriptorMsl<'a, Labe /// only WGSL source code strings are accepted. pub type ShaderModuleDescriptorSpirV<'a> = wgt::ShaderModuleDescriptorSpirV<'a, Label<'a>>; -/// Descriptor for a shader module given by DirectX HLSL source. +/// Descriptor for a shader module given by DirectX HLSl source. /// /// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, /// only WGSL source code strings are accepted. @@ -259,3 +259,17 @@ pub type ShaderModuleDescriptorHlsl<'a> = wgt::ShaderModuleDescriptorHlsl<'a, La /// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, /// only WGSL source code strings are accepted. pub type ShaderModuleDescriptorDxil<'a> = wgt::ShaderModuleDescriptorDxil<'a, Label<'a>>; + +/// Descriptor for a shader module given by any of several sources, with optional reflection information. +/// All shader types that may be used by the backend must be `Some`, otherwise usage is undefined behavior +/// +/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, +/// only WGSL source code strings are accepted. +pub type ShaderModuleDescriptorGeneric<'a> = wgt::ShaderModuleDescriptorGeneric<'a, Label<'a>>; + +/// Reflection info for a shader module, created by compiling a shader with naga, +/// either at compile-time or run-time. +/// +/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, +/// only WGSL source code strings are accepted. +pub type ShaderModuleReflection = wgt::ShaderModuleReflection;