Skip to content

[JAX] Expert Parallelism: JAX primitives + VJPs#3036

Open
phu0ngng wants to merge 22 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax
Open

[JAX] Expert Parallelism: JAX primitives + VJPs#3036
phu0ngng wants to merge 22 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax

Guard against None out_partition_spec in EP combine/dispatch-bwd part…

db8d43b
Select commit
Loading
Failed to load commit list.