From 7c3d43c847f8899848da49aa29fd628ab4655ba8 Mon Sep 17 00:00:00 2001 From: firestar99 Date: Mon, 28 Jul 2025 13:04:46 +0200 Subject: [PATCH 1/3] wgsl: add `spirv-unknown-naga-wgsl` target, transpiling with naga 27 --- Cargo.lock | 2 + crates/rustc_codegen_spirv/Cargo.toml | 2 + crates/rustc_codegen_spirv/src/lib.rs | 1 + crates/rustc_codegen_spirv/src/link.rs | 5 ++ .../rustc_codegen_spirv/src/naga_transpile.rs | 79 +++++++++++++++++++ crates/rustc_codegen_spirv/src/target.rs | 77 ++++++++++++++++++ 6 files changed, 166 insertions(+) create mode 100644 crates/rustc_codegen_spirv/src/naga_transpile.rs diff --git a/Cargo.lock b/Cargo.lock index 64ee62a4795..c1d4a283b84 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2941,6 +2941,7 @@ dependencies = [ "lazy_static", "libc", "log", + "naga", "object 0.37.3", "pretty_assertions", "regex", @@ -2953,6 +2954,7 @@ dependencies = [ "spirt", "spirv-std-types", "spirv-tools", + "strum", "termcolor", "thorin-dwp", "tracing", diff --git a/crates/rustc_codegen_spirv/Cargo.toml b/crates/rustc_codegen_spirv/Cargo.toml index ed72fbd2800..1d8ae642f0b 100644 --- a/crates/rustc_codegen_spirv/Cargo.toml +++ b/crates/rustc_codegen_spirv/Cargo.toml @@ -61,6 +61,8 @@ itertools = "0.14.0" tracing.workspace = true tracing-subscriber.workspace = true tracing-tree = "0.4.0" +naga = { version = "27.0.3", features = ["spv-in", "wgsl-out"] } +strum = { version = "0.27.2", features = ["derive"] } [dev-dependencies] pretty_assertions = "1.0" diff --git a/crates/rustc_codegen_spirv/src/lib.rs b/crates/rustc_codegen_spirv/src/lib.rs index 5c891706ea8..fb68dcf687b 100644 --- a/crates/rustc_codegen_spirv/src/lib.rs +++ b/crates/rustc_codegen_spirv/src/lib.rs @@ -127,6 +127,7 @@ mod custom_decorations; mod custom_insts; mod link; mod linker; +mod naga_transpile; mod spirv_type; mod spirv_type_constraints; mod symbols; diff --git a/crates/rustc_codegen_spirv/src/link.rs b/crates/rustc_codegen_spirv/src/link.rs index fd903a52d08..8526bc7d547 100644 --- a/crates/rustc_codegen_spirv/src/link.rs +++ b/crates/rustc_codegen_spirv/src/link.rs @@ -3,6 +3,7 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa; use crate::codegen_cx::{CodegenArgs, SpirvMetadata}; use crate::linker; +use crate::naga_transpile::should_transpile; use crate::target::{SpirvTarget, SpirvTargetVariant}; use ar::{Archive, GnuBuilder, Header}; use rspirv::binary::Assemble; @@ -319,6 +320,10 @@ fn post_link_single_module( drop(save_modules_timer); } + + if let Some(transpile) = should_transpile(sess) { + transpile(sess, cg_args, &spv_binary, out_filename).ok(); + } } fn do_spirv_opt( diff --git a/crates/rustc_codegen_spirv/src/naga_transpile.rs b/crates/rustc_codegen_spirv/src/naga_transpile.rs new file mode 100644 index 00000000000..407f9f0a0e5 --- /dev/null +++ b/crates/rustc_codegen_spirv/src/naga_transpile.rs @@ -0,0 +1,79 @@ +use crate::codegen_cx::CodegenArgs; +use crate::target::{NagaTarget, SpirvTarget}; +use rustc_session::Session; +use rustc_span::ErrorGuaranteed; +use std::path::Path; + +pub type NagaTranspile = fn( + sess: &Session, + cg_args: &CodegenArgs, + spv_binary: &[u32], + out_filename: &Path, +) -> Result<(), ErrorGuaranteed>; + +pub fn should_transpile(sess: &Session) -> Option { + let target = SpirvTarget::parse_target(sess.opts.target_triple.tuple()) + .expect("parsing should fail earlier"); + match target { + SpirvTarget::Naga(NagaTarget::NAGA_WGSL) => Some(transpile::wgsl_transpile), + _ => None, + } +} + +mod transpile { + use crate::codegen_cx::CodegenArgs; + use naga::error::ShaderError; + use naga::valid::Capabilities; + use rustc_session::Session; + use rustc_span::ErrorGuaranteed; + use std::path::Path; + + pub fn wgsl_transpile( + sess: &Session, + _cg_args: &CodegenArgs, + spv_binary: &[u32], + out_filename: &Path, + ) -> Result<(), ErrorGuaranteed> { + // these should be params via spirv-builder + let opts = naga::front::spv::Options::default(); + let capabilities = Capabilities::all(); + let writer_flags = naga::back::wgsl::WriterFlags::empty(); + + let module = naga::front::spv::parse_u8_slice(bytemuck::cast_slice(spv_binary), &opts) + .map_err(|err| { + sess.dcx().err(format!( + "Naga failed to parse spv: \n{}", + ShaderError { + source: String::new(), + label: None, + inner: Box::new(err), + } + )) + })?; + let mut validator = + naga::valid::Validator::new(naga::valid::ValidationFlags::default(), capabilities); + let info = validator.validate(&module).map_err(|err| { + sess.dcx().err(format!( + "Naga validation failed: \n{}", + ShaderError { + source: String::new(), + label: None, + inner: Box::new(err), + } + )) + })?; + + let wgsl_dst = out_filename.with_extension("wgsl"); + let wgsl = naga::back::wgsl::write_string(&module, &info, writer_flags).map_err(|err| { + sess.dcx() + .err(format!("Naga failed to write wgsl : \n{err}")) + })?; + + std::fs::write(&wgsl_dst, wgsl).map_err(|err| { + sess.dcx() + .err(format!("failed to write wgsl to file: {err}")) + })?; + + Ok(()) + } +} diff --git a/crates/rustc_codegen_spirv/src/target.rs b/crates/rustc_codegen_spirv/src/target.rs index 69def2137d0..332757a80ef 100644 --- a/crates/rustc_codegen_spirv/src/target.rs +++ b/crates/rustc_codegen_spirv/src/target.rs @@ -5,11 +5,16 @@ use std::cmp::Ordering; use std::fmt::{Debug, Display, Formatter}; use std::ops::{Deref, DerefMut}; use std::str::FromStr; +use strum::{Display, EnumString, IntoStaticStr}; #[derive(Clone, Eq, PartialEq)] pub enum TargetError { + /// If during parsing a target variant returns `UnknownTarget`, further variants will attempt to parse the string. + /// Returning another error means that you have recognized the target but something else is invalid, and we should + /// abort the parsing with your error. UnknownTarget(String), InvalidTargetVersion(SpirvTarget), + InvalidNagaVariant(String), } impl Display for TargetError { @@ -21,6 +26,9 @@ impl Display for TargetError { TargetError::InvalidTargetVersion(target) => { write!(f, "Invalid version in target `{}`", target.env()) } + TargetError::InvalidNagaVariant(target) => { + write!(f, "Unknown naga out variant `{target}`") + } } } } @@ -439,6 +447,63 @@ impl Display for OpenGLTarget { } } +/// A naga target +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct NagaTarget { + pub out: NagaOut, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, IntoStaticStr, Display, EnumString)] +#[allow(clippy::upper_case_acronyms)] +pub enum NagaOut { + #[strum(to_string = "wgsl")] + WGSL, +} + +impl NagaTarget { + pub const NAGA_WGSL: Self = NagaTarget::new(NagaOut::WGSL); + pub const ALL_NAGA_TARGETS: &'static [Self] = &[Self::NAGA_WGSL]; + /// emit spirv like naga targets were this target + pub const EMIT_SPIRV_LIKE: SpirvTarget = SpirvTarget::VULKAN_1_3; + + pub const fn new(out: NagaOut) -> Self { + Self { out } + } +} + +impl SpirvTargetVariant for NagaTarget { + fn validate(&self) -> Result<(), TargetError> { + Ok(()) + } + + fn to_spirv_tools(&self) -> spirv_tools::TargetEnv { + Self::EMIT_SPIRV_LIKE.to_spirv_tools() + } + + fn spirv_version(&self) -> SpirvVersion { + Self::EMIT_SPIRV_LIKE.spirv_version() + } +} + +impl FromStr for NagaTarget { + type Err = TargetError; + + fn from_str(s: &str) -> Result { + let s = s + .strip_prefix("naga-") + .ok_or_else(|| TargetError::UnknownTarget(s.to_owned()))?; + Ok(Self::new(FromStr::from_str(s).map_err(|_e| { + TargetError::InvalidNagaVariant(s.to_owned()) + })?)) + } +} + +impl Display for NagaTarget { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "naga-{}", self.out) + } +} + /// A rust-gpu target #[derive(Copy, Clone, Debug, Eq, PartialEq)] #[non_exhaustive] @@ -446,6 +511,7 @@ pub enum SpirvTarget { Universal(UniversalTarget), Vulkan(VulkanTarget), OpenGL(OpenGLTarget), + Naga(NagaTarget), } impl SpirvTarget { @@ -467,12 +533,15 @@ impl SpirvTarget { pub const OPENGL_4_2: Self = Self::OpenGL(OpenGLTarget::OPENGL_4_2); pub const OPENGL_4_3: Self = Self::OpenGL(OpenGLTarget::OPENGL_4_3); pub const OPENGL_4_5: Self = Self::OpenGL(OpenGLTarget::OPENGL_4_5); + pub const NAGA_WGSL: Self = Self::Naga(NagaTarget::NAGA_WGSL); + #[allow(clippy::match_same_arms)] pub const fn memory_model(&self) -> MemoryModel { match self { SpirvTarget::Universal(_) => MemoryModel::Simple, SpirvTarget::Vulkan(_) => MemoryModel::Vulkan, SpirvTarget::OpenGL(_) => MemoryModel::GLSL450, + SpirvTarget::Naga(_) => MemoryModel::Vulkan, } } } @@ -483,6 +552,7 @@ impl SpirvTargetVariant for SpirvTarget { SpirvTarget::Universal(t) => t.validate(), SpirvTarget::Vulkan(t) => t.validate(), SpirvTarget::OpenGL(t) => t.validate(), + SpirvTarget::Naga(t) => t.validate(), } } @@ -491,6 +561,7 @@ impl SpirvTargetVariant for SpirvTarget { SpirvTarget::Universal(t) => t.to_spirv_tools(), SpirvTarget::Vulkan(t) => t.to_spirv_tools(), SpirvTarget::OpenGL(t) => t.to_spirv_tools(), + SpirvTarget::Naga(t) => t.to_spirv_tools(), } } @@ -499,6 +570,7 @@ impl SpirvTargetVariant for SpirvTarget { SpirvTarget::Universal(t) => t.spirv_version(), SpirvTarget::Vulkan(t) => t.spirv_version(), SpirvTarget::OpenGL(t) => t.spirv_version(), + SpirvTarget::Naga(t) => t.spirv_version(), } } } @@ -513,6 +585,9 @@ impl SpirvTarget { if matches!(result, Err(TargetError::UnknownTarget(..))) { result = OpenGLTarget::from_str(s).map(Self::OpenGL); } + if matches!(result, Err(TargetError::UnknownTarget(..))) { + result = NagaTarget::from_str(s).map(Self::Naga); + } result } @@ -533,6 +608,7 @@ impl SpirvTarget { SpirvTarget::Universal(t) => t.to_string(), SpirvTarget::Vulkan(t) => t.to_string(), SpirvTarget::OpenGL(t) => t.to_string(), + SpirvTarget::Naga(t) => t.to_string(), } } @@ -555,6 +631,7 @@ impl SpirvTarget { .iter() .map(|t| Self::OpenGL(*t)), ) + .chain(NagaTarget::ALL_NAGA_TARGETS.iter().map(|t| Self::Naga(*t))) } } From 8847c19a7eb14d000c489e141d964cfc585c34a3 Mon Sep 17 00:00:00 2001 From: firestar99 Date: Mon, 16 Jun 2025 10:03:25 +0200 Subject: [PATCH 2/3] wgsl: hide naga behind feature --- crates/rustc_codegen_spirv/Cargo.toml | 3 ++- crates/rustc_codegen_spirv/src/link.rs | 2 +- .../rustc_codegen_spirv/src/naga_transpile.rs | 20 ++++++++++++++----- tests/compiletests/Cargo.toml | 2 +- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/crates/rustc_codegen_spirv/Cargo.toml b/crates/rustc_codegen_spirv/Cargo.toml index 1d8ae642f0b..aaaed94a7ff 100644 --- a/crates/rustc_codegen_spirv/Cargo.toml +++ b/crates/rustc_codegen_spirv/Cargo.toml @@ -29,6 +29,7 @@ use-compiled-tools = ["spirv-tools/use-compiled-tools"] # and will likely produce compile errors when built against a different toolchain. # Enable this feature to be able to experiment with other versions. skip-toolchain-check = [] +naga = ["dep:naga"] [dependencies] # HACK(eddyb) these only exist to unify features across dependency trees, @@ -61,7 +62,7 @@ itertools = "0.14.0" tracing.workspace = true tracing-subscriber.workspace = true tracing-tree = "0.4.0" -naga = { version = "27.0.3", features = ["spv-in", "wgsl-out"] } +naga = { version = "27.0.3", features = ["spv-in", "wgsl-out"], optional = true } strum = { version = "0.27.2", features = ["derive"] } [dev-dependencies] diff --git a/crates/rustc_codegen_spirv/src/link.rs b/crates/rustc_codegen_spirv/src/link.rs index 8526bc7d547..ebefbcae4a1 100644 --- a/crates/rustc_codegen_spirv/src/link.rs +++ b/crates/rustc_codegen_spirv/src/link.rs @@ -321,7 +321,7 @@ fn post_link_single_module( drop(save_modules_timer); } - if let Some(transpile) = should_transpile(sess) { + if let Ok(Some(transpile)) = should_transpile(sess) { transpile(sess, cg_args, &spv_binary, out_filename).ok(); } } diff --git a/crates/rustc_codegen_spirv/src/naga_transpile.rs b/crates/rustc_codegen_spirv/src/naga_transpile.rs index 407f9f0a0e5..6f6b8131107 100644 --- a/crates/rustc_codegen_spirv/src/naga_transpile.rs +++ b/crates/rustc_codegen_spirv/src/naga_transpile.rs @@ -11,15 +11,25 @@ pub type NagaTranspile = fn( out_filename: &Path, ) -> Result<(), ErrorGuaranteed>; -pub fn should_transpile(sess: &Session) -> Option { +pub fn should_transpile(sess: &Session) -> Result, ErrorGuaranteed> { let target = SpirvTarget::parse_target(sess.opts.target_triple.tuple()) .expect("parsing should fail earlier"); - match target { - SpirvTarget::Naga(NagaTarget::NAGA_WGSL) => Some(transpile::wgsl_transpile), - _ => None, - } + let result: Result, ()> = match target { + #[cfg(feature = "naga")] + SpirvTarget::Naga(NagaTarget::NAGA_WGSL) => Ok(Some(transpile::wgsl_transpile)), + #[cfg(not(feature = "naga"))] + SpirvTarget::Naga(NagaTarget::NAGA_WGSL) => Err(()), + _ => Ok(None), + }; + result.map_err(|_e| { + sess.dcx().err(format!( + "Target `{}` requires feature \"naga\" on rustc_codegen_spirv", + target.target() + )) + }) } +#[cfg(feature = "naga")] mod transpile { use crate::codegen_cx::CodegenArgs; use naga::error::ShaderError; diff --git a/tests/compiletests/Cargo.toml b/tests/compiletests/Cargo.toml index 8b88c9211b0..644473a6471 100644 --- a/tests/compiletests/Cargo.toml +++ b/tests/compiletests/Cargo.toml @@ -15,7 +15,7 @@ use-compiled-tools = ["rustc_codegen_spirv/use-compiled-tools"] [dependencies] compiletest = { version = "0.11.2", package = "compiletest_rs" } -rustc_codegen_spirv = { workspace = true } +rustc_codegen_spirv = { workspace = true, features = ["naga"] } rustc_codegen_spirv-types = { workspace = true } clap = { version = "4", features = ["derive"] } itertools = "0.14.0" From 7052f394abe194b2b58f81464a26faeaf697d893 Mon Sep 17 00:00:00 2001 From: Firestar99 Date: Thu, 3 Jul 2025 18:07:11 +0200 Subject: [PATCH 3/3] wgsl: enable naga feature by default, cargo-gpu can't handle it --- crates/rustc_codegen_spirv/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/rustc_codegen_spirv/Cargo.toml b/crates/rustc_codegen_spirv/Cargo.toml index aaaed94a7ff..7cd9eaf1fbe 100644 --- a/crates/rustc_codegen_spirv/Cargo.toml +++ b/crates/rustc_codegen_spirv/Cargo.toml @@ -20,10 +20,10 @@ crate-type = ["dylib"] default = ["use-compiled-tools"] # If enabled, uses spirv-tools binaries installed in PATH, instead of # compiling and linking the spirv-tools C++ code -use-installed-tools = ["spirv-tools/use-installed-tools"] +use-installed-tools = ["spirv-tools/use-installed-tools", "naga"] # If enabled will compile and link the C++ code for the spirv tools, the compiled # version is preferred if both this and `use-installed-tools` are enabled -use-compiled-tools = ["spirv-tools/use-compiled-tools"] +use-compiled-tools = ["spirv-tools/use-compiled-tools", "naga"] # If enabled, this will not check whether the current rustc version is set to the # appropriate channel. rustc_cogeden_spirv requires a specific nightly version, # and will likely produce compile errors when built against a different toolchain.