From 992ad445577cd4150360938eae34b7fa5535dd6f Mon Sep 17 00:00:00 2001 From: Jonathan-Weinstein-AMD Date: Fri, 15 May 2026 12:16:01 -0700 Subject: [PATCH] zstd: Replace 5-permute WAVE_PROPAGATE_STEP series with 1-permute zstdgpu_WavePropogateFseTableIndex. In `zstdgpu_ShaderEntry_ParseCompressedBlocks()`, replace both series of 5 `WAVE_PROPAGATE_STEP`s ``` WAVE_PROPAGATE_STEP(x, 2) WAVE_PROPAGATE_STEP(x, 4) WAVE_PROPAGATE_STEP(x, 8) WAVE_PROPAGATE_STEP(x, 16) WAVE_PROPAGATE_STEP(x, 32) ``` with ``` x = zstdgpu_WavePropogateFseTableIndex(x) ``` which is leaner since it does a single lane permute, instead of 5. It also maybe makes it a bit clearer what is being computed. --- NOTE: `zstdgpu_WavePropogateFseTableIndex` doesn't handle active lanes at or beyond index 32, but that should be easy to add if desired, and neither did the original (it would need a `WAVE_PROPAGATE_STEP(x, 64)` for wave64). A pre-processor check/`#error` on `kzstdgpu_TgSizeX_ParseCompressedBlocks` was added, similar to the existing in `ZstdGpuPrefixSequenceOffsets.hlsl` on `kzstdgpu_TgSizeX_PrefixSequenceOffsets`. **Testing** I dispatched this HLSL shader with 256\*256 = 2\**16 total groups in a side-app and inspected the output. For the first 8 lanes in a wave, this tests all combinations of `{ Unused, Repeat, 2, 3 }`. This isn't a perfect test, but I also ran `zstdgpu_demo --chk-gpu` with some real inputs. ```hlsl // NOTE: makes lane masks easier; WaveTryReplicateFillerUpwardsToHoles doesn't bother with wave64+: static const uint32_t kzstdgpu_TgSizeX_ParseCompressedBlocks = 32; static const uint32_t kzstdgpu_FseProbTableIndex_Unused = 0x3fffffff; static const uint32_t kzstdgpu_FseProbTableIndex_Repeat = kzstdgpu_FseProbTableIndex_Unused - 1; uint WavePropogateFseTableIndex_Reference(uint x) { const uint32_t blockSize = min(WaveGetLaneCount(), kzstdgpu_TgSizeX_ParseCompressedBlocks); #define WAVE_SHUFFLE(v, and_mask, or_mask, xor_mask) WaveReadLaneAt(v, ((WaveGetLaneIndex() & (and_mask)) | (or_mask)) ^ (xor_mask)) #define WAVE_BROADCAST(v, group_size, group_lane) WAVE_SHUFFLE(v, ~(group_size - 1u), group_lane, 0) #define WAVE_PROPAGATE_STEP(p, group_size) \ if (blockSize >= group_size /** this condition is expected to be a compile-time condition, so no real branch */) \ { \ /* for every group of `group_size` consecutive lanes, broadcast the value from the last lane of the "odd" sub-group of 2x smaller size) */ \ uint32_t b = WAVE_BROADCAST(p, group_size, group_size / 2u - 1u); \ /* for every group of `group_size` consecutive lanes */ \ /* propagate element from the last lane of the "odd" sub-group of 2x smaller size */ \ /* into all elements of the "even" sub-group of 2x smaller size when propagated value makes sense */\ [flatten] if ((WaveGetLaneIndex() & (group_size / 2u))) \ { \ /* We propagate only non-Repeat and not-Unused values to lanes containing Repeat/Unused values*/\ if (p >= kzstdgpu_FseProbTableIndex_Repeat && b < kzstdgpu_FseProbTableIndex_Repeat) \ p = b; \ } \ } WAVE_PROPAGATE_STEP(x, 2) WAVE_PROPAGATE_STEP(x, 4) WAVE_PROPAGATE_STEP(x, 8) WAVE_PROPAGATE_STEP(x, 16) WAVE_PROPAGATE_STEP(x, 32) return x; #undef WAVE_PROPAGATE_STEP #undef WAVE_BROADCAST #undef WAVE_SHUFFLE } // Active lanes either contain a "filler" xor a "hole" value. // // If a lane with a hole value can't have a filler value propagated to it from a lower lane, // its value is unchanged (remains a hole). // // NOTE: ensure kzstdgpu_TgSizeX_ParseCompressedBlocks <= 32 // so HLSL lane masks are easy to work with. // // Example with lower lane IDs on the left for "Wave8" where filler values are even integers (holes are odd integers): // input = { 1, 4, 3, 3, 6, 8, 5, 5 } // output = { 1, 4, 4, 4, 6, 8, 8, 8 } uint WaveTryReplicateFillerUpwardsToHoles(uint v_value, bool v_isFiller) { const uint s_hasFillerMask = WaveActiveBallot(v_isFiller).x; // assume <= Wave32 const uint v_selfMask = 1u << WaveGetLaneIndex(); uint v_srcLanesMask = s_hasFillerMask & (v_selfMask - 1); // If this lane already has a filler value, or it has no lane with a filler value to read from, make it read from itself: if (v_isFiller || v_srcLanesMask == 0) { v_srcLanesMask = v_selfMask; } return WaveReadLaneAt(v_value, firstbithigh(v_srcLanesMask)); } uint WavePropogateFseTableIndex_V2(uint tableIndex) { const bool isFiller = tableIndex < kzstdgpu_FseProbTableIndex_Repeat; return WaveTryReplicateFillerUpwardsToHoles(tableIndex, isFiller); } RWStructuredBuffer uav : register(u4, space2); [numthreads(kzstdgpu_TgSizeX_ParseCompressedBlocks, 1, 1)] void main(uint2 combinationKey2 : SV_GroupId, uint threadIdInGroup : SV_GroupThreadID) { const uint combinationKey = combinationKey2.y * 256 + combinationKey2.x; // Lets test all combinations of "wave8" (to deal with less data) with up to 4 values per lane. // The shader is actually wave32; we don't really care about values at lane index 8+. // For a combination key of 0b00'00'00'00'11'00'10'00, the output should be [Unused, 2, 2, 3, 3, 3, 3, 3] const uint slotId = (threadIdInGroup % 8u); uint v = (combinationKey >> (slotId * 2)) & 0x3; if (v == 0) { v = kzstdgpu_FseProbTableIndex_Unused; } else if (v == 1) { v = kzstdgpu_FseProbTableIndex_Repeat; } const uint output_ref = WavePropogateFseTableIndex_Reference(v); const uint output_v2 = WavePropogateFseTableIndex_V2(v); if (output_ref != output_v2) { uav[0] = 0xEEEEEEEE; } uav[(combinationKey + 1) * kzstdgpu_TgSizeX_ParseCompressedBlocks + threadIdInGroup] = output_v2; } ``` --- zstd/zstdgpu/zstdgpu_shaders.h | 77 ++++++++++++++++++---------------- zstd/zstdgpu/zstdgpu_structs.h | 2 +- 2 files changed, 41 insertions(+), 38 deletions(-) diff --git a/zstd/zstdgpu/zstdgpu_shaders.h b/zstd/zstdgpu/zstdgpu_shaders.h index 83e2cd8..7aba197 100644 --- a/zstd/zstdgpu/zstdgpu_shaders.h +++ b/zstd/zstdgpu/zstdgpu_shaders.h @@ -777,6 +777,43 @@ static void zstdgpu_ParseFseHeader(ZSTDGPU_PARAM_INOUT(zstdgpu_Forward_BitBuffer outFseInfo[outFseTableIndex] = zstdgpu_CreateFseInfo(symbol, accuracyLog2); } + +// Active lanes either contain a "filler" xor a "hole" value. +// +// If a lane with a hole value can't have a filler value propagated to it from a lower lane, +// its value is unchanged (remains a hole). +// +// NOTE: ensure kzstdgpu_TgSizeX_ParseCompressedBlocks <= 32 +// so HLSL lane masks are easy to work with. +// +// Example with lower lane IDs on the left for "Wave8" where filler values are even integers (holes are odd integers): +// input = { 1, 4, 3, 3, 6, 8, 5, 5 } +// output = { 1, 4, 4, 4, 6, 8, 8, 8 } +inline uint32_t zstdgpu_WaveReplicateFillerUpwardsToHoles(uint32_t v_value, bool v_isFiller) +{ + const uint32_t s_hasFillerMask = WaveActiveBallot(v_isFiller).x; // assume <= Wave32 + const uint32_t v_selfMask = 1u << WaveGetLaneIndex(); + + uint32_t v_srcLanesMask = s_hasFillerMask & (v_selfMask - 1); + // If this lane already has a filler value, or it has no lane with a filler value to read from, make it read from itself: + if (v_isFiller || v_srcLanesMask == 0) + { + v_srcLanesMask = v_selfMask; + } + + return WaveReadLaneAt(v_value, zstdgpu_FindFirstBitHiU32(v_srcLanesMask)); +} + +inline uint32_t zstdgpu_WavePropogateFseTableIndex(uint32_t tableIndex) +{ +#if (kzstdgpu_TgSizeX_ParseCompressedBlocks - 1u) >= 32u + // Parsing compressed blocks can be divergent, so probably don't want a large thread group anyway. + #error "kzstdgpu_TgSizeX_ParseCompressedBlocks must be in [1:32], else implement WaveActiveBallot.y[zw] handling." +#endif + const bool isFiller = tableIndex < kzstdgpu_FseProbTableIndex_Repeat; + return zstdgpu_WaveReplicateFillerUpwardsToHoles(tableIndex, isFiller); +} + static void zstdgpu_ShaderEntry_ParseCompressedBlocks(ZSTDGPU_PARAM_INOUT(zstdgpu_ParseCompressedBlocks_SRT) srt, uint32_t threadId) { if (threadId >= srt.compressedBlockCount) @@ -1247,26 +1284,6 @@ static void zstdgpu_ShaderEntry_ParseCompressedBlocks(ZSTDGPU_PARAM_INOUT(zstdgp const uint32_t lastLocalIndex = WaveActiveCountBits(true) - 1u; - #define WAVE_SHUFFLE(v, and_mask, or_mask, xor_mask) WaveReadLaneAt(v, ((WaveGetLaneIndex() & (and_mask)) | (or_mask)) ^ (xor_mask)) - - #define WAVE_BROADCAST(v, group_size, group_lane) WAVE_SHUFFLE(v, ~(group_size - 1u), group_lane, 0) - - #define WAVE_PROPAGATE_STEP(p, group_size) \ - if (blockSize >= group_size /** this condition is expected to be a compile-time condition, so no real branch */) \ - { \ - /* for every group of `group_size` consecutive lanes, broadcast the value from the last lane of the "odd" sub-group of 2x smaller size) */ \ - uint32_t b = WAVE_BROADCAST(p, group_size, group_size / 2u - 1u); \ - /* for every group of `group_size` consecutive lanes */ \ - /* propagate element from the last lane of the "odd" sub-group of 2x smaller size */ \ - /* into all elements of the "even" sub-group of 2x smaller size when propagated value makes sense */\ - [flatten] if ((WaveGetLaneIndex() & (group_size / 2u))) \ - { \ - /* We propagate only non-Repeat and not-Unused values to lanes containing Repeat/Unused values*/\ - if (p >= kzstdgpu_FseProbTableIndex_Repeat && b < kzstdgpu_FseProbTableIndex_Repeat) \ - p = b; \ - } \ - } - // To propagate FSE table indices, we use a variant of "Decoupled Lookback" // 1. Each block (a group of `blockSize` threads) looks at indices of each type of FSE table // and checks for each of FSE table type if there's any FSE table "index" that is not `Unused` @@ -1325,12 +1342,7 @@ static void zstdgpu_ShaderEntry_ParseCompressedBlocks(ZSTDGPU_PARAM_INOUT(zstdgp #define LOOKBACK_STORE_EARLY_ANY_VALID(name) \ if (WaveActiveAnyTrue(indexValid##name)) \ { \ - uint32_t x = outBlockData.fseTableIndex##name; \ - WAVE_PROPAGATE_STEP(x, 2) \ - WAVE_PROPAGATE_STEP(x, 4) \ - WAVE_PROPAGATE_STEP(x, 8) \ - WAVE_PROPAGATE_STEP(x, 16) \ - WAVE_PROPAGATE_STEP(x, 32) \ + const uint32_t x = zstdgpu_WavePropogateFseTableIndex(outBlockData.fseTableIndex##name);\ const uint32_t xLast = WaveReadLaneAt(x, lastLocalIndex); \ if (WaveIsFirstLane()) \ { \ @@ -1451,15 +1463,10 @@ static void zstdgpu_ShaderEntry_ParseCompressedBlocks(ZSTDGPU_PARAM_INOUT(zstdgp // NOTE(pamartis): Because the first lane containining "non-Unused" index was set to something other than `Repeat`, // we can propagate indices across the wave (if needed of course, if the wave needs that -- contains any number of lanes with `Repeat` indices) #define PROPAGATE_ACROSS_WAVE_IF_NEEDED(name) \ - const bool needPropagateAcrossWave##name = fseTableIndexPropagated##name == kzstdgpu_FseProbTableIndex_Repeat; \ + const bool needPropagateAcrossWave##name = fseTableIndexPropagated##name == kzstdgpu_FseProbTableIndex_Repeat; \ if (WaveActiveAnyTrue(needPropagateAcrossWave##name)) \ { \ - uint32_t x = fseTableIndexPropagated##name; \ - WAVE_PROPAGATE_STEP(x, 2) \ - WAVE_PROPAGATE_STEP(x, 4) \ - WAVE_PROPAGATE_STEP(x, 8) \ - WAVE_PROPAGATE_STEP(x, 16) \ - WAVE_PROPAGATE_STEP(x, 32) \ + const uint32_t x = zstdgpu_WavePropogateFseTableIndex(fseTableIndexPropagated##name); \ if (needPropagateAcrossWave##name) \ { \ fseTableIndexPropagated##name = x; \ @@ -1478,10 +1485,6 @@ static void zstdgpu_ShaderEntry_ParseCompressedBlocks(ZSTDGPU_PARAM_INOUT(zstdgp outBlockData.fseTableIndexOffs = fseTableIndexPropagatedOffs; outBlockData.fseTableIndexMLen = fseTableIndexPropagatedMLen; - #undef WAVE_PROPAGATE_STEP - #undef WAVE_BROADCAST - #undef WAVE_SHUFFLE - #else // use static variables on CPU because this function is expected to be called in a loop for all compressed blocks static uint32_t lastHufWIndex = kzstdgpu_FseProbTableIndex_Unused; diff --git a/zstd/zstdgpu/zstdgpu_structs.h b/zstd/zstdgpu/zstdgpu_structs.h index 2f3f724..de7f544 100644 --- a/zstd/zstdgpu/zstdgpu_structs.h +++ b/zstd/zstdgpu/zstdgpu_structs.h @@ -380,7 +380,7 @@ static const uint32_t kzstdgpu_TgSizeX_PrefixSum = 64; static const uint32_t kzstdgpu_TgSizeX_PrefixSum = 32; #endif -static const uint32_t kzstdgpu_TgSizeX_ParseCompressedBlocks = 32; +#define kzstdgpu_TgSizeX_ParseCompressedBlocks 32 // #define since dxc may lack static_assert static const uint32_t kzstdgpu_TgSizeX_Memset = 64; // NOTE(pamartis): The rationale behind the below choice of TG sizes is the following