Skip to content

[JAX] Expert Parallelism: JAX primitives + VJPs#3036

Open
phu0ngng wants to merge 22 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax
Open

[JAX] Expert Parallelism: JAX primitives + VJPs#3036
phu0ngng wants to merge 22 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax

Conversation

@phu0ngng

@phu0ngng phu0ngng commented May 22, 2026

Copy link
Copy Markdown
Collaborator

Summary

Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the nvte_ep_* C API, a Python wrapper with custom_vjp for autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCL ncclEpDispatch/ncclEpCombine are exposed as XLA primitives and work with CUDA-graph capture.

Implementation

Public Python API (transformer_engine/jax/ep.py)

from transformer_engine.jax.ep import (
    EpHandle,        # opaque (id, handle_mem) pair from ep_prepare
    ep_bootstrap,    # one-shot per-process: init NCCL comm + nvte_ep_initialize
    ep_dispatch,     # custom_vjp-wrapped dispatch 
    ep_combine,      # custom_vjp-wrapped combine

ep_dispatch / ep_combine are jax.custom_vjp functions: forward is the FFI primitive, backward calls the matching nvte_ep_*_bwd FFI primitive directly (no ep_prepare in the bwd — routing state is already cached in handle.mem). Note that ep_dispatch also calls ep_prepare in the forward path, which all-gathers and preprocesses routing maps.

XLA FFI bindings (transformer_engine/jax/csrc/extensions/ep.cpp)

Five XLA_FFI_DEFINE_HANDLER_SYMBOL entries — EpPrepareHandler, EpDispatchHandler, EpCombineHandler, EpDispatchBwdHandler, EpCombineBwdHandler — each calling the corresponding nvte_ep_* C entry point. All marked FFI_CudaGraph_Traits so they capture cleanly. handle_id is a static FFI attribute baked at jit trace time.

Primitives + Python layer (transformer_engine/jax/cpp_extensions/ep.py, +951 lines)

Standard TE primitive plumbing: abstract_eval (shape/dtype inference), lowering, impl, outer_primitive registration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).

Sharding (transformer_engine/jax/sharding.py, +12 lines)

Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.

Build wiring (build_tools/jax.py, +41 lines)

Threads NCCL EP linkage through the JAX transformer_engine_jax extension. No new top-level build flags; rides on the parent PR's NVTE_BUILD_WITH_NCCL_EP.

Tests & example

  • tests/jax/test_multi_process_ep.py (+690 lines): 13 tests covering bootstrap, ep_prepare shape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing), custom_vjp fwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).
  • tests/jax/multi_process_launch_ep.sh: 4-rank launcher; sets XLA_FLAGS to keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).
  • examples/jax/ep/ep_moe.py (+394 lines) + run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison --check that verifies fwd+bwd vs a single-process reference.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps

greptile-apps Bot commented May 22, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR lands the JAX Expert Parallelism (EP) bindings: XLA FFI handlers over the nvte_ep_* C API, custom_vjp-wrapped Python API, mesh-aware SPMD sharding rules for all five EP ops (prepare, dispatch-fwd/bwd, combine-fwd/bwd), a multi-process test suite, and an end-to-end MoE example. It also consolidates the NCCL EP arch-check into a shared nccl_ep_enabled() helper to ensure build-flag consistency between setup.py and build_tools/jax.py.

  • Core Python layer (cpp_extensions/ep.py): Five BasePrimitive subclasses with abstract_eval, lowering, SPMD partition, and Shardy rules. A shape mismatch arises when dp_resource is set but topk_idx is sharded only on "ep"outer_abstract sizes outputs as num_ep_groups * ep_size but partition only divides by ep_size, producing mismatched per-shard shapes and a JAX compilation error.
  • Public API (ep.py): ep_bootstrap correctly validates mesh config, divisibility, and single-device-per-process constraints. _dispatch_bwd/_combine_bwd re-derive out_spec from global_mesh_resource() at backward-trace time instead of capturing it in the forward residuals, which is fragile across mesh-context changes.
  • C++ / FFI (ep.cpp): Five handlers with a shared EpInstantiateHandler tying EpResources lifetime to compiled executables. ~EpResources calls nvte_ep_shutdown() without holding the global mutex.

Confidence Score: 4/5

Safe for the standard compound (dp+ep) sharding path; a compilation error surfaces when dp is active and tokens are sharded on ep only.

EpPreparePrimitive.partition uses PartitionSpec(idx_spec[0]) to derive output sharding but outer_abstract fixed the global shape as num_ep_groups * ep_size. With ep-only input sharding when dp is set, JAX infers per-shard shape (num_ep_groups, ...) while sharded_impl produces (1, ...), causing a compile-time shape error.

transformer_engine/jax/cpp_extensions/ep.py — EpPreparePrimitive.partition and EpDispatchPrimitive.partition need either stronger _leading_axis_ok validation or output sharding derived from the full compound spec.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/ep.py Core Python primitives (+968 lines): abstract eval, lowering, partition, sharding rules for all five EP ops. Shape mismatch possible when dp is set and topk_idx is sharded only on ep (not the compound (dp,ep)), causing a JAX compilation error. EpCombineBwdPrimitive.partition also omits the _ep_spec_ok guard present in every other primitive.
transformer_engine/jax/csrc/extensions/ep.cpp XLA FFI handlers for all five EP ops (+497 lines). Logic is sound; nvte_ep_shutdown() in ~EpResources is called without a lock and without a use-count check on concurrent EpInstanceState refs.
transformer_engine/jax/ep.py Public ep_bootstrap / ep_dispatch / ep_combine API (+318 lines). _dispatch_bwd and _combine_bwd recompute out_spec from global_mesh_resource() at backward-trace time instead of capturing it from the forward residuals, which is fragile if the mesh context changes between traces.
build_tools/utils.py Refactors NCCL EP arch-check logic into a shared nccl_ep_enabled() helper, eliminating the inconsistency between setup.py and build_tools/jax.py.
transformer_engine/jax/sharding.py Adds ep_resource field to MeshResource and ep_axis_size() helper (+12 lines). Change is additive and backward-compatible.
transformer_engine/jax/csrc/extensions/pybind.cpp Registers all five EP FFI handlers and exposes bootstrap/teardown helpers to Python under the NVTE_WITH_NCCL_EP guard.
tests/jax/test_multi_process_ep.py New 742-line multi-process test suite covering bootstrap, prepare, dispatch/combine identity, custom_vjp fwd+bwd correctness, and HLO reshard guard with both uniform and skewed routing patterns.
build_tools/jax.py Adds -DNVTE_WITH_NCCL_EP compile flag via nccl_ep_enabled(); build flag is now consistent with setup.py.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant PY as Python (ep.py)
    participant CPP as cpp_extensions/ep.py
    participant FFI as XLA FFI (ep.cpp)
    participant NCCL as NCCL EP Backend

    Note over PY: ep_bootstrap()
    PY->>FFI: set_ep_bootstrap_params(uid, ep_size, ...)
    FFI->>NCCL: ncclCommInitRank + nvte_ep_initialize

    Note over PY: Forward Pass (ep_dispatch)
    PY->>CPP: ep_prepare(cfg, topk_idx)
    CPP->>FFI: EpPrepareFFI
    FFI-->>CPP: token_counts, handle_mem
    PY->>CPP: ep_dispatch_fwd(cfg, handle_mem, ...)
    CPP->>FFI: EpDispatchFFI
    FFI->>NCCL: ncclEpDispatch (collective)
    NCCL-->>PY: recv_tokens, recv_topk_weights

    Note over PY: Expert Computation (user code)

    Note over PY: Forward Pass (ep_combine)
    PY->>CPP: ep_combine_fwd(cfg, handle_mem, expert_out, ...)
    CPP->>FFI: EpCombineFFI
    FFI->>NCCL: ncclEpCombine (collective)
    NCCL-->>PY: result

    Note over PY: Backward Pass
    PY->>CPP: ep_combine_bwd(cfg, handle_mem, g_result, ...)
    CPP->>FFI: EpCombineBwdFFI
    FFI-->>PY: grad_expert_out
    PY->>CPP: ep_dispatch_bwd(cfg, handle_mem, g_recv, ...)
    CPP->>FFI: EpDispatchBwdFFI
    FFI-->>PY: grad_tokens, grad_topk_weights
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant PY as Python (ep.py)
    participant CPP as cpp_extensions/ep.py
    participant FFI as XLA FFI (ep.cpp)
    participant NCCL as NCCL EP Backend

    Note over PY: ep_bootstrap()
    PY->>FFI: set_ep_bootstrap_params(uid, ep_size, ...)
    FFI->>NCCL: ncclCommInitRank + nvte_ep_initialize

    Note over PY: Forward Pass (ep_dispatch)
    PY->>CPP: ep_prepare(cfg, topk_idx)
    CPP->>FFI: EpPrepareFFI
    FFI-->>CPP: token_counts, handle_mem
    PY->>CPP: ep_dispatch_fwd(cfg, handle_mem, ...)
    CPP->>FFI: EpDispatchFFI
    FFI->>NCCL: ncclEpDispatch (collective)
    NCCL-->>PY: recv_tokens, recv_topk_weights

    Note over PY: Expert Computation (user code)

    Note over PY: Forward Pass (ep_combine)
    PY->>CPP: ep_combine_fwd(cfg, handle_mem, expert_out, ...)
    CPP->>FFI: EpCombineFFI
    FFI->>NCCL: ncclEpCombine (collective)
    NCCL-->>PY: result

    Note over PY: Backward Pass
    PY->>CPP: ep_combine_bwd(cfg, handle_mem, g_result, ...)
    CPP->>FFI: EpCombineBwdFFI
    FFI-->>PY: grad_expert_out
    PY->>CPP: ep_dispatch_bwd(cfg, handle_mem, g_recv, ...)
    CPP->>FFI: EpDispatchBwdFFI
    FFI-->>PY: grad_tokens, grad_topk_weights
Loading

Reviews (22): Last reviewed commit: "Guard against None out_partition_spec in..." | Re-trigger Greptile

Comment thread build_tools/jax.py Outdated
Comment thread build_tools/jax.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp Outdated
Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts,
Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) {
auto topk_dims = topk_idx.dimensions();
NVTE_CHECK(topk_dims.size() >= 2,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.

@phu0ngng phu0ngng requested a review from tdophung May 22, 2026 15:51
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

I would appreciate your help to review this PR @tdophung @jberchtold-nvidia!
Please focus on the changes in the JAX side, as the TE/Common ones will be discussed in #3034

Comment thread examples/jax/ep/ep_moe.py Outdated
Comment thread tests/jax/multi_process_launch_ep.sh Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread examples/jax/ep/ep_moe.py
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/ep.py Outdated

@jberchtold-nvidia jberchtold-nvidia left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending CI

Comment thread transformer_engine/jax/ep.py
Comment thread transformer_engine/jax/ep.py
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions.h Outdated
jberchtold-nvidia pushed a commit to jberchtold-nvidia/TransformerEngine that referenced this pull request Jun 5, 2026
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype
field. The C++ backend (ep_backend.cpp:349) enforces
    typeToSize(tok_dtype) <= typeToSize(max_token_dtype)
at every dispatch, and the field is also used at group create to size the
NCCL EP staging buffers (ep_backend.cpp:221-222).

PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written
before this field existed and never set it, so any JAX EP group landed with
the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from
JAX then failed immediately with:
    tokens dtype (6) wider than group max_token_dtype (0)

This commit threads max_token_dtype end-to-end:

  - transformer_engine/jax/csrc/extensions.h
    update SetEpBootstrapParams declaration to match the new arity.

  - transformer_engine/jax/csrc/extensions/ep.cpp
    add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams;
    forward it into NVTEEpGroupConfig in the EpResources ctor.

  - transformer_engine/jax/csrc/extensions/pybind.cpp
    add the matching pybind11::arg("max_token_dtype") = 0.

  - transformer_engine/jax/ep.py
    add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to
    NVTEDType int, forward to the C++ setter.

Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream.
See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp
@phu0ngng phu0ngng force-pushed the phuong/ep-3-jax branch 2 times, most recently from 06f8a13 to c34771d Compare June 10, 2026 15:24
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 10, 2026
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype
field. The C++ backend (ep_backend.cpp:349) enforces
    typeToSize(tok_dtype) <= typeToSize(max_token_dtype)
at every dispatch, and the field is also used at group create to size the
NCCL EP staging buffers (ep_backend.cpp:221-222).

PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written
before this field existed and never set it, so any JAX EP group landed with
the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from
JAX then failed immediately with:
    tokens dtype (6) wider than group max_token_dtype (0)

This commit threads max_token_dtype end-to-end:

  - transformer_engine/jax/csrc/extensions.h
    update SetEpBootstrapParams declaration to match the new arity.

  - transformer_engine/jax/csrc/extensions/ep.cpp
    add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams;
    forward it into NVTEEpGroupConfig in the EpResources ctor.

  - transformer_engine/jax/csrc/extensions/pybind.cpp
    add the matching pybind11::arg("max_token_dtype") = 0.

  - transformer_engine/jax/ep.py
    add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to
    NVTEDType int, forward to the C++ setter.

Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream.
See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 10, 2026
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with
EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied
the three deltas uniquely ours:

  * transformer_engine/jax/moe.py: replaces upstream's multi-backend
    MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted
    to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle
    (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed
    in place of handle, ep_prepare arg order swapped, top_k= dropped
    from ep_dispatch_bwd since it's now in cfg.
  * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with
    ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped
    (no longer supported; ep_size is derived from mesh axes and the
    handle_mem reloc gating is gone).
  * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept).

Pre-sync state preserved at branch
teddy/te_ep_integration.backup-pre-phuong-sync.
EOF
)
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

phu0ngng and others added 21 commits June 24, 2026 00:22
…16 max_token_dtype

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… with_sharding_constraint

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…trap

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…EpLayerConfig type)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ives (lint 10.00)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…; define NVTE_WITH_NCCL_EP

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ract, drop dead helpers

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…e example) jax distributed suites

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ARY_PATH for libnccl_ep.so

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ck via nccl_ep_enabled()

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
@phu0ngng phu0ngng added the 2.7.0 label Jun 24, 2026
…ition methods

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

@jberchtold-nvidia jberchtold-nvidia left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending CI

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants