Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions src/underworld3/function/_function.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def global_evaluate_nd( expr,
check_extrapolated=False,
force_l2=False,
smoothing=1e-6,
local_fallback=True,
):

"""
Expand All @@ -368,6 +369,21 @@ def global_evaluate_nd( expr,
Users should typically use :func:`underworld3.function.global_evaluate`
which provides automatic unit handling and a cleaner interface.

Contract: this is a faithful *parallel* counterpart of :func:`evaluate` —
a query point is interpolated wherever in the mesh it lands (on any rank),
a point just outside the mesh is extrapolated from its nearest cell (the
globally nearest by a centroid-distance heuristic, with the lowest rank as
the tie-break in parallel),
and ``check_extrapolated`` returns an inside/outside flag per point. The
result is independent of the number of ranks (up to the rank-local
extrapolation residual near partition seams). Points that no rank can
locate in-cell are resolved by a best-claim reduction over ranks (see the
out-of-domain block below); pass ``local_fallback=False`` to restore the
legacy behaviour where such points returned silently-wrong values. The
``GE_LOCAL_FALLBACK`` environment variable, if set, overrides the kwarg
(an operator escape hatch retained from the parallel-deadlock debugging
history; the kwarg is the supported control surface).

Note it is not efficient to call this function to evaluate an expression at
a single coordinate. Instead the user should provide a numpy array of all
coordinates requiring evaluation.
Expand Down Expand Up @@ -520,6 +536,109 @@ def global_evaluate_nd( expr,
return_value[index, :, :] = data_container.array[:, :, :]
return_mask[index] = is_extrapolated.array[:]

# ------------------------------------------------------------------
# Out-of-domain extrapolation — keep the parallel result a faithful
# match for the serial ``evaluate()`` contract: interpolate a point
# wherever it lands across ranks, extrapolate a point just outside the
# mesh, and flag inside/outside.
Comment on lines +539 to +543
#
# After the migrate round-trip, a query point that NO rank could locate
# in one of its cells returns flagged-extrapolated but valued from
# whichever rank the bare dm.migrate happened to strand it on — typically
# a geometrically far, WRONG cell (the classic symptom is an annulus
# boundary point reading a value from the opposite side of the domain).
# Serial ``evaluate()`` instead extrapolates from the TRUE nearest cell.
# Restore that contract with a "best-claim" reduction over the (small,
# boundary-layer) stranded set:
#
# 1. allgather the extrapolated points so every rank holds the SAME
# global set;
# 2. each rank reports, per point, its nearest-local-cell distance and
# its LOCAL rbf extrapolation of the field there;
# 3. Allreduce(MIN distance) + Allreduce(MIN rank) tie-break picks the
# rank whose nearest cell is globally closest, and Allreduce(SUM of
# the winner-only value/flag) scatters that rank's extrapolation back.
#
# A point some rank actually contains (distance ~ 0) naturally wins, so
# only genuinely-stranded points are corrected. Cost is O(boundary points)
# — no dense global tree, no exhaustive search.
#
# DEADLOCK SAFETY — read before editing. Every collective here (allgather,
# Allreduce) runs unconditionally on the IDENTICAL global set on every
# rank, so all ranks stay in lockstep (n_ext_total is itself a reduced
# value, so the `> 0` guard is taken identically everywhere). The per-rank
# value MUST come from the LOCAL rbf path (rbf=True): the FE interpolation
# path (petsc_interpolate / DMInterpolation) is itself collective and would
# desync here, because each rank classifies the same global set against its
# own domain (different interior-point counts) → hang. Never route the
# fallback value through FE interpolation.
#
# Serial is left untouched (the serial path above already extrapolates from
# the true nearest cell). Escape hatch: GE_LOCAL_FALLBACK=0 restores the
# legacy (silently-wrong out-of-domain) behaviour; default on.
# ------------------------------------------------------------------
import os
# The kwarg is the supported control; an explicitly-set env var overrides
# it (operator escape hatch — see DEADLOCK SAFETY / contract docstring).
_env_fallback = os.environ.get("GE_LOCAL_FALLBACK")
if _env_fallback is not None:
_local_fallback = _env_fallback.strip().lower() not in ("0", "off", "false", "no", "")
else:
_local_fallback = bool(local_fallback)
if uw.mpi.size > 1 and _local_fallback:
from mpi4py import MPI

comm = uw.mpi.comm
ext_idx = np.where(return_mask[:, 0, 0])[0]
ext_coords = np.ascontiguousarray(coords_array[ext_idx], dtype=np.double)

counts = np.array(comm.allgather(ext_coords.shape[0]), dtype=int)
n_ext_total = int(counts.sum())

if n_ext_total > 0:
parts = comm.allgather(ext_coords)
all_ext = np.concatenate(
[p for p in parts if p.size], axis=0).reshape(n_ext_total, -1)

# This rank's local rbf extrapolation of the global set. NON-collective
# value path — see DEADLOCK SAFETY above (must be rbf=True, never FE).
ext_vals, ext_flag = evaluate_nd(
expr, all_ext, rbf=True, evalf=False, verbose=False,
check_extrapolated=True,)
ext_vals = np.ascontiguousarray(
np.asarray(ext_vals, dtype=np.double).reshape((n_ext_total,) + expr_shape))
ext_flag = np.asarray(ext_flag).reshape(n_ext_total).astype(np.int32)

# Nearest-local-cell distance for every point (local kd-tree query).
mesh._build_kd_tree_index()
dist2, _ = mesh._centroid_index.query(all_ext, k=1, sqr_dists=True)
dist2 = np.ascontiguousarray(np.asarray(dist2, dtype=np.double).ravel())
Comment on lines +605 to +615

# Globally-nearest cell per point, lowest rank as the tie-break.
min_dist2 = np.empty(n_ext_total, dtype=np.double)
comm.Allreduce([dist2, MPI.DOUBLE], [min_dist2, MPI.DOUBLE], op=MPI.MIN)
my_claim = np.where(dist2 <= min_dist2 * (1.0 + 1e-12) + 1e-300,
comm.rank, comm.size).astype(np.int32)
win_rank = np.empty(n_ext_total, dtype=np.int32)
comm.Allreduce([my_claim, MPI.INT], [win_rank, MPI.INT], op=MPI.MIN)
i_win = (win_rank == comm.rank)

# Winner contributes value+flag, everyone else zero; SUM selects it.
contrib_val = np.ascontiguousarray(
np.where(i_win[:, None, None], ext_vals, 0.0))
best_val = np.empty_like(contrib_val)
comm.Allreduce([contrib_val, MPI.DOUBLE], [best_val, MPI.DOUBLE], op=MPI.SUM)
contrib_flag = np.where(i_win, ext_flag, 0).astype(np.int32)
best_flag = np.empty(n_ext_total, dtype=np.int32)
comm.Allreduce([contrib_flag, MPI.INT], [best_flag, MPI.INT], op=MPI.SUM)

# Scatter this rank's segment of the global set back to its points.
offset = int(counts[:comm.rank].sum())
seg = slice(offset, offset + ext_coords.shape[0])
if ext_idx.size:
return_value[ext_idx, :, :] = best_val[seg]
return_mask[ext_idx, 0, 0] = best_flag[seg].astype(bool)

if not check_extrapolated:
return return_value
else:
Expand Down
14 changes: 14 additions & 0 deletions src/underworld3/function/functions_unit_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def _global_evaluate_impl(
# Expert overrides (override mode settings)
rbf=None,
force_l2=None,
local_fallback=True,
):
"""
Global evaluate with automatic unit-aware results.
Expand Down Expand Up @@ -434,6 +435,12 @@ def _global_evaluate_impl(
Expert override: Force RBF interpolation everywhere. Overrides mode.
force_l2 : bool, optional
Expert override: Force L2 projection path. Overrides mode.
local_fallback : bool, optional
Parallel-only. When True (default), query points that no rank can
locate in-cell are resolved by a best-claim reduction so the parallel
result matches serial ``evaluate`` (extrapolate from the true nearest
cell). Set False to restore the legacy silently-wrong out-of-domain
behaviour. The ``GE_LOCAL_FALLBACK`` env var, if set, overrides this.

Returns
-------
Expand Down Expand Up @@ -560,6 +567,7 @@ def _global_evaluate_impl(
check_extrapolated=check_extrapolated,
force_l2=force_l2_flag,
smoothing=smoothing,
local_fallback=local_fallback,
)

# Step 2: Re-dimensionalize and wrap with units (GATEWAY PRINCIPLE)
Expand Down Expand Up @@ -865,6 +873,7 @@ def global_evaluate(
rbf=None,
force_l2=None,
monotone=False,
local_fallback=True,
):
"""Parallel-safe evaluate with automatic unit-aware results.

Expand All @@ -879,6 +888,10 @@ def global_evaluate(
Opt-in bounded (monotone) interpolation post-process. See
:func:`evaluate` for semantics. Not supported together with
``check_extrapolated`` (raises ``NotImplementedError``).
local_fallback : bool, optional
Parallel-only best-claim resolution of out-of-domain points so the
result matches serial ``evaluate``. Default True. See
:func:`_global_evaluate_impl`.
"""
# Validate up front so an unknown option or the unsupported
# monotone + check_extrapolated combination fails fast (no wasted eval).
Expand All @@ -903,6 +916,7 @@ def global_evaluate(
smoothing=smoothing,
rbf=rbf,
force_l2=force_l2,
local_fallback=local_fallback,
)

if monotone_mode is None:
Expand Down
Loading