Skip to content

compiler: KeyError in MPI halo hoisting on nested HaloSpot (unguarded cond_mapper lookup) #2943

@ggorman

Description

@ggorman

Summary

Compiling an Operator (default opt) that contains a nested HaloSpot (a HaloSpot whose subtree contains another HaloSpot) crashes in the MPI halo-hoisting pass:

File ".../devito/passes/iet/mpi.py", line 89, in _hoist_redundant_from_conditionals
    conditions = cond_mapper[hs0]
KeyError: <HaloSpot(B)>

Reproduced on 4.8.21 and confirmed present on current main.

Root cause

In _hoist_redundant_from_conditionals (devito/passes/iet/mpi.py), cond_mapper = _make_cond_mapper(iet) does not register nested HaloSpots as keys, but the loop iterates every halo_spot from _filter_iter_mapper and does an unguarded lookup:

for hs0 in halo_spots:
    conditions = cond_mapper[hs0]   # KeyError when hs0 is a nested HaloSpot
    if not conditions:
        continue

A few lines down the same map is accessed defensively with cond_mapper.get(hs0) (in _merge_halospots), i.e. line 89 just missed the .get() guard.

Fix (one line)

conditions = cond_mapper.get(hs0)
if not conditions:
    continue

The existing if not conditions: continue already handles the empty/None case, so switching [].get() is sufficient and behaviour-preserving for present keys. PR attached.

Minimal reproducer

Two accumulator Eq-families over a per-row SubDomain zone grid, each reading a second Function at an offset in a different subset of rows → interleaved, partially-overlapping HaloSpot families → a nested HaloSpot. The crash fires iff a nested HaloSpot is present. Reproduces at n=10, nrows=2. opt='noop' avoids it. The script also verifies the .get() fix via a _make_cond_mapper monkeypatch.

"""Minimal failing example — Devito MPI halo-optimizer KeyError on nested HaloSpots.

SYMPTOM (devito 4.8.21, also present on origin/main as of 2026-06-04):

    File ".../devito/passes/iet/mpi.py", line 89, in _hoist_redundant_from_conditionals
        conditions = cond_mapper[hs0]
    KeyError: <HaloSpot(B)>

raised at compile time (default `opt`) for an Operator that contains a **nested**
HaloSpot — a HaloSpot whose subtree contains another HaloSpot. `_make_cond_mapper`
keys the outer/standalone HaloSpots, but `_hoist_redundant_from_conditionals`
iterates ALL halo_spots from `_filter_iter_mapper` (including the nested one) and
does an UNGUARDED `cond_mapper[hs0]` (mpi.py:89) — vs the DEFENSIVE
`cond_mapper.get(hs0)` used a few lines down (mpi.py:471). The nested HaloSpot is
missing from cond_mapper -> KeyError.

MINIMAL TRIGGER (this script): two accumulator Eq-families (`acc1`, `acc2`) over a
per-row SubDomain zone grid, each reading a SECOND Function `B` at an offset but in
a DIFFERENT subset of x-rows. The two interleaved, partially-overlapping HaloSpot
families for `B` produce a nested HaloSpot -> the KeyError. Reproduces at n=10,
nrows=2 (18 Eqs). `nested_hs` counts the nested HaloSpots; the crash fires iff
nested_hs >= 1.

More generally this arises in any Operator with many SubDomain-restricted `Eq`s in
which several `Function`s are read at offsets across SubDomain boundaries, producing
interleaved HaloSpot families.

WORKAROUNDS / FIX:
  * `Operator(..., opt='noop')` avoids the halo-optimizer pass entirely.
    Numerics unchanged; only loop optimization is skipped.
  * One-line library fix: guard the lookup at mpi.py:89, mirroring :471 —
    `conditions = cond_mapper.get(hs0)` then `if not conditions: continue`.
    This script VERIFIES that fix by monkeypatching `_make_cond_mapper` to return a
    `defaultdict(list)` (so the missing-key lookup yields `[]` and is skipped),
    WITHOUT editing the installed package.

Run:  DEVITO_ARCH=clang DEVITO_LANGUAGE=openmp uv run python halo_keyerror_mfe.py
"""
import sys
from collections import defaultdict
import numpy as np
from devito import Grid, Function, Eq, Operator, SubDomain
import devito.passes.iet.mpi as MPI
from devito.ir.iet import FindNodes
from devito.ir.iet.nodes import HaloSpot

_seen = {}
_orig_hoist = MPI._hoist_redundant_from_conditionals


def _probe(iet):
    """Count nested HaloSpots (a HaloSpot inside another HaloSpot's subtree)."""
    hss = FindNodes(HaloSpot).visit(iet)
    _seen['nested'] = sum(1 for hs in hss
                          if any(h is not hs for h in FindNodes(HaloSpot).visit(hs)))
    _seen['total'] = len(hss)
    return _orig_hoist(iet)


MPI._hoist_redundant_from_conditionals = _probe


def _sd(name, sx, sy):
    class _SD(SubDomain):
        pass
    _SD.name = name

    def define(self, dims):
        x, y = dims
        return {x: sx, y: sy}
    _SD.define = define
    return _SD()


def _zones(n, nrows):
    z = [(str(i), ('middle', i, n-1-i)) for i in range(nrows)]
    z.append(("I", ('middle', nrows, nrows)))
    return z


def build(n=10, nrows=2, so=4):
    zx, zy = _zones(n, nrows), _zones(n, nrows)
    subs, row = [], {}
    for li, (lx, tx) in enumerate(zx):
        for (ly, ty) in zy:
            nm = f"z_{lx}_{ly}"; subs.append(_sd(nm, tx, ty)); row[nm] = li
    grid = Grid(shape=(n, n), extent=(float(n-1),)*2, dtype=np.float64,
                subdomains=tuple(subs))
    A = Function(name="A", grid=grid, space_order=so)
    B = Function(name="B", grid=grid, space_order=so)
    acc1 = Function(name="acc1", grid=grid, space_order=so)
    acc2 = Function(name="acc2", grid=grid, space_order=so)
    x, y = grid.dimensions
    eqs = []
    for sd in subs:
        i = row[sd.name]
        r1 = A[x-1, y] + A[x+1, y] + A[x, y+1]
        if i >= 1:
            r1 = r1 + B[x+2, y]              # acc1 reads B in rows >= 1
        eqs.append(Eq(acc1, r1, subdomain=sd))
    for sd in subs:
        i = row[sd.name]
        r2 = A[x, y-1] + A[x, y+1] + A[x+1, y]
        if i >= 2:
            r2 = r2 + B[x+3, y]             # acc2 reads B in rows >= 2 (different)
        eqs.append(Eq(acc2, r2, subdomain=sd))
    return eqs


def run(opt="advanced", **kw):
    _seen.clear()
    eqs = build(**kw)
    try:
        Operator(eqs, name="mfe", opt=opt).apply()
        return "OK"
    except KeyError as e:
        import traceback
        return ("KEYERR:cond_mapper" if "cond_mapper" in traceback.format_exc()
                else f"KEYERR:{e}")


if __name__ == "__main__":
    print("== 1. default opt='advanced' (expect KeyError) ==")
    r1 = run(opt="advanced", n=10, nrows=2)
    print(f"   nested_hs={_seen.get('nested')} total_hs={_seen.get('total')} -> {r1}")

    print("== 2. workaround opt='noop' (expect OK) ==")
    r2 = run(opt="noop", n=10, nrows=2)
    print(f"   nested_hs={_seen.get('nested')} total_hs={_seen.get('total')} -> {r2}")

    print("== 3. proposed fix: cond_mapper.get(...) via defaultdict monkeypatch, "
          "default opt (expect OK) ==")
    _orig_mcm = MPI._make_cond_mapper
    MPI._make_cond_mapper = lambda iet: defaultdict(list, _orig_mcm(iet))
    try:
        r3 = run(opt="advanced", n=10, nrows=2)
    finally:
        MPI._make_cond_mapper = _orig_mcm
    print(f"   nested_hs={_seen.get('nested')} total_hs={_seen.get('total')} -> {r3}")

    ok = (r1.startswith("KEYERR:cond_mapper") and r2 == "OK" and r3 == "OK")
    print(f"\nMFE {'CONFIRMED' if ok else 'NOT reproduced as expected'}: "
          f"advanced->{r1}, noop->{r2}, fixed->{r3}")
    sys.exit(0 if ok else 1)

Output:

== 1. default opt='advanced' (expect KeyError) ==
   nested_hs=1 total_hs=7 -> KEYERR:cond_mapper
== 2. workaround opt='noop' (expect OK) ==
   nested_hs=0 total_hs=6 -> OK
== 3. proposed fix: cond_mapper.get(...) via defaultdict monkeypatch, default opt (expect OK) ==
   nested_hs=1 total_hs=7 -> OK
MFE CONFIRMED: advanced->KEYERR:cond_mapper, noop->OK, fixed->OK

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions