Skip to content

Fix scatter_reduce(reduce="mean") producing incorrect ONNX export results#2829

Open
Copilot wants to merge 4 commits into
mainfrom
copilot/fix-scatter-reduce-onnx-export
Open

Fix scatter_reduce(reduce="mean") producing incorrect ONNX export results#2829
Copilot wants to merge 4 commits into
mainfrom
copilot/fix-scatter-reduce-onnx-export

Conversation

Copilot AI commented Feb 25, 2026

Copy link
Copy Markdown
Contributor
  • Fix aten_scatter_reduce in core.py to implement reduce="mean" correctly
    • include_self=True: sum/count approach including self in both
    • include_self=False: use Where to preserve self[i] for positions with no scattered values (fixes CI failure where original values were incorrectly replaced with 0)
  • Fix lint error in e2e_ops_tests.py (forward method signature too long)
  • Update xfail in ops_test_data.py to remove general mean xfail
  • Add e2e tests for scatter_reduce with reduce="mean" for both include_self values
Original prompt

This section details on the original issue you should resolve

<issue_title>torch.export + torch.onnx.export(dynamo=True) gives incorrect results for scatter_reduce_(reduce="mean")</issue_title>
<issue_description>### 🐛 Describe the bug

Observed behaviour:

  • scatter_reduce(mean) shows a large mismatch after export:
    • max_abs_diff: 10.0
    • mean_abs_diff: 5.5
  • Equivalent sum/count control matches exactly:
    • max_abs_diff: 0.0
    • mean_abs_diff: 0.0

Expected behaviour:

  • ONNX output should match eager PyTorch semantics (within normal floating-point tolerance) for:
    • scatter_reduce_(reduce="mean", include_self=False)

Impact:

  • Silent numerical correctness issue (wrong predictions without a crash).

Environment:

  • torch: 2.7.1
  • onnxruntime: 1.24.1
  • python: 3.12
  • os: macOS (darwin 24.6.0)

Code example:

import numpy as np
import onnxruntime as ort
import torch


class ScatterMeanModel(torch.nn.Module):
    def forward(self, h: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
        # h: [N, F], batch: [N] with group IDs in [0, G-1]
        index = batch.unsqueeze(1).repeat(1, h.shape[1])
        groups = batch.max().int() + 1
        out = torch.zeros(groups, h.shape[1], dtype=h.dtype, device=h.device)
        out = out.scatter_reduce_(0, index, h, reduce="mean", include_self=False)
        return out


class ScatterSumDivCountModel(torch.nn.Module):
    # Mathematically equivalent grouped mean = sum / count.
    def forward(self, h: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
        index = batch.unsqueeze(1).repeat(1, h.shape[1])
        groups = batch.max().int() + 1

        sums = torch.zeros(groups, h.shape[1], dtype=h.dtype, device=h.device)
        sums = sums.scatter_reduce_(0, index, h, reduce="sum", include_self=False)

        ones = torch.ones(h.shape[0], 1, dtype=h.dtype, device=h.device)
        counts = torch.zeros(groups, 1, dtype=h.dtype, device=h.device)
        counts = counts.scatter_reduce_(
            0, batch.unsqueeze(1), ones, reduce="sum", include_self=False
        )
        return sums / counts


def run(model: torch.nn.Module) -> tuple[np.ndarray, np.ndarray, float, float]:
    model.eval()
    h = torch.tensor(
        [
            [1.0, 10.0],
            [3.0, 30.0],
            [5.0, 50.0],
            [7.0, 70.0],
            [2.0, 20.0],
            [4.0, 40.0],
        ],
        dtype=torch.float32,
    )
    batch = torch.tensor([0, 0, 1, 1, 2, 2], dtype=torch.int64)

    with torch.inference_mode():
        pt = model(h, batch).cpu().numpy()

    exported = torch.export.export(model, (h, batch), strict=False)
    onnx_program = torch.onnx.export(exported, f=None, dynamo=True)

    sess = ort.InferenceSession(
        onnx_program.model_proto.SerializeToString(),
        providers=["CPUExecutionProvider"],
    )
    input_names = [i.name for i in sess.get_inputs()]
    ort_out = sess.run(
        None, {input_names[0]: h.numpy(), input_names[1]: batch.numpy()}
    )[0]

    diff = np.abs(pt - ort_out)
    return pt, ort_out, float(diff.max()), float(diff.mean())


print("torch:", torch.__version__)
print("onnxruntime:", ort.__version__)

pt, ort_out, max_abs, mean_abs = run(ScatterMeanModel())
print("\n=== scatter_reduce(mean) ===")
print("PyTorch output:\n", pt)
print("ONNX Runtime output:\n", ort_out)
print("max_abs_diff:", max_abs)
print("mean_abs_diff:", mean_abs)

pt2, ort_out2, max_abs2, mean_abs2 = run(ScatterSumDivCountModel())
print("\n=== sum/count control ===")
print("PyTorch output:\n", pt2)
print("ONNX Runtime output:\n", ort_out2)
print("max_abs_diff:", max_abs2)
print("mean_abs_diff:", mean_abs2)

Example output:

python tmp/repro_pytorch_scatter_reduce_mean_onnx.py
torch: 2.7.1
onnxruntime: 1.24.1
W0225 11:07:22.309000 94867 torch/onnx/_internal/exporter/_registration.py:103] torchvision is not installed. Skipping torchvision::nms
W0225 11:07:22.310000 94867 torch/onnx/_internal/exporter/_registration.py:103] torchvision is not installed. Skipping torchvision::roi_align
W0225 11:07:22.310000 94867 torch/onnx/_internal/exporter/_registration.py:103] torchvision is not installed. Skipping torchvision::roi_pool
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅

=== scatter_reduce(mean) ===
PyTorch output:
 [[ 2. 20.]
 [ 6. 60.]
 [ 3. 30.]]
ONNX Runtime output:
 [[ 3. 30.]
 [ 7. 70.]
 [ 4. 40.]]
max_abs_diff: 10.0
mean_abs_diff: 5.5
W0225 11:07:22.633000 94867 torch/onnx/_internal/exporter/_registration.py:103] torchvision is not installed. Skipping torchvision::nms
W0225 11:07:22.634000 94867 torch/onnx/_internal/exporter/_registration.py:103] torchvision is not installed. Skipping torchvision::roi_align
W0225 11:07:22.634000 94867 torch/onnx/_internal/exporter/_registration.py:103] torchvision is not installed. Skipping torchvision::roi_pool
[tor...

</details>



<!-- START COPILOT CODING AGENT SUFFIX -->

- Fixes pytorch/pytorch#175738

<!-- START COPILOT CODING AGENT TIPS -->
---

🔒 GitHub Advanced Security automatically protects Copilot coding agent pull requests. You can protect all pull requests by enabling Advanced Security for your repositories. [Learn more about Advanced Security.](https://gh.io/cca-advanced-security)

ONNX ScatterElements doesn't support 'mean' reduction. Previously the
code mapped 'mean' to 'none' (no reduction), which just overwrote values
instead of computing the mean.

The fix implements mean as sum/count:
- scatter_sum: ScatterElements with reduction='add' onto zeros
- scatter_count: ScatterElements of ones with reduction='add' onto zeros
- For include_self=True: add self to sum and 1 to count
- For include_self=False: use max(count, 1) to avoid div-by-zero
  (positions with count=0 also have sum=0, so 0/1=0 is correct)

Also removes the xfail for scatter_reduce mean in ops_test_data.py and
adds e2e tests for both include_self=True and include_self=False cases.

Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Copilot AI changed the title [WIP] Fix incorrect results for scatter_reduce with ONNX export Fix scatter_reduce(reduce="mean") producing incorrect ONNX export results Feb 25, 2026
Copilot AI requested a review from justinchuby February 25, 2026 15:17
@codecov

codecov Bot commented Feb 25, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 72.66%. Comparing base (ef2bc22) to head (1879fe5).
⚠️ Report is 56 commits behind head on main.
✅ All tests successful. No failed tests found.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2829      +/-   ##
==========================================
+ Coverage   71.78%   72.66%   +0.88%     
==========================================
  Files         239      259      +20     
  Lines       28989    31764    +2775     
  Branches     2859     3008     +149     
==========================================
+ Hits        20809    23082    +2273     
- Misses       7209     7662     +453     
- Partials      971     1020      +49     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

@justinchuby

Copy link
Copy Markdown
Collaborator

@copilot CI errors

Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>

Copilot AI commented Mar 3, 2026

Copy link
Copy Markdown
Contributor Author

@copilot CI errors

Fixed in 6614c55. There were two CI failures:

  1. Test failures (scatter_reduce_mean_cpu_float32): The include_self=False path was returning 0 for positions with no scattered values instead of preserving the original self[i]. Fixed by using Where(count==0, self, mean_vals).

  2. Lint failure (RUFF-FORMAT): The forward method signature in e2e_ops_tests.py was too long. Split into multiple lines.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@justinchuby justinchuby marked this pull request as ready for review June 19, 2026 18:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

2 participants