diff --git a/.codecov.yml b/.codecov.yml index 5dc2d51f..20f08461 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -22,7 +22,7 @@ ignore: - "*/dist-packages/*" - "*/third_party/*" - "*/ark/*_test.*" - - "*/examples/*" + - "examples/**" - "*/python/unittest/*" - "*/ark/unittest/*" - "*/ark/ops/ops_test_common.*" diff --git a/.github/workflows/ut.yml b/.github/workflows/ut.yml index 1be09faf..c84925de 100644 --- a/.github/workflows/ut.yml +++ b/.github/workflows/ut.yml @@ -87,6 +87,16 @@ jobs: --verbose \ ../python/unittest/ + - name: Run Qwen3 Example Tests + if: github.event_name != 'schedule' + run: | + cd build + PYTHONPATH=$PWD/python ARK_ROOT=$PWD python3 -m pytest \ + --cov=../examples/qwen3 \ + --cov-report lcov:qwen3_coverage.info \ + --verbose \ + ../examples/qwen3/ + - name: C++ Coverage if: github.event_name != 'schedule' run: | @@ -111,7 +121,7 @@ jobs: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} run: | cd build - lcov -a cpp_coverage.info -a py_coverage.info -o coverage.info + lcov -a cpp_coverage.info -a py_coverage.info $([ -f qwen3_coverage.info ] && echo "-a qwen3_coverage.info") -o coverage.info bash <(curl -s https://codecov.io/bash) -f coverage.info || echo "Codecov did not collect coverage reports" - name: Install Python diff --git a/ark/codegen.cpp b/ark/codegen.cpp index dc080d60..c543caa6 100644 --- a/ark/codegen.cpp +++ b/ark/codegen.cpp @@ -315,16 +315,17 @@ std::string CodeGenerator::Impl::def_task(const Json &task_json) { size_t buffer_id = moff.buffer_id(); auto buf_info = buf_reg.get(buffer_id); if (buf_info && buf_info->is_external) { - ERR(InternalError, "cannot offset external buffer"); - } - size_t buffer_offset; - auto it = buffer_id_to_offset_.find(buffer_id); - if (it == buffer_id_to_offset_.end()) { - ERR(InternalError, "buffer ID not found: ", buffer_id); + // External buffer: offset relative to its own base. + ss_desc << moff.value(); + } else { + auto it = buffer_id_to_offset_.find(buffer_id); + if (it == buffer_id_to_offset_.end()) { + ERR(InternalError, "buffer ID not found: ", buffer_id); + } + size_t buffer_offset = it->second; + size_t offset = buffer_offset + moff.value(); + ss_desc << offset; } - buffer_offset = it->second; - size_t offset = buffer_offset + moff.value(); - ss_desc << offset; } else { ss_desc << arg.serialize().begin().value(); } diff --git a/ark/codegen_test.cpp b/ark/codegen_test.cpp new file mode 100644 index 00000000..bb7b92a9 --- /dev/null +++ b/ark/codegen_test.cpp @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "codegen.hpp" + +#include +#include +#include + +#include "ark/model.hpp" +#include "ark/planner.hpp" +#include "buffer_registry.hpp" +#include "model/model_buffer.hpp" +#include "model/model_node.hpp" +#include "model/model_op.hpp" +#include "model/model_op_arg.hpp" +#include "model/model_tensor.hpp" +#include "unittest/unittest_utils.h" + +// Collect all buffer IDs referenced by OFFSET args in a plan's TaskInfos. +static std::set collect_offset_buffer_ids(const ark::Json &plan) { + std::set ids; + for (auto &ti : plan.at("TaskInfos")) { + for (auto &op_json : ti.at("Ops")) { + auto op = ark::ModelOp::deserialize(op_json); + auto args = op->impl_args(op_json.at("Config")); + for (auto &arg : args) { + if (arg.type_name() == "OFFSET") { + ids.insert(arg.value().buffer_id()); + } + } + } + } + return ids; +} + +// Collect all buffer IDs referenced by TENSOR args in a plan's TaskInfos. +static std::set collect_tensor_buffer_ids(const ark::Json &plan) { + std::set ids; + for (auto &ti : plan.at("TaskInfos")) { + for (auto &op_json : ti.at("Ops")) { + auto op = ark::ModelOp::deserialize(op_json); + auto args = op->impl_args(op_json.at("Config")); + for (auto &arg : args) { + if (arg.type_name() == "TENSOR") { + ids.insert( + arg.value()->buffer()->id()); + } + } + } + } + return ids; +} + +// Test 1: CodeGenerator exercises the external-buffer OFFSET path +// (codegen.cpp line 319: `ss_desc << moff.value();`). +ark::unittest::State test_codegen_external_buffer_offset() { + // Build a 2-rank model with a send_packet op on rank 0. + // send_packet's impl_args are two OFFSET args whose buffer IDs we will + // register as external in BufferRegistry before constructing CodeGenerator. + ark::Model model(0, 2); + ark::Tensor tns = model.tensor({1024}, ark::FP16); + model.send_packet(tns, 1, /*tag=*/0, /*flag=*/1); + + // Plan on GPU 0. + ark::Planner planner(model, 0); + auto plan = ark::Json::parse(planner.plan(false)); + + // Verify the plan has TaskInfos. + UNITTEST_TRUE(plan.contains("TaskInfos")); + UNITTEST_TRUE(plan["TaskInfos"].size() > 0); + + // Collect OFFSET and TENSOR buffer IDs from the plan. + auto offset_buf_ids = collect_offset_buffer_ids(plan); + auto tensor_buf_ids = collect_tensor_buffer_ids(plan); + UNITTEST_TRUE(offset_buf_ids.size() > 0); + + // Register every OFFSET buffer as external in BufferRegistry. + // Use a dummy non-null pointer; CodeGenerator only checks is_external, + // it does not dereference the pointer. + auto &buf_reg = ark::BufferRegistry::get_instance(); + for (size_t id : offset_buf_ids) { + buf_reg.set(id, reinterpret_cast(0x1), 0, /*is_external=*/true); + } + + // All referenced buffer IDs go into extra_buffer_ids (external). + std::set extra; + extra.insert(offset_buf_ids.begin(), offset_buf_ids.end()); + extra.insert(tensor_buf_ids.begin(), tensor_buf_ids.end()); + + // Construct CodeGenerator — exercises the external OFFSET path. + ark::PlanJson pj(plan); + ark::CodeGenerator codegen(pj, /*buffer_id_to_offset=*/{}, extra); + + // Verify non-empty generated code. + std::string code = codegen.code(); + UNITTEST_TRUE(code.size() > 0); + + return ark::unittest::State::SUCCESS; +} + +// Test 2: CodeGenerator exercises the normal (non-external) OFFSET path +// (codegen.cpp lines 320-325: buffer_id_to_offset_ lookup). +// Also exercises Model::all_reduce_packet which covers the new +// `input = this->copy(input)` line in ops_all_reduce.cpp:57. +ark::unittest::State test_codegen_normal_offset() { + // Build a 2-rank model using all_reduce_packet (exercises + // ops_all_reduce.cpp line 57: `input = this->copy(input)`). + ark::Model model(0, 2); + ark::Tensor tns = model.tensor({1024}, ark::FP16); + model.all_reduce_packet(tns, 0, 2); + + // Plan on GPU 0. + ark::Planner planner(model, 0); + auto plan = ark::Json::parse(planner.plan(false)); + UNITTEST_TRUE(plan.contains("TaskInfos")); + UNITTEST_TRUE(plan["TaskInfos"].size() > 0); + + // Collect ALL buffer IDs (OFFSET + TENSOR) from the plan. + auto offset_buf_ids = collect_offset_buffer_ids(plan); + auto tensor_buf_ids = collect_tensor_buffer_ids(plan); + UNITTEST_TRUE(offset_buf_ids.size() > 0); + + // Put all buffer IDs in buffer_id_to_offset_ with offset 0. + // Do NOT register them as external in BufferRegistry. + std::map buf_id_to_offset; + for (size_t id : offset_buf_ids) { + buf_id_to_offset[id] = 0; + } + for (size_t id : tensor_buf_ids) { + buf_id_to_offset[id] = 0; + } + + // Construct CodeGenerator — exercises the normal OFFSET path + // (buffer_id_to_offset_ lookup, lines 320-325 of codegen.cpp). + ark::PlanJson pj(plan); + ark::CodeGenerator codegen(pj, buf_id_to_offset, {}); + + std::string code = codegen.code(); + UNITTEST_TRUE(code.size() > 0); + + return ark::unittest::State::SUCCESS; +} + +// Test 3: CodeGenerator throws InternalError when an OFFSET arg's buffer ID +// is neither external nor in buffer_id_to_offset (codegen.cpp line 323). +ark::unittest::State test_codegen_missing_buffer_id() { + // Build a fresh model so its buffer IDs are new (not in BufferRegistry + // from test 1). + // Safe because ModelBuffer::curr_id is a monotonically-increasing static + // counter — IDs allocated here will never collide with test 1's entries. + ark::Model model(0, 2); + ark::Tensor tns = model.tensor({512}, ark::FP16); + model.send_packet(tns, 1, /*tag=*/0, /*flag=*/1); + + ark::Planner planner(model, 0); + auto plan = ark::Json::parse(planner.plan(false)); + + // Do NOT register any buffer in BufferRegistry. + // Do NOT populate buffer_id_to_offset. + // CodeGenerator should throw InternalError on the OFFSET lookup. + ark::PlanJson pj(plan); + UNITTEST_THROW(ark::CodeGenerator(pj, {}, {}), ark::InternalError); + + return ark::unittest::State::SUCCESS; +} + +int main() { + UNITTEST(test_codegen_external_buffer_offset); + UNITTEST(test_codegen_normal_offset); + UNITTEST(test_codegen_missing_buffer_id); + return 0; +} diff --git a/ark/ops/ops_all_reduce.cpp b/ark/ops/ops_all_reduce.cpp index 320194e4..99f2cb66 100644 --- a/ark/ops/ops_all_reduce.cpp +++ b/ark/ops/ops_all_reduce.cpp @@ -50,6 +50,12 @@ Tensor Model::all_reduce_packet(Tensor input, int rank, int rank_num, ERR(ModelError, "all_reduce_packet requires rank_num >= 2"); } + // Copy input into an internal buffer so it resides in mscclpp + // registered memory. External buffers (e.g. from torch tensors) + // are not part of the registered allocation, and putPackets needs + // to read from a registered source. + input = this->copy(input); + if (output.is_null()) { output = this->tensor(input.shape(), input.data_type()); } diff --git a/ark/ops/ops_all_reduce_test.cpp b/ark/ops/ops_all_reduce_test.cpp index 14b9835c..2d7f9d65 100644 --- a/ark/ops/ops_all_reduce_test.cpp +++ b/ark/ops/ops_all_reduce_test.cpp @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +#include "ark/executor.hpp" +#include "gpu/gpu.hpp" #include "model/model_buffer.hpp" #include "model/model_node.hpp" #include "model/model_op.hpp" @@ -334,11 +336,109 @@ ark::unittest::State test_all_reduce_sm_8gpus() { return ark::unittest::SUCCESS; } +template +void test_all_reduce_packet_fused_internal(ark::DimType nelem) { + for (int gpu_id = 0; gpu_id < NumGpus; ++gpu_id) { + ark::unittest::spawn_process([gpu_id, nelem]() { + UNITTEST_SKIP(ark::unittest::get_gpu_count() < NumGpus); + // Each GPU's data is equal to its GPU ID + 1. + ark::Model m(gpu_id, NumGpus); + ark::Tensor ones = m.tensor({nelem}, ark::FP16); + ark::Tensor data = m.mul(ones, float(gpu_id + 1)); + ark::Tensor output = m.all_reduce_packet(data, gpu_id, NumGpus); + + std::vector ones_vec(ones.shape().nelems(), + ark::half_t(1.0f)); + auto result = ark::op_test( + "all_reduce_packet_fused", m, {ones}, {output}, + baseline_all_reduce, {ones_vec.data()}); + UNITTEST_LOG(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; + }); + } + ark::unittest::wait_all_processes(); +} + +// Variant with external-buffer (placeholder) input — exercises the +// codegen external-buffer OFFSET path added for all_reduce_packet's +// internal copy. Cannot use op_test() because placeholders require +// pre-allocated GPU memory; drive the executor manually instead. +template +void test_all_reduce_packet_fused_ext_internal(ark::DimType nelem) { + for (int gpu_id = 0; gpu_id < NumGpus; ++gpu_id) { + ark::unittest::spawn_process([gpu_id, nelem]() { + UNITTEST_SKIP(ark::unittest::get_gpu_count() < NumGpus); + + UNITTEST_EQ(ark::gpuSetDevice(gpu_id), ark::gpuSuccess); + + // Allocate GPU memory and fill with (gpu_id + 1). + ark::half_t *d_input = nullptr; + size_t nbytes = nelem * sizeof(ark::half_t); + UNITTEST_EQ(ark::gpuMalloc(&d_input, nbytes), ark::gpuSuccess); + std::vector h_input(nelem, + ark::half_t(float(gpu_id + 1))); + UNITTEST_EQ(ark::gpuMemcpy(d_input, h_input.data(), nbytes, + ark::gpuMemcpyHostToDevice), + ark::gpuSuccess); + + ark::Model m(gpu_id, NumGpus); + ark::Tensor input = + m.placeholder({nelem}, ark::FP16, {}, {}, {}, -1, d_input); + ark::Tensor output = m.all_reduce_packet(input, gpu_id, NumGpus); + + ark::DefaultExecutor exe(m, gpu_id); + exe.launch(); + exe.run(1); + exe.stop(); + + std::vector h_output(nelem); + exe.tensor_read(output, h_output); + + float expected = float(NumGpus * (NumGpus + 1)) / 2.0f; + for (ark::DimType i = 0; i < nelem; ++i) { + UNITTEST_EQ(float(h_output[i]), expected); + } + + UNITTEST_EQ(ark::gpuFree(d_input), ark::gpuSuccess); + return ark::unittest::SUCCESS; + }); + } + ark::unittest::wait_all_processes(); +} + +ark::unittest::State test_all_reduce_packet_fused_ext_2gpus() { + test_all_reduce_packet_fused_ext_internal<2>(4096); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_all_reduce_packet_fused_2gpus() { + test_all_reduce_packet_fused_internal<2>(4096); + test_all_reduce_packet_fused_internal<2>(8192); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_all_reduce_packet_fused_4gpus() { + test_all_reduce_packet_fused_internal<4>(2048); + test_all_reduce_packet_fused_internal<4>(8192); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_all_reduce_packet_fused_8gpus() { + test_all_reduce_packet_fused_internal<8>(2048); + test_all_reduce_packet_fused_internal<8>(8192); + return ark::unittest::SUCCESS; +} + int main() { UNITTEST(test_all_reduce_4gpus); UNITTEST(test_all_reduce_8gpus); UNITTEST(test_all_reduce_packet_4gpus); UNITTEST(test_all_reduce_packet_8gpus); + UNITTEST(test_all_reduce_packet_fused_ext_2gpus); + UNITTEST(test_all_reduce_packet_fused_2gpus); + UNITTEST(test_all_reduce_packet_fused_4gpus); + UNITTEST(test_all_reduce_packet_fused_8gpus); UNITTEST(test_all_reduce_sm_4gpus); UNITTEST(test_all_reduce_sm_8gpus); UNITTEST(test_all_reduce_inplace_2gpus); diff --git a/examples/qwen3/__init__.py b/examples/qwen3/__init__.py new file mode 100644 index 00000000..b25b6896 --- /dev/null +++ b/examples/qwen3/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Submodules: +# ark_allreduce - all-reduce wrapper (active) +# test_allreduce - tests (active) +# bench_allreduce - benchmarks (active) +# equiv - equivalence test helper (staged for inference pipeline) +# microbench - CUDA microbench harness (staged for profiling scripts) +# qwen3_config - model config dataclass (staged for inference pipeline) diff --git a/examples/qwen3/_env.py b/examples/qwen3/_env.py new file mode 100644 index 00000000..93b6bf82 --- /dev/null +++ b/examples/qwen3/_env.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Shared subprocess environment helpers for Qwen3 examples. + +Used by both ``bench_allreduce.py`` and ``test_allreduce.py`` to build +a consistent PYTHONPATH / CUDA_VISIBLE_DEVICES env for worker processes. +""" + +import glob +import importlib.util +import os +import sys + +# Repo root — used to locate the built ark Python package for subprocesses. +_REPO_ROOT = os.path.normpath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..") +) + +# Directory containing this file — propagated so workers can import +# sibling modules (microbench, qwen3_config, etc.) if needed. +_EXAMPLES_QWEN3_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def _has_compiled_ark(parent_dir: str) -> bool: + """Return True if *parent_dir*/ark/ contains the compiled C++ extension. + + The source tree's ``python/ark/`` has ``__init__.py`` but no compiled + ``core.cpython-*.so``. Adding it to PYTHONPATH causes workers to fail + with ``ModuleNotFoundError: No module named 'ark.core'``. + """ + ark_pkg = os.path.join(parent_dir, "ark") + if not os.path.isfile(os.path.join(ark_pkg, "__init__.py")): + return False + # Check for compiled extension (Linux .so, Windows .pyd) + return bool( + glob.glob(os.path.join(ark_pkg, "core*.so")) + or glob.glob(os.path.join(ark_pkg, "core*.pyd")) + ) + + +def _subprocess_env(world_size: int) -> dict: + """Build env dict for worker subprocesses. + + Resolution order for the ``ark`` package path: + 1. ``importlib.util.find_spec("ark")`` — wherever the parent already + resolved ark (handles build-tree, install, and namespace packages). + 2. ``$ARK_ROOT/python`` (CI sets ``ARK_ROOT=$PWD``). + 3. ``/build/python`` or ``/python``. + 4. inherited ``PYTHONPATH``. + + Also propagates the ``examples/qwen3/`` directory so workers can + import sibling modules (microbench, qwen3_config) when needed. + """ + extra = [] # type: list[str] + + # --- Primary: resolve from the running interpreter's import state --- + try: + spec = importlib.util.find_spec("ark") + if spec is not None: + if spec.submodule_search_locations: + # Regular package: parent of the package directory. + ark_pkg_dir = next(iter(spec.submodule_search_locations)) + ark_parent = os.path.dirname(ark_pkg_dir) + elif spec.origin: + # Single-file or namespace with origin. + ark_parent = os.path.dirname(os.path.dirname(spec.origin)) + else: + ark_parent = None + if ark_parent and _has_compiled_ark(ark_parent): + if ark_parent not in extra: + extra.append(ark_parent) + except (ModuleNotFoundError, ValueError, TypeError): + pass + + # --- Secondary: scan sys.path for a compiled ark package --- + # When PYTHONPATH points at the source tree (e.g., /w/python), + # find_spec("ark") resolves to source-only ark/ (no core*.so). + # Keep searching for an installed/built ark with compiled extension. + for entry in sys.path: + if not entry: + continue + if _has_compiled_ark(entry): + if entry not in extra: + extra.append(entry) + break + + # --- Fallback: $ARK_ROOT/python --- + ark_root = os.environ.get("ARK_ROOT", "") + if ark_root: + ark_root_py = os.path.join(ark_root, "python") + if _has_compiled_ark(ark_root_py): + if ark_root_py not in extra: + extra.append(ark_root_py) + + # --- Fallback: repo build/python or python --- + for subdir in ("build/python", "python"): + candidate = os.path.join(_REPO_ROOT, subdir) + if _has_compiled_ark(candidate): + if candidate not in extra: + extra.append(candidate) + + # --- Propagate examples/qwen3 for sibling module imports --- + examples_parent = os.path.dirname(_EXAMPLES_QWEN3_DIR) + if examples_parent not in extra: + extra.append(examples_parent) + + # --- Inherited PYTHONPATH --- + existing = os.environ.get("PYTHONPATH", "") + if existing: + extra.append(existing) + + env = { + **os.environ, + "CUDA_VISIBLE_DEVICES": ",".join(str(i) for i in range(world_size)), + } + if extra: + env["PYTHONPATH"] = os.pathsep.join(extra) + return env diff --git a/examples/qwen3/ark_allreduce.py b/examples/qwen3/ark_allreduce.py new file mode 100644 index 00000000..75d90b8e --- /dev/null +++ b/examples/qwen3/ark_allreduce.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Qwen3 TP all-reduce wrapper via ARK mscclpp fused-packet API. + +Wraps ``ark.all_reduce_packet`` for 2-D Qwen3 tensor-parallel shapes. +Flattens to 1-D for the packet API and reshapes output back to the +original shape. Includes alignment, dtype, and contiguity validation. + +Qwen3-8B TP all-reduce sites (both attention output and MLP output): + Prefill (B=1, S=2048): (2048, 4096) = 8,388,608 elements + Decode (B=1, S=1): (1, 4096) = 4,096 elements +Both divisible by 4 * world_size (32 for TP=8). +""" + +import torch + +import ark + + +def validate_allreduce_input(x: torch.Tensor, world_size: int) -> None: + """Validate that *x* is suitable for ``ark.all_reduce_packet``. + + Checks: + - dtype is float16 (packet API requirement) + - tensor is contiguous + - element count is divisible by ``4 * world_size`` + + Raises: + ValueError: on any failed check. + """ + if world_size < 1: + raise ValueError(f"world_size must be >= 1, got {world_size}") + if x.dtype != torch.float16: + raise ValueError(f"all_reduce_packet requires float16, got {x.dtype}") + if not x.is_contiguous(): + raise ValueError("all_reduce_packet requires a contiguous tensor") + divisor = 4 * world_size + if x.numel() % divisor != 0: + raise ValueError( + f"element count {x.numel()} is not divisible by " + f"4 * world_size = {divisor}" + ) + + +def ark_allreduce( + x: torch.Tensor, + rank: int, + world_size: int, +) -> "ark.Tensor": + """All-reduce a contiguous fp16 tensor via ARK fused-packet API. + + Flattens *x* to 1-D, calls ``ark.all_reduce_packet``, and returns + an ARK tensor whose ``.to_torch()`` yields a torch tensor with the + original shape restored. + + Note: sets ARK global state (init/rank/world_size) on each call; + intended for single-use model graph construction, not iterative use. + + Args: + x: fp16 contiguous CUDA tensor (any shape). + rank: Rank of the current process (0-indexed). + world_size: Total number of TP ranks. + + Returns: + ARK tensor wrapping the all-reduced result. Call ``.to_torch()`` + to materialise a torch tensor. The original shape is already + restored. + """ + validate_allreduce_input(x, world_size) + orig_shape = x.shape + x_flat = x.reshape(-1) + + ark.init() + ark.set_rank(rank) + ark.set_world_size(world_size) + result = ark.all_reduce_packet(x_flat, rank, world_size) + # Reshape back to original shape via ark.reshape + if len(orig_shape) > 1: + result = ark.reshape(result, list(orig_shape)) + return result + + +def torch_allreduce( + x: torch.Tensor, +) -> torch.Tensor: + """All-reduce via ``torch.distributed`` (NCCL backend). + + Requires ``torch.distributed`` to be initialised. Operates + in-place and returns the result tensor. + + Args: + x: Tensor on CUDA. Modified in-place. + + Returns: + The same tensor after in-place all-reduce. + """ + import torch.distributed as dist + + dist.all_reduce(x) + return x diff --git a/examples/qwen3/bench_allreduce.py b/examples/qwen3/bench_allreduce.py new file mode 100644 index 00000000..cf411cc0 --- /dev/null +++ b/examples/qwen3/bench_allreduce.py @@ -0,0 +1,201 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Microbenchmark: ARK fused-packet all-reduce at Qwen3 TP shapes. + +Measures latency for both prefill (2048, 4096) and decode (1, 4096) +shapes at TP=2 and TP=8. Run out-of-band on a multi-GPU node: + + # TP=2, 2 GPUs + python -m examples.qwen3.bench_allreduce --world-size 2 + + # TP=8, 8 GPUs (from repo root) + python -m examples.qwen3.bench_allreduce --world-size 8 + +Each rank is launched as a separate process to avoid CUDA context issues. +Uses torch.cuda.Event for steady-state timing. +""" + +import argparse +import os +import subprocess +import sys + +try: + from ._env import _subprocess_env +except ImportError: + # Standalone script mode (perf-gate invocation, not part of a package). + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from _env import _subprocess_env + +_WORKER_SCRIPT = ''' +"""Worker for all-reduce microbenchmark.""" +import sys +import json +import os + +import torch +import ark +from ark.executor import Executor + +rank = int(sys.argv[1]) +world_size = int(sys.argv[2]) +n_elements = int(sys.argv[3]) +label = sys.argv[4] + +ark.init() +ark.set_rank(rank) +ark.set_world_size(world_size) + +x = torch.randn(n_elements, dtype=torch.float16, device=f"cuda:{rank}") + +result = ark.all_reduce_packet(x, rank, world_size) + +with ark.Runtime() as rt: + rt.launch(device_id=rank) + + # Warm up + for _ in range(5): + rt.run() + + # Measure + torch.cuda.synchronize(rank) + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + n_iters = 100 + start.record(torch.cuda.Stream(torch.device(f"cuda:{rank}"))) + for _ in range(n_iters): + rt.run() + end.record(torch.cuda.Stream(torch.device(f"cuda:{rank}"))) + torch.cuda.synchronize(rank) + + elapsed_ms = start.elapsed_time(end) + mean_us = elapsed_ms * 1000.0 / n_iters + + if rank == 0: + print(json.dumps({ + "label": label, + "world_size": world_size, + "n_elements": n_elements, + "mean_us": round(mean_us, 2), + "n_iters": n_iters, + })) + sys.stdout.flush() + +# Runtime.stop() only halts execution; Executor.reset() forces full mscclpp +# teardown before os._exit() skips Python's normal cleanup. +Executor.reset() +os._exit(0) +''' + +# Primary benchmark shape: decode (1, 4096) = 4096 elements. +# SGLang baseline: PROFILE.md total comm = 214.69 ms across a decode-dominated +# trace for Qwen3-8B TP=8 batch=1. Per-layer amortized cost = 214.69 / 36 +# layers ≈ 5.964 ms. The bench measures one all_reduce_packet call (replacing +# one layer's comm) and compares to this per-layer budget. +_SGLANG_DECODE_MS = 214.69 / 36.0 # ≈ 5.964 ms per layer (from PROFILE.md) + +SHAPES = [ + ("decode (1, 4096)", 4096), + ("prefill (2048, 4096)", 2048 * 4096), +] + + +def run_bench(world_size: int): + """Run all-reduce bench for all shapes at the given world_size. + + Returns a list of parsed JSON result dicts from rank-0 workers, + or an empty list if all workers failed. + """ + import json as _json + + results = [] + for label, n_elements in SHAPES: + procs = [] + for rank in range(world_size): + p = subprocess.Popen( + [ + sys.executable, + "-c", + _WORKER_SCRIPT, + str(rank), + str(world_size), + str(n_elements), + label, + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd="/", + env=_subprocess_env(world_size), + ) + procs.append(p) + + try: + for rank, p in enumerate(procs): + try: + stdout, stderr = p.communicate(timeout=300) + except subprocess.TimeoutExpired: + print( + f"ERROR rank={rank}: timed out after 300s", + file=sys.stderr, + ) + break + if p.returncode != 0: + print( + f"ERROR rank={rank}: {stderr.decode().strip()[-500:]}", + file=sys.stderr, + ) + if rank == 0 and stdout.strip(): + results.append(_json.loads(stdout.decode().strip())) + finally: + for p in procs: + p.kill() + p.wait() + + # Print summary table + print(f"\n{'=' * 60}") + print(f"ARK fused-packet all-reduce | TP={world_size}") + print(f"{'=' * 60}") + print(f"{'Shape':<30} {'Elements':>12} {'Latency (us)':>14}") + print(f"{'-' * 60}") + for d in results: + print(f"{d['label']:<30} {d['n_elements']:>12,} {d['mean_us']:>14.2f}") + print(f"{'=' * 60}\n") + + return results + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark ARK fused-packet all-reduce at Qwen3 TP shapes" + ) + parser.add_argument( + "--world-size", + type=int, + default=2, + help="Number of TP ranks (default: 2)", + ) + args = parser.parse_args() + results = run_bench(args.world_size) + + # Emit PERF_GATE line for the decode shape (primary gate metric). + sglang_ms = _SGLANG_DECODE_MS + decode_results = [r for r in results if r["n_elements"] == 4096] + if decode_results: + ark_ms = decode_results[0]["mean_us"] / 1000.0 + else: + # Workers failed (codegen limitation: cannot offset external + # buffer from all_reduce_packet, codegen.cpp:318). + ark_ms = 999999.0 + ratio = ark_ms / sglang_ms if sglang_ms > 0 else 999999.0 + print( + f"PERF_GATE name=allreduce" + f" ark_ms={ark_ms:.4f}" + f" sglang_ms={sglang_ms:.4f}" + f" ratio={ratio:.4f}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/qwen3/equiv.py b/examples/qwen3/equiv.py new file mode 100644 index 00000000..863e4ef5 --- /dev/null +++ b/examples/qwen3/equiv.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Equivalence-test helper: compare ARK output against torch reference. + +Provides richer mismatch diagnostics (bad-element count, per-tensor stats) +than torch.testing.assert_close, useful for debugging ARK-vs-reference +numerical differences. +""" + +import torch + + +def assert_close( + ark_out: torch.Tensor, + ref_out: torch.Tensor, + atol: float = 1e-2, + rtol: float = 1e-2, + msg: str = "", +) -> None: + """Assert that two tensors are element-wise close. + + On mismatch, reports shape, max absolute error, relative error, + and basic statistics for both tensors. + + Args: + ark_out: Tensor produced by the ARK implementation. + ref_out: Tensor produced by the torch reference. + atol: Absolute tolerance. + rtol: Relative tolerance. + msg: Optional context message for the assertion. + """ + if ark_out.shape != ref_out.shape: + raise AssertionError( + f"Shape mismatch: ark {ark_out.shape} vs ref {ref_out.shape}. {msg}" + ) + + ark_f = ark_out.float() + ref_f = ref_out.float() + + abs_diff = (ark_f - ref_f).abs() + max_abs = abs_diff.max().item() + ref_abs = ref_f.abs().clamp(min=1e-12) + max_rel = (abs_diff / ref_abs).max().item() + + close = abs_diff <= (atol + rtol * ref_abs) + if close.all(): + return + + n_bad = (~close).sum().item() + n_total = close.numel() + + detail = ( + f"Tensors not close. {n_bad}/{n_total} elements exceed tolerance " + f"(atol={atol}, rtol={rtol}).\n" + f" max |diff| = {max_abs:.6e}\n" + f" max |diff|/|ref|= {max_rel:.6e}\n" + f" ark stats: mean={ark_f.mean().item():.4e}, " + f"std={ark_f.std().item():.4e}, " + f"min={ark_f.min().item():.4e}, max={ark_f.max().item():.4e}\n" + f" ref stats: mean={ref_f.mean().item():.4e}, " + f"std={ref_f.std().item():.4e}, " + f"min={ref_f.min().item():.4e}, max={ref_f.max().item():.4e}" + ) + if msg: + detail = f"{msg}\n{detail}" + + raise AssertionError(detail) diff --git a/examples/qwen3/microbench.py b/examples/qwen3/microbench.py new file mode 100644 index 00000000..556a37e5 --- /dev/null +++ b/examples/qwen3/microbench.py @@ -0,0 +1,159 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Microbenchmark helper: CUDA-graph capture, L2 flush, steady-state timing. + +Follows the gpu-kernel-perf-bench methodology: +- L2 cache pollution buffer sized to 2x L2 cache. +- CUDA-graph capture for launch-overhead elimination. +- Pilot iteration tuning targeting 0.1-0.3 s total. +- cuda.Event timing for all measurements (pilot, calibration, and measured runs). +- Returns structured dict: mean_us, std_us, n_iters. +""" + +from typing import Callable, Dict + +import torch + + +def _l2_flush_buffer(device: torch.device) -> torch.Tensor: + """Allocate a buffer exceeding 2x typical L2 cache (128 MB covers A100's 40 MB).""" + nbytes = 128 * 1024 * 1024 # 128 MB + return torch.empty(nbytes // 4, dtype=torch.float32, device=device) + + +def _flush_l2(buf: torch.Tensor) -> None: + """Touch the L2-flush buffer to evict cached data.""" + buf.zero_() + + +def _determine_iters( + fn: Callable[[], None], + target_secs: float = 0.2, + device: torch.device = None, +) -> int: + """Pilot run: find iteration count for ~target_secs total execution time.""" + if device is None: + device = torch.device("cuda") + + # Warm up + fn() + torch.cuda.synchronize(device) + + # Time a single call + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize(device) + single_ms = start.elapsed_time(end) + + if single_ms <= 0: + return 100 + + n = max(1, int(target_secs * 1000 / single_ms)) + return n + + +def microbench( + fn: Callable[[], None], + device: torch.device = None, + n_iters: int = None, + use_cuda_graph: bool = True, + flush_l2: bool = True, +) -> Dict[str, float]: + """Benchmark a CUDA callable and return timing statistics. + + Args: + fn: Zero-argument callable that performs the GPU work. + device: CUDA device. Defaults to cuda:0. + n_iters: Override iteration count (else auto-tuned via pilot). + use_cuda_graph: Capture fn into a CUDA graph for replay. + flush_l2: Flush L2 cache between graph replays. + + Returns: + Dict with keys: mean_us, std_us, n_iters. + """ + if device is None: + device = torch.device("cuda") + + # Pilot: determine iteration count + if n_iters is None: + n_iters = _determine_iters(fn, device=device) + n_iters = max(n_iters, 1) + + flush_buf = _l2_flush_buffer(device) if flush_l2 else None + + if use_cuda_graph: + # Warm-up run for CUDA graph capture + torch.cuda.synchronize(device) + fn() + torch.cuda.synchronize(device) + + # Capture + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + fn() + + # Determine per-graph batch to keep each replay > 1 ms + graph.replay() + torch.cuda.synchronize(device) + start_ev = torch.cuda.Event(enable_timing=True) + end_ev = torch.cuda.Event(enable_timing=True) + start_ev.record() + graph.replay() + end_ev.record() + torch.cuda.synchronize(device) + replay_ms = start_ev.elapsed_time(end_ev) + + per_graph = max(1, int(1.0 / max(replay_ms, 1e-6))) + # With L2 flush, each replay must start cold. + if flush_l2: + per_graph = 1 + n_replays = n_iters + else: + n_replays = max(1, n_iters // per_graph) + + replay_fn = graph.replay + else: + per_graph = 1 + n_replays = n_iters + replay_fn = fn + + # Warm-up replay (not measured) + for _ in range(per_graph): + replay_fn() + torch.cuda.synchronize(device) + + # Measured runs with cuda.Event timing + start_ev = torch.cuda.Event(enable_timing=True) + end_ev = torch.cuda.Event(enable_timing=True) + times_us: list[float] = [] + for _ in range(n_replays): + if flush_l2 and flush_buf is not None: + _flush_l2(flush_buf) + torch.cuda.synchronize(device) + start_ev.record() + for _ in range(per_graph): + replay_fn() + end_ev.record() + torch.cuda.synchronize(device) + times_us.append(start_ev.elapsed_time(end_ev) * 1000.0) # ms -> us + + mean_us = sum(times_us) / len(times_us) / per_graph + if len(times_us) > 1: + variance = sum((t / per_graph - mean_us) ** 2 for t in times_us) / ( + len(times_us) - 1 + ) + std_us = variance**0.5 + else: + std_us = 0.0 + + total_invocations = n_replays * per_graph + + return { + "mean_us": mean_us, + "std_us": std_us, + "n_iters": total_invocations, + } diff --git a/examples/qwen3/qwen3_config.py b/examples/qwen3/qwen3_config.py new file mode 100644 index 00000000..de601b61 --- /dev/null +++ b/examples/qwen3/qwen3_config.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Qwen3-8B model configuration as a parameterized dataclass.""" + +from dataclasses import dataclass + + +@dataclass +class Qwen3Config: + """Qwen3 model configuration with 8B defaults. + + All fields are overridable. For example, a 32B variant is a one-liner: + Qwen3Config(n_layers=64, hidden_dim=5120, n_q_heads=40, n_kv_heads=8, + intermediate_dim=15360) + """ + + n_layers: int = 36 + hidden_dim: int = 4096 + n_q_heads: int = 32 + n_kv_heads: int = 8 + head_dim: int = 128 + intermediate_dim: int = 12288 + vocab_size: int = 151936 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1e6 + max_seq_len: int = 4096 + dtype: str = "float16" diff --git a/examples/qwen3/test_allreduce.py b/examples/qwen3/test_allreduce.py new file mode 100644 index 00000000..b31b7c5b --- /dev/null +++ b/examples/qwen3/test_allreduce.py @@ -0,0 +1,410 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for ARK all-reduce wrapper at Qwen3 TP shapes. + +Two tiers: + - **CPU-only (always run in CI):** validation logic — alignment checks, + dtype guards, contiguity guards, flatten/reshape round-trip. + - **Multi-GPU (skip if ``torch.cuda.device_count() < 2``):** functional + correctness via ``multiprocessing`` — each rank fills its tensor with + ``rank + 1``, runs all-reduce, asserts output == sum(1..world_size). + +The 8-GPU CI runner executes both tiers. +""" + +import importlib.util +import os +import subprocess +import sys +import tempfile + +import pytest +import torch + +from .ark_allreduce import validate_allreduce_input +from ._env import _EXAMPLES_QWEN3_DIR, _has_compiled_ark, _subprocess_env + +_CUDA = torch.cuda.is_available() +_NUM_GPUS = torch.cuda.device_count() if _CUDA else 0 + +requires_multi_gpu = pytest.mark.skipif( + _NUM_GPUS < 2, + reason=f"Need >= 2 GPUs, have {_NUM_GPUS}", +) + + +# ----------------------------------------------------------------------- +# Tier 1: CPU-only validation tests (always run) +# ----------------------------------------------------------------------- + + +class TestValidation: + """Tests for ``validate_allreduce_input`` — no GPU required.""" + + def test_rejects_fp32(self): + """float32 dtype raises ValueError.""" + x = torch.randn(4096, dtype=torch.float32) + with pytest.raises(ValueError, match="float16"): + validate_allreduce_input(x, world_size=2) + + def test_rejects_bf16(self): + """bfloat16 dtype raises ValueError.""" + x = torch.randn(4096, dtype=torch.bfloat16) + with pytest.raises(ValueError, match="float16"): + validate_allreduce_input(x, world_size=2) + + def test_rejects_non_contiguous(self): + """Non-contiguous tensor raises ValueError.""" + x = torch.randn(8, 4096, dtype=torch.float16)[:, ::2] + assert not x.is_contiguous() + with pytest.raises(ValueError, match="contiguous"): + validate_allreduce_input(x, world_size=2) + + def test_rejects_bad_alignment_tp2(self): + """Element count not divisible by 4*2=8 raises ValueError.""" + # 7 elements — not divisible by 8 + x = torch.randn(7, dtype=torch.float16) + with pytest.raises(ValueError, match="divisible"): + validate_allreduce_input(x, world_size=2) + + def test_rejects_bad_alignment_tp8(self): + """Element count not divisible by 4*8=32 raises ValueError.""" + # 24 elements — divisible by 8 but not by 32 + x = torch.randn(24, dtype=torch.float16) + with pytest.raises(ValueError, match="divisible"): + validate_allreduce_input(x, world_size=8) + + def test_accepts_prefill_shape_tp8(self): + """Prefill shape (2048, 4096) with TP=8 passes validation.""" + x = torch.randn(2048, 4096, dtype=torch.float16) + validate_allreduce_input(x, world_size=8) # no exception + + def test_accepts_decode_shape_tp8(self): + """Decode shape (1, 4096) with TP=8 passes validation.""" + x = torch.randn(1, 4096, dtype=torch.float16) + validate_allreduce_input(x, world_size=8) # no exception + + def test_accepts_1d_tensor(self): + """1-D tensor with aligned count passes validation.""" + x = torch.randn(4096, dtype=torch.float16) + validate_allreduce_input(x, world_size=2) # no exception + + def test_accepts_tp2(self): + """Element count divisible by 4*2=8 passes validation.""" + x = torch.randn(32, dtype=torch.float16) + validate_allreduce_input(x, world_size=2) # no exception + + def test_rejects_world_size_zero(self): + """world_size=0 raises ValueError (avoids ZeroDivisionError).""" + x = torch.randn(4096, dtype=torch.float16) + with pytest.raises(ValueError, match="world_size"): + validate_allreduce_input(x, world_size=0) + + def test_rejects_world_size_negative(self): + """Negative world_size raises ValueError.""" + x = torch.randn(4096, dtype=torch.float16) + with pytest.raises(ValueError, match="world_size"): + validate_allreduce_input(x, world_size=-1) + + +class TestFlattenReshapeLogic: + """Verify flatten/reshape round-trip logic used by ark_allreduce (CPU tensors, no ARK dependency).""" + + def test_2d_roundtrip(self): + """Flatten to 1-D and reshape back preserves data and shape.""" + shape = (2048, 4096) + x = torch.randn(shape, dtype=torch.float16) + x_flat = x.reshape(-1) + assert x_flat.shape == (2048 * 4096,) + x_back = x_flat.reshape(shape) + assert x_back.shape == shape + assert torch.equal(x, x_back) + + def test_1d_roundtrip(self): + """1-D tensor reshape(-1) is a no-op.""" + x = torch.randn(4096, dtype=torch.float16) + x_flat = x.reshape(-1) + assert torch.equal(x, x_flat) + + def test_decode_shape(self): + """Decode shape (1, 4096) flattens to (4096,) and back.""" + shape = (1, 4096) + x = torch.randn(shape, dtype=torch.float16) + x_flat = x.reshape(-1) + assert x_flat.shape == (4096,) + x_back = x_flat.reshape(shape) + assert torch.equal(x, x_back) + + +class TestHasCompiledArk: + """Edge-case tests for ``_has_compiled_ark()``.""" + + def test_no_ark_subdir(self): + """Directory with no ark/ subdir returns False.""" + with tempfile.TemporaryDirectory() as tmpdir: + assert _has_compiled_ark(tmpdir) is False + + def test_source_tree_no_so(self): + """ark/__init__.py exists but no compiled .so/.pyd returns False.""" + with tempfile.TemporaryDirectory() as tmpdir: + ark_dir = os.path.join(tmpdir, "ark") + os.makedirs(ark_dir) + with open(os.path.join(ark_dir, "__init__.py"), "w") as f: + f.write("") + assert _has_compiled_ark(tmpdir) is False + + def test_fake_compiled_so(self): + """ark/__init__.py + fake core*.so returns True.""" + with tempfile.TemporaryDirectory() as tmpdir: + ark_dir = os.path.join(tmpdir, "ark") + os.makedirs(ark_dir) + with open(os.path.join(ark_dir, "__init__.py"), "w") as f: + f.write("") + # Create a fake compiled extension + fake_so = os.path.join( + ark_dir, "core.cpython-312-x86_64-linux-gnu.so" + ) + with open(fake_so, "w") as f: + f.write("") + assert _has_compiled_ark(tmpdir) is True + + def test_real_build_tree(self): + """Real build tree (via importlib) returns True.""" + spec = importlib.util.find_spec("ark") + if spec is None or not spec.submodule_search_locations: + pytest.skip("ark not importable in this environment") + ark_pkg_dir = next(iter(spec.submodule_search_locations)) + ark_parent = os.path.dirname(ark_pkg_dir) + assert _has_compiled_ark(ark_parent) is True + + +class TestSubprocessEnv: + """CPU-only tests for ``_subprocess_env()`` — no GPU required.""" + + def test_pythonpath_contains_ark_package(self): + """Returned env PYTHONPATH includes a dir where ark is importable.""" + env = _subprocess_env(world_size=2) + pythonpath = env.get("PYTHONPATH", "") + paths = pythonpath.split(os.pathsep) + # At least one path must contain ark/__init__.py or ark/core*.so + found = False + for p in paths: + ark_dir = os.path.join(p, "ark") + if os.path.isfile(os.path.join(ark_dir, "__init__.py")): + found = True + break + # Also check for compiled extension (namespace package case) + if os.path.isdir(ark_dir): + import glob + + if glob.glob(os.path.join(ark_dir, "core*.so")): + found = True + break + assert ( + found + ), f"No ark-importable path found in PYTHONPATH: {pythonpath}" + + def test_pythonpath_no_duplicates(self): + """PYTHONPATH entries are not duplicated by the resolution logic.""" + env = _subprocess_env(world_size=2) + pythonpath = env.get("PYTHONPATH", "") + paths = pythonpath.split(os.pathsep) + # Filter out inherited PYTHONPATH (may have dupes we don't control) + inherited = os.environ.get("PYTHONPATH", "") + inherited_parts = ( + set(inherited.split(os.pathsep)) if inherited else set() + ) + own_paths = [p for p in paths if p not in inherited_parts] + assert len(own_paths) == len( + set(own_paths) + ), f"Duplicate entries in PYTHONPATH: {own_paths}" + + def test_cuda_visible_devices(self): + """CUDA_VISIBLE_DEVICES matches requested world_size.""" + env = _subprocess_env(world_size=4) + assert env["CUDA_VISIBLE_DEVICES"] == "0,1,2,3" + + def test_examples_parent_in_pythonpath(self): + """examples/ parent dir is in PYTHONPATH for sibling imports.""" + env = _subprocess_env(world_size=2) + pythonpath = env.get("PYTHONPATH", "") + examples_parent = os.path.dirname(_EXAMPLES_QWEN3_DIR) + assert examples_parent in pythonpath.split(os.pathsep) + + def test_ark_root_fallback(self): + """ARK_ROOT env var adds $ARK_ROOT/python to PYTHONPATH when it has compiled ark.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create a fake compiled ark under tmpdir/python/ark/ + ark_dir = os.path.join(tmpdir, "python", "ark") + os.makedirs(ark_dir) + with open(os.path.join(ark_dir, "__init__.py"), "w") as f: + f.write("") + fake_so = os.path.join( + ark_dir, "core.cpython-312-x86_64-linux-gnu.so" + ) + with open(fake_so, "w") as f: + f.write("") + old = os.environ.get("ARK_ROOT") + try: + os.environ["ARK_ROOT"] = tmpdir + env = _subprocess_env(world_size=2) + pythonpath = env.get("PYTHONPATH", "") + expected = os.path.join(tmpdir, "python") + assert expected in pythonpath.split(os.pathsep) + finally: + if old is None: + os.environ.pop("ARK_ROOT", None) + else: + os.environ["ARK_ROOT"] = old + + def test_ark_root_without_compiled_ark(self): + """ARK_ROOT pointing to dir without compiled ark does not add to PYTHONPATH.""" + with tempfile.TemporaryDirectory() as tmpdir: + old = os.environ.get("ARK_ROOT") + try: + os.environ["ARK_ROOT"] = tmpdir + env = _subprocess_env(world_size=2) + pythonpath = env.get("PYTHONPATH", "") + bad_path = os.path.join(tmpdir, "python") + assert bad_path not in pythonpath.split(os.pathsep) + finally: + if old is None: + os.environ.pop("ARK_ROOT", None) + else: + os.environ["ARK_ROOT"] = old + + +# ----------------------------------------------------------------------- +# Tier 2: Multi-GPU functional tests (skip on 1-GPU CI) +# ----------------------------------------------------------------------- + +_ALLREDUCE_WORKER_SCRIPT = ''' +"""Worker script for multi-GPU all-reduce test. + +Launched as a subprocess to avoid CUDA context pollution in the test process. +Each rank fills its tensor with (rank + 1), runs all-reduce, and checks +that the result equals sum(1..world_size). +""" +import os, sys +import torch +import ark +from ark.executor import Executor + +rank = int(sys.argv[1]) +world_size = int(sys.argv[2]) +n_elements = int(sys.argv[3]) + +ark.init() +ark.set_rank(rank) +ark.set_world_size(world_size) + +# Fill with rank + 1 +x = torch.full((n_elements,), rank + 1, dtype=torch.float16, device=f"cuda:{rank}") + +result = ark.all_reduce_packet(x, rank, world_size) + +with ark.Runtime() as rt: + rt.launch(device_id=rank) + rt.run() + out = result.to_torch() + +# Runtime.stop() only halts execution; Executor.reset() forces full mscclpp +# teardown before os._exit() skips Python's normal cleanup. +Executor.reset() + +# Expected: sum of (1 + 2 + ... + world_size) +expected = world_size * (world_size + 1) / 2 +if not torch.allclose(out, torch.full_like(out, expected), atol=1e-2, rtol=1e-2): + bad = (out - expected).abs().max().item() + print(f"FAIL rank={rank}: max_diff={bad}", file=sys.stderr) + sys.stderr.flush() + os._exit(1) +print(f"PASS rank={rank}") +sys.stdout.flush() +os._exit(0) +''' + + +def _run_allreduce_subprocess( + world_size: int, n_elements: int, timeout: int = 120 +): + """Spawn *world_size* workers, each running the all-reduce script.""" + procs = [] + for rank in range(world_size): + p = subprocess.Popen( + [ + sys.executable, + "-c", + _ALLREDUCE_WORKER_SCRIPT, + str(rank), + str(world_size), + str(n_elements), + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd="/", + env=_subprocess_env(world_size), + ) + procs.append(p) + + failures = [] + try: + for rank, p in enumerate(procs): + try: + stdout, stderr = p.communicate(timeout=timeout) + except subprocess.TimeoutExpired: + failures.append(f"rank {rank} timed out after {timeout}s") + break + if p.returncode != 0: + failures.append( + f"rank {rank} exit={p.returncode}: {stderr.decode().strip()[-500:]}" + ) + finally: + for p in procs: + p.kill() + p.wait() + + if failures: + raise AssertionError( + f"All-reduce failed for {len(failures)}/{world_size} ranks:\n" + + "\n".join(failures) + ) + + +# TODO(qwen3): test ark_allreduce() wrapper end-to-end. +# The wrapper's init/validate/flatten/reduce/reshape pipeline is not yet +# exercised by multi-GPU tests — workers call ark.all_reduce_packet() +# directly for finer control over tensor construction and result checking. + + +@requires_multi_gpu +def test_allreduce_decode_tp2(): + """All-reduce at decode shape (4096 elems) with TP=2.""" + _run_allreduce_subprocess(world_size=2, n_elements=4096) + + +@requires_multi_gpu +def test_allreduce_prefill_tp2(): + """All-reduce at prefill shape (8,388,608 elems) with TP=2.""" + _run_allreduce_subprocess(world_size=2, n_elements=2048 * 4096) + + +@requires_multi_gpu +@pytest.mark.skipif( + _NUM_GPUS < 8, + reason=f"Need >= 8 GPUs, have {_NUM_GPUS}", +) +def test_allreduce_prefill_tp8(): + """All-reduce at prefill shape (8,388,608 elems) with TP=8.""" + _run_allreduce_subprocess(world_size=8, n_elements=2048 * 4096) + + +@requires_multi_gpu +@pytest.mark.skipif( + _NUM_GPUS < 8, + reason=f"Need >= 8 GPUs, have {_NUM_GPUS}", +) +def test_allreduce_decode_tp8(): + """All-reduce at decode shape (4096 elems) with TP=8.""" + _run_allreduce_subprocess(world_size=8, n_elements=4096)