Skip to content

strided BLAS DGEMM path for ToT einsum contractions#559

Open
zhihao-deng wants to merge 15 commits into
masterfrom
zhihao/feature/strided-dgemm
Open

strided BLAS DGEMM path for ToT einsum contractions#559
zhihao-deng wants to merge 15 commits into
masterfrom
zhihao/feature/strided-dgemm

Conversation

@zhihao-deng
Copy link
Copy Markdown
Contributor

Summary

Lift per-cell ToT (ArenaTensor) einsum work to BLAS-3 GEMM wherever possible,
instead of looping per-cell ops.

Recast the following ArenaTensor einsum cases as strided GEMM:

  • hce+e (core ce+e, inner outer-product): ride the outer-contraction
    index into BLAS K
  • hce+ce (core ce+ce, inner contraction — guarded subset, not the general
    case): ride the outer-external index into BLAS M.

Everything outside these guarded regimes keeps the existing per-cell path, so
behavior is unchanged elsewhere.

Guards

A strided GEMM fires only when the cell run is "clean": all cells present,
uniform inner size, and a single constant inter-cell stride. Empty inners punch
holes that break contiguity, so we fall to segmented kernels: walk each run
and emit one strided GEMM per maximal contiguous segment of present cells,
skipping the holes (accumulating with β=1 across segments).

Notes

Still carries env-gated diagnostics (TA_GEMM_TIMING, TA_STRIDED_DGEMM_VERBOSE,
and the TA_STRIDED_DGEMM_COUNT build counters)

zhihao-deng and others added 15 commits May 30, 2026 19:40
Route the regime-A hc+e einsum (outer Hadamard + outer contraction,
inner outer-product) through the landed arena_strided_dgemm_ce_e core
(M=N=1, K=tile volume) in run_regime_a_arena, replacing the per-cell
rank-1 dger loop with one strided DGEMM per outer-contraction tile.
Gated to view+double arena ToT contraction with num_contract_ranks()==0;
all other kinds keep the per-cell path. Adds a regime_a_strided_disabled()
kill switch, tile/e2e/differential/edge tests, and a strided-vs-per-cell
benchmark (~7.3x on a C6H14-like shape).
… ranks

The einsum_tot arena-matches-owning tests iterate over all result tile
ordinals but only inspect tiles local to the calling rank, then assert
the per-rank elements_compared / result_outer_cells_seen counts (and the
fatal BOOST_REQUIRE_GT(elements_compared, 0u)) against the global
expected totals. That holds under np=1 (all tiles local) but fails under
np=2: each rank sees only its share, and a rank owning no result tiles
trips the REQUIRE_GT.

All-reduce the accumulators (gop.sum on the counts, gop.max on
max_abs_diff) before the assertions so every rank checks the true global
totals. Fixes the 14 np=2 einsum_tot failures.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants