From 7286c79abf93bcfde5b99ba40fabb8f6bca9f9b3 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Tue, 2 Jun 2026 12:53:42 -0700 Subject: [PATCH 1/2] Upload named-data constants in WebGPUGraph Summary: The Vulkan serializer that the WebGPU backend reuses stores every non-empty constant in the PTE's named-data map with `offset == UINT64_MAX` and a `named_key`, rather than inline in the VK00 blob. `WebGPUGraph::build` previously handled only inline constants, so a delegated op's constant weights were never uploaded and the op produced all zeros. `build` now also fetches named-data constants via `NamedDataMap::get_data`, mirroring the path `VulkanBackend` already uses. `aten.add` was unaffected since it has no constant tensors; the first consumer is the `rms_norm` op in the child diff. Differential Revision: D107288998 --- backends/webgpu/runtime/WebGPUBackend.cpp | 2 +- backends/webgpu/runtime/WebGPUGraph.cpp | 23 ++++++++++++++++++++++- backends/webgpu/runtime/WebGPUGraph.h | 9 ++++++++- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/backends/webgpu/runtime/WebGPUBackend.cpp b/backends/webgpu/runtime/WebGPUBackend.cpp index 5321c20aaa4..b4e3165d8f4 100644 --- a/backends/webgpu/runtime/WebGPUBackend.cpp +++ b/backends/webgpu/runtime/WebGPUBackend.cpp @@ -76,7 +76,7 @@ Result WebGPUBackend::init( } try { - graph->build(flatbuffer_data, constant_data); + graph->build(flatbuffer_data, constant_data, context.get_named_data_map()); } catch (const std::exception& e) { ET_LOG(Error, "WebGPU graph build failed: %s", e.what()); graph->~WebGPUGraph(); diff --git a/backends/webgpu/runtime/WebGPUGraph.cpp b/backends/webgpu/runtime/WebGPUGraph.cpp index 91404fb164f..855a0c8fae8 100644 --- a/backends/webgpu/runtime/WebGPUGraph.cpp +++ b/backends/webgpu/runtime/WebGPUGraph.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -93,7 +94,8 @@ WebGPUGraph::~WebGPUGraph() { void WebGPUGraph::build( const void* flatbuffer_data, - const uint8_t* constant_data) { + const uint8_t* constant_data, + const executorch::runtime::NamedDataMap* named_data_map) { if (!device_) { auto* ctx = get_default_webgpu_context(); if (ctx) { @@ -165,6 +167,25 @@ void WebGPUGraph::build( const uint8_t* src = constant_data + vk_bytes->offset(); wgpuQueueWriteBuffer( queue_, tensor.buffer, 0, src, tensor.nbytes); + } else if ( + vk_bytes->named_key() != nullptr && + named_data_map != nullptr) { + // Constant stored in the PTE named-data map. + auto buf = + named_data_map->get_data(vk_bytes->named_key()->c_str()); + if (buf.ok() && buf->size() >= tensor.nbytes) { + wgpuQueueWriteBuffer( + queue_, tensor.buffer, 0, buf->data(), tensor.nbytes); + buf->Free(); + } else { + throw std::runtime_error( + std::string("WebGPU: named constant '") + + vk_bytes->named_key()->c_str() + + "' missing or undersized in NamedDataMap"); + } + } else { + throw std::runtime_error( + "WebGPU: constant has no inline offset and no named-data key"); } } } diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index 3aa96917a4e..c600432966e 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -15,10 +15,14 @@ #include #include +#include + namespace executorch { namespace backends { namespace webgpu { +using executorch::runtime::NamedDataMap; + struct WebGPUTensor { WGPUBuffer buffer = nullptr; std::vector dims; @@ -66,7 +70,10 @@ class WebGPUGraph { // Build the graph from a deserialized VkGraph flatbuffer and constant data. // The flatbuffer_data pointer must remain valid during build(). - void build(const void* flatbuffer_data, const uint8_t* constant_data); + void build( + const void* flatbuffer_data, + const uint8_t* constant_data, + const NamedDataMap* named_data_map = nullptr); // Copy input tensor data from host pointers into GPU buffers. void copy_inputs(const std::vector>& inputs); From 158a7496e907163efcfb5895045d5593d6323e2e Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Tue, 2 Jun 2026 12:59:17 -0700 Subject: [PATCH 2/2] Add rms_norm op (#19893) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/19893 Adds the `et_vk.rms_norm.default` operator to the WebGPU backend: a WGSL compute shader using a cooperative tree reduction, one workgroup per row. The shader mirrors the Vulkan implementation (`backends/vulkan/runtime/graph/ops/impl/RmsNorm.cpp`, `backends/vulkan/runtime/graph/ops/glsl/rms_norm_buffer.glsl`); indexing assumes contiguous fp32 inputs. The handler fails loud (throws, mirroring Vulkan's `VK_CHECK_COND`) on invalid shape/dtype/dispatch-limit conditions, and defaults `eps` to the float32 machine epsilon. The weight constant is uploaded via the named-data path added in the parent diff. Differential Revision: D106887028 --- backends/webgpu/CMakeLists.txt | 41 +++- .../webgpu/runtime/ops/rms_norm/RmsNorm.cpp | 192 ++++++++++++++++++ .../webgpu/runtime/ops/rms_norm/rms_norm.wgsl | 72 +++++++ .../runtime/ops/rms_norm/rms_norm_wgsl.h | 93 +++++++++ backends/webgpu/test/native/test_rms_norm.cpp | 169 +++++++++++++++ backends/webgpu/test/ops/rms_norm/__init__.py | 0 .../webgpu/test/ops/rms_norm/test_rms_norm.py | 191 +++++++++++++++++ backends/webgpu/test/test_build_webgpu.sh | 28 ++- backends/webgpu/test/test_webgpu_native.cpp | 2 + 9 files changed, 781 insertions(+), 7 deletions(-) create mode 100644 backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp create mode 100644 backends/webgpu/runtime/ops/rms_norm/rms_norm.wgsl create mode 100644 backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h create mode 100644 backends/webgpu/test/native/test_rms_norm.cpp create mode 100644 backends/webgpu/test/ops/rms_norm/__init__.py create mode 100644 backends/webgpu/test/ops/rms_norm/test_rms_norm.py diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index ab2da24a569..972518f1399 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -26,9 +26,13 @@ if(NOT TARGET vulkan_schema) endif() set(WEBGPU_SRCS - runtime/WebGPUBackend.cpp runtime/WebGPUGraph.cpp - runtime/WebGPUDelegateHeader.cpp runtime/WebGPUDevice.cpp - runtime/ops/OperatorRegistry.cpp runtime/ops/add/BinaryOp.cpp + runtime/WebGPUBackend.cpp + runtime/WebGPUGraph.cpp + runtime/WebGPUDelegateHeader.cpp + runtime/WebGPUDevice.cpp + runtime/ops/OperatorRegistry.cpp + runtime/ops/add/BinaryOp.cpp + runtime/ops/rms_norm/RmsNorm.cpp ) add_library(webgpu_backend ${WEBGPU_SRCS}) @@ -116,4 +120,35 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST) target_compile_options(webgpu_native_test PRIVATE -fexceptions) set_property(TARGET webgpu_native_test PROPERTY CXX_STANDARD 17) + + add_executable(webgpu_rms_norm_test test/native/test_rms_norm.cpp) + + target_include_directories( + webgpu_rms_norm_test PRIVATE $ + "${WGPU_NATIVE_DIR}/include" + ) + + target_link_libraries( + webgpu_rms_norm_test + PRIVATE webgpu_backend + wgpu_native + executorch_core + extension_module_static + extension_data_loader + extension_tensor + portable_kernels + portable_ops_lib + ) + + if(APPLE) + target_link_libraries( + webgpu_rms_norm_test PRIVATE "-framework Metal" "-framework QuartzCore" + "-framework CoreGraphics" + ) + else() + target_link_libraries(webgpu_rms_norm_test PRIVATE dl m pthread) + endif() + + target_compile_options(webgpu_rms_norm_test PRIVATE -fexceptions) + set_property(TARGET webgpu_rms_norm_test PROPERTY CXX_STANDARD 17) endif() diff --git a/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp b/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp new file mode 100644 index 00000000000..3dbf444b772 --- /dev/null +++ b/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp @@ -0,0 +1,192 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include + +#include +#include +#include + +namespace executorch::backends::webgpu { + +namespace { + +// Uniform layout matching the WGSL Params struct (16-byte aligned). +struct RmsNormParams { + uint32_t num_rows; + uint32_t row_width; + float epsilon; + uint32_t _pad; +}; +static_assert(sizeof(RmsNormParams) == 16, "RmsNormParams must be 16 bytes"); + +void rms_norm_impl(WebGPUGraph& graph, const std::vector& args) { + // et_vk.rms_norm.default args: [in, weight, eps, out] + const int in_id = args.at(0); + const int weight_id = args.at(1); + const int eps_id = args.at(2); + const int out_id = args.at(3); + + WGPUDevice device = graph.device(); + + // Get epsilon (Double from a Python float; defaults to float32 eps) + float epsilon = 1.1920928955078125e-07f; + if (graph.get_value_type(eps_id) == WebGPUGraph::ValueType::Double) { + epsilon = static_cast(graph.get_double(eps_id)); + } else if (graph.get_value_type(eps_id) == WebGPUGraph::ValueType::Int) { + epsilon = static_cast(graph.get_int(eps_id)); + } + + // row_width = last dim; num_rows = product of the rest (PyTorch NCHW order) + const auto& in_tensor = graph.get_tensor(in_id); + if (in_tensor.dims.empty() || in_tensor.nbytes == 0) { + throw std::runtime_error("WebGPU rms_norm: empty input"); + } + const uint32_t row_width = static_cast(in_tensor.dims.back()); + if (row_width == 0) { + throw std::runtime_error("WebGPU rms_norm: zero row width"); + } + uint64_t in_numel = 1; + for (int64_t d : in_tensor.dims) { + in_numel *= static_cast(d); + } + // fp32-only shader: bail if the bytes don't match an fp32 element count. + if (in_tensor.nbytes != in_numel * sizeof(float)) { + throw std::runtime_error("WebGPU rms_norm: fp32-only (byte-size mismatch)"); + } + const uint32_t num_rows = static_cast(in_numel / row_width); + if (num_rows == 0) { + throw std::runtime_error("WebGPU rms_norm: zero rows"); + } + + // Create uniform buffer for params + RmsNormParams params = {}; + params.num_rows = num_rows; + params.row_width = row_width; + params.epsilon = epsilon; + + WGPUBufferDescriptor uniform_desc = {}; + uniform_desc.size = sizeof(RmsNormParams); + uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst; + uniform_desc.mappedAtCreation = true; + WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc); + void* mapped = + wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(RmsNormParams)); + std::memcpy(mapped, ¶ms, sizeof(RmsNormParams)); + wgpuBufferUnmap(uniform_buffer); + + graph.add_uniform_buffer_bytes(sizeof(RmsNormParams)); + + // Create shader module from built-in WGSL source + WGPUShaderSourceWGSL wgsl_desc = {}; + wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; + wgsl_desc.code = {kRmsNormWGSL, WGPU_STRLEN}; + + WGPUShaderModuleDescriptor shader_desc = {}; + shader_desc.nextInChain = &wgsl_desc.chain; + WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc); + + // Create bind group layout: out (rw) + in/weight (ro storage) + params + WGPUBindGroupLayoutEntry entries[4] = {}; + + // t_out - storage buffer, read-write + entries[0].binding = 0; + entries[0].visibility = WGPUShaderStage_Compute; + entries[0].buffer.type = WGPUBufferBindingType_Storage; + + // t_in - storage buffer, read-only + entries[1].binding = 1; + entries[1].visibility = WGPUShaderStage_Compute; + entries[1].buffer.type = WGPUBufferBindingType_ReadOnlyStorage; + + // t_weight - storage buffer, read-only + entries[2].binding = 2; + entries[2].visibility = WGPUShaderStage_Compute; + entries[2].buffer.type = WGPUBufferBindingType_ReadOnlyStorage; + + // params - uniform buffer + entries[3].binding = 3; + entries[3].visibility = WGPUShaderStage_Compute; + entries[3].buffer.type = WGPUBufferBindingType_Uniform; + + WGPUBindGroupLayoutDescriptor bgl_desc = {}; + bgl_desc.entryCount = 4; + bgl_desc.entries = entries; + WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc); + + // Create pipeline layout + WGPUPipelineLayoutDescriptor pl_desc = {}; + pl_desc.bindGroupLayoutCount = 1; + pl_desc.bindGroupLayouts = &bgl; + WGPUPipelineLayout pipeline_layout = + wgpuDeviceCreatePipelineLayout(device, &pl_desc); + + // Create compute pipeline + WGPUComputePipelineDescriptor pipeline_desc = {}; + pipeline_desc.layout = pipeline_layout; + pipeline_desc.compute.module = shader; + pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; + WGPUComputePipeline pipeline = + wgpuDeviceCreateComputePipeline(device, &pipeline_desc); + + // Create bind group with actual buffers + const auto& out_tensor = graph.get_tensor(out_id); + const auto& weight_tensor = graph.get_tensor(weight_id); + + WGPUBindGroupEntry bg_entries[4] = {}; + + bg_entries[0].binding = 0; + bg_entries[0].buffer = out_tensor.buffer; + bg_entries[0].size = out_tensor.nbytes; + + bg_entries[1].binding = 1; + bg_entries[1].buffer = in_tensor.buffer; + bg_entries[1].size = in_tensor.nbytes; + + bg_entries[2].binding = 2; + bg_entries[2].buffer = weight_tensor.buffer; + bg_entries[2].size = weight_tensor.nbytes; + + bg_entries[3].binding = 3; + bg_entries[3].buffer = uniform_buffer; + bg_entries[3].size = sizeof(RmsNormParams); + + WGPUBindGroupDescriptor bg_desc = {}; + bg_desc.layout = bgl; + bg_desc.entryCount = 4; + bg_desc.entries = bg_entries; + WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); + + // One workgroup per row (kRmsNormWorkgroupSize threads cooperate per row) + static_assert( + kRmsNormWorkgroupSize == 64, + "must match @workgroup_size and WG_SIZE in rms_norm.wgsl"); + if (num_rows > 65535u) { + throw std::runtime_error( + "WebGPU rms_norm: num_rows exceeds the 1D dispatch limit (65535)"); + } + graph.add_dispatch({pipeline, bind_group, num_rows}); + + // Release intermediate objects (pipeline and bind_group are kept by dispatch) + wgpuShaderModuleRelease(shader); + wgpuBindGroupLayoutRelease(bgl); + wgpuPipelineLayoutRelease(pipeline_layout); + // uniform_buffer is kept alive by the bind group +} + +} // namespace + +WEBGPU_REGISTER_OPERATORS { + WEBGPU_REGISTER_OP(et_vk.rms_norm.default, rms_norm_impl); +} + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/runtime/ops/rms_norm/rms_norm.wgsl b/backends/webgpu/runtime/ops/rms_norm/rms_norm.wgsl new file mode 100644 index 00000000000..4bd5618596f --- /dev/null +++ b/backends/webgpu/runtime/ops/rms_norm/rms_norm.wgsl @@ -0,0 +1,72 @@ +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_in: array; +@group(0) @binding(2) var t_weight: array; + +struct Params { + num_rows: u32, + row_width: u32, + epsilon: f32, + _pad: u32, +} +@group(0) @binding(3) var params: Params; + +const WG_SIZE: u32 = 64u; + +var shared_sum: array; + +fn reduce_shared(worker_id: u32) { + workgroupBarrier(); + var stride: u32 = WG_SIZE / 2u; + loop { + if (stride == 0u) { + break; + } + if (worker_id < stride) { + shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride]; + } + workgroupBarrier(); + stride = stride >> 1u; + } +} + +@compute @workgroup_size(64, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let row_idx = wid.x; + let worker_id = lid.x; + + if (row_idx >= params.num_rows) { + return; + } + + let base = row_idx * params.row_width; + + var local_sq_sum: f32 = 0.0; + var x: u32 = worker_id; + loop { + if (x >= params.row_width) { + break; + } + let v = t_in[base + x]; + local_sq_sum = local_sq_sum + v * v; + x = x + WG_SIZE; + } + + shared_sum[worker_id] = local_sq_sum; + reduce_shared(worker_id); + + let mean_sq = shared_sum[0] / f32(params.row_width); + let rstd = inverseSqrt(mean_sq + params.epsilon); + + x = worker_id; + loop { + if (x >= params.row_width) { + break; + } + let v = t_in[base + x]; + let w = t_weight[x]; + t_out[base + x] = v * rstd * w; + x = x + WG_SIZE; + } +} diff --git a/backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h b/backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h new file mode 100644 index 00000000000..3a2424f6a93 --- /dev/null +++ b/backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// WGSL shader source for rms_norm: y = x * w * rsqrt(mean(x^2) + eps) +inline constexpr const char* kRmsNormWGSL = R"( +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_in: array; +@group(0) @binding(2) var t_weight: array; + +struct Params { + num_rows: u32, + row_width: u32, + epsilon: f32, + _pad: u32, +} +@group(0) @binding(3) var params: Params; + +const WG_SIZE: u32 = 64u; + +var shared_sum: array; + +fn reduce_shared(worker_id: u32) { + workgroupBarrier(); + var stride: u32 = WG_SIZE / 2u; + loop { + if (stride == 0u) { + break; + } + if (worker_id < stride) { + shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride]; + } + workgroupBarrier(); + stride = stride >> 1u; + } +} + +@compute @workgroup_size(64, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let row_idx = wid.x; + let worker_id = lid.x; + + if (row_idx >= params.num_rows) { + return; + } + + let base = row_idx * params.row_width; + + var local_sq_sum: f32 = 0.0; + var x: u32 = worker_id; + loop { + if (x >= params.row_width) { + break; + } + let v = t_in[base + x]; + local_sq_sum = local_sq_sum + v * v; + x = x + WG_SIZE; + } + + shared_sum[worker_id] = local_sq_sum; + reduce_shared(worker_id); + + let mean_sq = shared_sum[0] / f32(params.row_width); + let rstd = inverseSqrt(mean_sq + params.epsilon); + + x = worker_id; + loop { + if (x >= params.row_width) { + break; + } + let v = t_in[base + x]; + let w = t_weight[x]; + t_out[base + x] = v * rstd * w; + x = x + WG_SIZE; + } +} +)"; + +inline constexpr uint32_t kRmsNormWorkgroupSize = 64; + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/test/native/test_rms_norm.cpp b/backends/webgpu/test/native/test_rms_norm.cpp new file mode 100644 index 00000000000..92f50facdf9 --- /dev/null +++ b/backends/webgpu/test/native/test_rms_norm.cpp @@ -0,0 +1,169 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::webgpu; +using namespace executorch::extension; +using namespace executorch::runtime; + +namespace { + +struct RmsNormCase { + const char* name; + std::array sizes; +}; + +// Mirrors test_rms_norm.py _CASES; the .py writes per-case .pte/input/golden. +constexpr RmsNormCase kRmsNormCases[] = { + {"baseline", {1, 1, 7, 896}}, + {"width_eq_wg", {1, 1, 1, 64}}, + {"width_lt_wg", {1, 1, 1, 32}}, + {"width_1", {1, 1, 1, 1}}, + {"width_100", {1, 1, 1, 100}}, + {"width_130", {1, 1, 1, 130}}, + {"rank4_guard", {1, 5, 4, 128}}, + {"many_rows", {1, 1, 1024, 64}}, + {"distinct_rows", {1, 1, 5, 256}}, + {"single_row", {1, 1, 1, 896}}, + {"mixed_sign", {1, 1, 4, 128}}, + {"large_4096", {1, 1, 1, 4096}}, + {"large_8192", {1, 1, 1, 8192}}, + {"weight_zeros_neg", {1, 1, 1, 128}}, +}; + +std::vector read_f32_bin(const std::string& path) { + std::ifstream f(path, std::ios::binary | std::ios::ate); + if (!f) { + return {}; + } + const std::streamsize bytes = f.tellg(); + f.seekg(0); + std::vector data(static_cast(bytes) / sizeof(float)); + f.read(reinterpret_cast(data.data()), bytes); + return data; +} + +bool run_case(const std::string& dir, const RmsNormCase& tc) { + printf("\n--- Test: rms_norm[%s] ---\n", tc.name); + const std::string base = dir + "/" + tc.name; + std::vector input = read_f32_bin(base + ".input.bin"); + std::vector golden = read_f32_bin(base + ".golden.bin"); + if (input.empty() || golden.empty()) { + printf("FAIL: could not read input/golden for %s\n", tc.name); + return false; + } + + Module module(base + ".pte"); + if (module.load_forward() != Error::Ok) { + printf("FAIL: could not load %s.pte\n", tc.name); + return false; + } + + std::vector sizes(tc.sizes.begin(), tc.sizes.end()); + size_t expected = 1; + for (int32_t d : tc.sizes) { + expected *= static_cast(d); + } + if (input.size() != expected) { + printf( + "FAIL: input numel %zu != expected %zu for %s\n", + input.size(), + expected, + tc.name); + return false; + } + auto x = make_tensor_ptr(sizes, std::vector(input)); + auto result = module.forward({EValue(x)}); + if (!result.ok()) { + printf("FAIL: forward failed (error %d)\n", (int)result.error()); + return false; + } + + const auto& outputs = result.get(); + if (outputs.empty() || !outputs[0].isTensor()) { + printf("FAIL: no tensor output\n"); + return false; + } + const auto& out_tensor = outputs[0].toTensor(); + if (static_cast(out_tensor.numel()) != golden.size()) { + printf( + "FAIL: output numel %zu != golden %zu\n", + (size_t)out_tensor.numel(), + golden.size()); + return false; + } + const float* out_data = out_tensor.const_data_ptr(); + + float max_abs_err = 0.0f; + float max_rel_err = 0.0f; + for (size_t i = 0; i < golden.size(); i++) { + const float abs_err = std::abs(out_data[i] - golden[i]); + max_abs_err = std::max(max_abs_err, abs_err); + const float denom = std::max(std::abs(golden[i]), 1e-6f); + max_rel_err = std::max(max_rel_err, abs_err / denom); + } + printf( + "Max abs error: %e Max rel error: %e (%zu elements)\n", + max_abs_err, + max_rel_err, + golden.size()); + if (max_abs_err > 1e-3f || max_rel_err > 1e-3f) { + printf("FAIL: rms_norm[%s] exceeds tolerance 1e-3\n", tc.name); + return false; + } + printf("PASS: rms_norm[%s]\n", tc.name); + return true; +} + +} // namespace + +int main(int argc, char** argv) { + std::string dir = "/tmp/rmsn"; + if (argc > 1) { + dir = argv[1]; + } + if (const char* env = std::getenv("WEBGPU_RMS_NORM_DIR")) { + dir = env; + } + + WebGPUContext ctx; + try { + ctx = create_webgpu_context(); + } catch (const std::exception& e) { + printf("SKIP: %s\n", e.what()); + return 0; + } + set_default_webgpu_context(&ctx); + printf("WebGPU device acquired (native); case dir: %s\n", dir.c_str()); + + bool ok = true; + for (const auto& tc : kRmsNormCases) { + ok = run_case(dir, tc) && ok; + } + + set_default_webgpu_context(nullptr); + destroy_webgpu_context(ctx); + + if (!ok) { + return 1; + } + printf("\nAll rms_norm tests passed\n"); + return 0; +} diff --git a/backends/webgpu/test/ops/rms_norm/__init__.py b/backends/webgpu/test/ops/rms_norm/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backends/webgpu/test/ops/rms_norm/test_rms_norm.py b/backends/webgpu/test/ops/rms_norm/test_rms_norm.py new file mode 100644 index 00000000000..e231a3b03c0 --- /dev/null +++ b/backends/webgpu/test/ops/rms_norm/test_rms_norm.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""fp32 RMSNorm export tests via VulkanPartitioner. + +Verifies the export side only; numerics are checked in the native test +`test/test_webgpu_native.cpp`. +""" + +import os +import unittest + +import torch +from executorch.backends.vulkan import VulkanPartitioner +from executorch.exir import to_edge_transform_and_lower + + +class RmsNormModule(torch.nn.Module): + """Standard RMSNorm with learnable per-feature weight.""" + + def __init__(self, hidden_size: int, eps: float = 1e-5) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_f32 = x.to(torch.float32) + var = x_f32.pow(2).mean(dim=-1, keepdim=True) + x_norm = x_f32 * torch.rsqrt(var + self.eps) + return (x_norm * self.weight).to(x.dtype) + + +class TestRmsNorm(unittest.TestCase): + def _export_and_check(self, model, example_inputs) -> None: + ep = torch.export.export(model, example_inputs) + et_program = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + + found_vulkan = False + for plan in et_program.executorch_program.execution_plan: + for delegate in plan.delegates: + if delegate.id == "VulkanBackend": + found_vulkan = True + break + self.assertTrue(found_vulkan, "Expected VulkanBackend delegate in .pte") + self.assertGreater(len(et_program.buffer), 100) + + def test_rms_norm_basic_small(self) -> None: + self._export_and_check(RmsNormModule(64), (torch.randn(1, 1, 1, 64),)) + + def test_rms_norm_llm_hidden(self) -> None: + # LLM-typical hidden size. + self._export_and_check(RmsNormModule(896), (torch.randn(1, 1, 1, 896),)) + + def test_rms_norm_multi_row(self) -> None: + # Multiple rows along the seq-len dimension (prefill-style). + self._export_and_check(RmsNormModule(896), (torch.randn(1, 1, 7, 896),)) + + def test_rms_norm_4d(self) -> None: + # 4D shape similar to QK norm with multiple Z slices. + self._export_and_check(RmsNormModule(128), (torch.randn(1, 5, 4, 128),)) + + +def export_rms_norm_model(output_path: str) -> None: + """Export the RMSNorm model to .pte for the native runtime test.""" + hidden = 896 + seq_len = 7 + model = RmsNormModule(hidden, eps=1e-6) + # Fix the weight to a known value the native test reconstructs. + with torch.no_grad(): + model.weight.copy_(torch.linspace(0.5, 1.5, hidden, dtype=torch.float32)) + example_inputs = (torch.randn(1, 1, seq_len, hidden),) + ep = torch.export.export(model, example_inputs) + et_program = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + with open(output_path, "wb") as f: + f.write(et_program.buffer) + print(f"Exported {output_path}") + + +def _ramp(shape) -> torch.Tensor: + """Deterministic linear ramp in [-1, 1] reshaped to `shape`.""" + n = 1 + for d in shape: + n *= d + return torch.linspace(-1.0, 1.0, n, dtype=torch.float32).reshape(shape) + + +def _linspace_weight(hidden: int) -> torch.Tensor: + return torch.linspace(0.5, 1.5, hidden, dtype=torch.float32) + + +def _distinct_rows(shape) -> torch.Tensor: + """Each row is a ramp scaled by 10^(r-2) so rows differ sharply in magnitude.""" + rows, width = shape[-2], shape[-1] + base = torch.linspace(-1.0, 1.0, width, dtype=torch.float32) + stacked = torch.stack([base * (10.0 ** (r - 2)) for r in range(rows)]) + return stacked.reshape(shape) + + +def _mixed_sign(shape) -> torch.Tensor: + """Row 0 all-negative, row 1 near-zero (eps-dominated), row 2 mixed, row 3 positive.""" + width = shape[-1] + base = torch.linspace(0.1, 1.0, width, dtype=torch.float32) + sign = torch.tensor([1.0, -1.0], dtype=torch.float32).repeat(width // 2) + stacked = torch.stack( + [-base, torch.full((width,), 1e-4, dtype=torch.float32), base * sign, base] + ) + return stacked.reshape(shape) + + +def _weight_zeros_neg(hidden: int) -> torch.Tensor: + """Spans negatives to positives with forced zeros (no weight>0 assumption).""" + w = torch.linspace(-1.0, 1.0, hidden, dtype=torch.float32).clone() + w[0] = 0.0 + w[hidden // 2] = 0.0 + return w + + +# Coverage cases (ssjia, D106887028): each bakes weight+shape -> own .pte; eps=1e-6. +_CASES = [ + {"name": "baseline", "shape": (1, 1, 7, 896)}, + {"name": "width_eq_wg", "shape": (1, 1, 1, 64)}, + {"name": "width_lt_wg", "shape": (1, 1, 1, 32)}, + { + "name": "width_1", + "shape": (1, 1, 1, 1), + "weight_fn": lambda h: torch.tensor([1.3], dtype=torch.float32), + "input_fn": lambda s: torch.tensor([0.7], dtype=torch.float32).reshape(s), + }, + {"name": "width_100", "shape": (1, 1, 1, 100)}, + {"name": "width_130", "shape": (1, 1, 1, 130)}, + {"name": "rank4_guard", "shape": (1, 5, 4, 128)}, + {"name": "many_rows", "shape": (1, 1, 1024, 64)}, + {"name": "distinct_rows", "shape": (1, 1, 5, 256), "input_fn": _distinct_rows}, + {"name": "single_row", "shape": (1, 1, 1, 896)}, + {"name": "mixed_sign", "shape": (1, 1, 4, 128), "input_fn": _mixed_sign}, + {"name": "large_4096", "shape": (1, 1, 1, 4096)}, + {"name": "large_8192", "shape": (1, 1, 1, 8192)}, + { + "name": "weight_zeros_neg", + "shape": (1, 1, 1, 128), + "weight_fn": _weight_zeros_neg, + }, +] + + +def export_rms_norm_cases(out_dir: str) -> None: + """Export every coverage case plus its torch golden for the native test. + + Writes `.pte`, `.input.bin`, `.golden.bin` (raw little-endian + fp32) under `out_dir` for each case in `_CASES`. + """ + os.makedirs(out_dir, exist_ok=True) + for case in _CASES: + shape = case["shape"] + hidden = shape[-1] + weight_fn = case.get("weight_fn", _linspace_weight) + input_fn = case.get("input_fn", _ramp) + + model = RmsNormModule(hidden, eps=1e-6) + with torch.no_grad(): + model.weight.copy_(weight_fn(hidden)) + x = input_fn(shape) + with torch.no_grad(): + golden = model(x) + + ep = torch.export.export(model, (x,)) + et_program = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + + name = case["name"] + with open(os.path.join(out_dir, f"{name}.pte"), "wb") as f: + f.write(et_program.buffer) + x.detach().cpu().numpy().astype("/dev/null || sysctl -n hw.ncpu) # ── Step 1: Python export tests ────────────────────────────────────────────── -echo "=== Step 1: Run Python export test ===" +echo "=== Step 1: Run Python export tests ===" $PYTHON_EXECUTABLE -m pytest "${SCRIPT_DIR}/ops/add/test_add.py" -v +# Non-fatal: a rms_norm pytest failure skips the rms_norm native test below +# rather than aborting the whole run. +RMS_NORM_PYTEST_OK=1 +$PYTHON_EXECUTABLE -m pytest "${SCRIPT_DIR}/ops/rms_norm/test_rms_norm.py" -v \ + || RMS_NORM_PYTEST_OK=0 # ── Step 2: Export .pte model ───────────────────────────────────────────────── echo "=== Step 2: Export test models ===" PTE_MODEL="/tmp/webgpu_add_test.pte" PTE_CHAINED_MODEL="/tmp/webgpu_chained_add_test.pte" +RMS_NORM_DIR="/tmp/rmsn" cd "${EXECUTORCH_ROOT}" $PYTHON_EXECUTABLE -c " from executorch.backends.webgpu.test.ops.add.test_add import export_add_model, export_chained_add_model export_add_model('${PTE_MODEL}') export_chained_add_model('${PTE_CHAINED_MODEL}') " +if [[ "${RMS_NORM_PYTEST_OK}" == "1" ]]; then + $PYTHON_EXECUTABLE -c " +from executorch.backends.webgpu.test.ops.rms_norm.test_rms_norm import export_rms_norm_cases +export_rms_norm_cases('${RMS_NORM_DIR}') +" || { echo "WARN: rms_norm export failed; skipping rms_norm native test"; RMS_NORM_PYTEST_OK=0; } +fi # ── Step 3: Native build + test (wgpu-native) ──────────────────────────────── @@ -59,10 +71,18 @@ cmake \ "${EXECUTORCH_ROOT}" cmake --build "${NATIVE_BUILD_DIR}" --target webgpu_native_test -j${NPROC} +cmake --build "${NATIVE_BUILD_DIR}" --target webgpu_rms_norm_test -j${NPROC} -echo "=== Step 4: Run native test ===" -WEBGPU_TEST_MODEL="${PTE_MODEL}" \ -WEBGPU_TEST_CHAINED_MODEL="${PTE_CHAINED_MODEL}" \ +echo "=== Step 4: Run native tests ===" +env \ + WEBGPU_TEST_MODEL="${PTE_MODEL}" \ + WEBGPU_TEST_CHAINED_MODEL="${PTE_CHAINED_MODEL}" \ "${NATIVE_BUILD_DIR}/backends/webgpu/webgpu_native_test" +if [[ "${RMS_NORM_PYTEST_OK}" == "1" ]]; then + "${NATIVE_BUILD_DIR}/backends/webgpu/webgpu_rms_norm_test" "${RMS_NORM_DIR}" +else + echo "(skipping rms_norm native test: pytest or export did not complete)" +fi + echo "=== Done ===" diff --git a/backends/webgpu/test/test_webgpu_native.cpp b/backends/webgpu/test/test_webgpu_native.cpp index d3005debf37..5b9d538223e 100644 --- a/backends/webgpu/test/test_webgpu_native.cpp +++ b/backends/webgpu/test/test_webgpu_native.cpp @@ -10,10 +10,12 @@ #include #include +#include #include #include #include #include +#include using namespace executorch::backends::webgpu; using namespace executorch::extension;