feat(examples): L3 ring allreduce (chunked RS+AG, a2a3 verified)#975
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis pull request adds a complete new L3 worker example implementing distributed ring AllReduce with chunked reduce-scatter and allgather phases. It includes AICORE kernel primitives, a full kernel implementation, orchestration wiring, Python runtime setup with golden-output validation, and integration tests. ChangesRing Allreduce Distributed Example
Sequence DiagramsequenceDiagram
participant Python as Python Runtime
participant Worker
participant Orch as Orchestration Layer
participant AIVKernel as AIV Kernel
participant Rank0
participant Rank1
Python->>Worker: Initialize worker
Python->>Python: Allocate per-rank input/output tensors
Python->>Python: Allocate ring domain window and scratch buffer
Python->>Orch: Submit orchestration DAG with tensor/scalar args
Orch->>AIVKernel: rt_submit_aiv_task (3 tensors + 2 scalars)
AIVKernel->>Rank0: Validate nranks, bind scratch layout
Rank0->>Rank0: Stage input into chunk slots
par Reduce-Scatter Phase
Rank0->>Rank1: Publish chunk, barrier signal
Rank1-->>Rank0: Send left-neighbor chunk
Rank0->>Rank0: Load/accumulate tile with MTE flags
end
par Allgather Phase
Rank0->>Rank1: Publish reduced chunk for dissemination
Rank1-->>Rank0: Send chunk from previous round
Rank0->>Rank0: Store chunk to output slot
end
AIVKernel->>AIVKernel: Stage concatenated chunks to output tensor
AIVKernel->>Worker: Return
Python->>Python: Compute golden expected output
Python->>Python: Validate each rank output vs golden (1e-3 tolerance)
Python->>Worker: Close worker
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a distributed ring AllReduce implementation, featuring chunked reduce-scatter and allgather algorithms. The code review feedback highlights several critical optimization and correctness improvements. First, the exchange buffer is completely unused by remote ranks and should be removed along with its redundant memory copies across the kernel, helper functions, and Python host code to improve performance and reduce scratch memory usage. Second, the kernel must explicitly zero-initialize the local signal slots to prevent undefined behavior, as device memory is not guaranteed to be zero-initialized. Finally, the unnecessary from __future__ import annotations import in main.py should be removed.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
New L3 example separate from mesh allreduce_distributed: stage-in, (P-1) reduce-scatter and (P-1) allgather ring rounds over HCCL window chunks with per-round TNOTIFY/TWAIT barriers. Same golden as mesh. P=2/P=4 pytest; default CLI devices 0-3.
75c3152 to
497ae58
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/workers/l3/allreduce_ring_distributed/main.py`:
- Around line 136-145: Validate the device_ids input at the top of run(): check
that device_ids is non-empty and that nranks = len(device_ids) is within the
supported range (e.g., between 2 and 16 as the example expects); if not, raise a
ValueError with a clear message so downstream calls (like scratch_float_elems)
don't hit ZeroDivisionError or unsupported configurations. Add this check
directly in run() before calling scratch_float_elems() or computing window_size,
referencing run() and scratch_float_elems() in the message so the caller can see
which entrypoint enforces the constraint.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: a173f9be-0a2b-4f97-bf62-6560ebde9a86
📒 Files selected for processing (7)
examples/workers/l3/README.mdexamples/workers/l3/allreduce_ring_distributed/__init__.pyexamples/workers/l3/allreduce_ring_distributed/kernels/aiv/allreduce_ring_common.hppexamples/workers/l3/allreduce_ring_distributed/kernels/aiv/allreduce_ring_kernel.cppexamples/workers/l3/allreduce_ring_distributed/kernels/orchestration/allreduce_ring_orch.cppexamples/workers/l3/allreduce_ring_distributed/main.pyexamples/workers/l3/allreduce_ring_distributed/test_allreduce.py
Drop RingZeroSignals (per-round barrier rows used once; zeroing raced peer notify and caused AICPU 507018 timeout). Recv via left neighbour chunks[] after barrier, not local exchange mirror (max golden diff 99 on second chunk). Size scratch CommBufferSpec to (P+1)*chunk elements. Align ring example with mesh L3 style: single allreduce_ring_kernel.cpp (no common header), phase banners, and matching orch/main.py comments.
497ae58 to
690efbc
Compare
Mirror parse_device_range() so pytest/CLI callers cannot pass an empty list or unsupported rank count into scratch_float_elems().
3a9b9b8 to
9c2236a
Compare
Summary
Reopens the work from #972 (that PR cannot be reopened after branch rebase).
Adds
examples/workers/l3/allreduce_ring_distributed/— chunked ring allreduce(RS + AG on a logical ring), separate from mesh
allreduce_distributed/.Closes / supersedes: #972
Ring uses +10624B HCCL window (chunked + per-round signals); mesh uses 4096B.
Algorithm
TNOTIFY/TWAITchunks[]viaCommRemotePtrafter each barrierTest plan
python examples/workers/l3/allreduce_ring_distributed/main.py -p a2a3 -d 5-6python examples/workers/l3/allreduce_ring_distributed/main.py -p a2a3 -d 0-3python examples/workers/l3/allreduce_distributed/main.py -p a2a3 -d 0-3