Skip to content

Fix incorrect rank handling in aten_repeat_interleave_self_int ONNX export #2932#2936

Open
PratikWayase wants to merge 6 commits into
microsoft:mainfrom
PratikWayase:fix-repeat-interleave-shape-inference
Open

Fix incorrect rank handling in aten_repeat_interleave_self_int ONNX export #2932#2936
PratikWayase wants to merge 6 commits into
microsoft:mainfrom
PratikWayase:fix-repeat-interleave-shape-inference

Conversation

@PratikWayase

Copy link
Copy Markdown

Summary

This PR fixes a ShapeInferenceError that occurs during the ONNX export of torch.repeat_interleave on 1D tensors. Additionally, it implements full support for the dim=None case, which was previously raising a NotImplementedError.

Problem

When exporting a model using torch.repeat_interleave on a 1D tensor, the exporter incorrectly returned an Identity node with a rank of 2 instead of 1. This caused a strict shape inference failure:

[ShapeInferenceError] Inferred shape and existing shape differ in rank: (2) vs (1)

Solution

  • Handled dim=None: Modified the logic in core.py to flatten the input tensor (op.Reshape(self, [-1])) and set dim = 0 and self_rank = 1 when dim is None. This allows the existing expansion logic to process it correctly.
  • Removed faulty shortcut: Deleted the incorrect if self_rank == 1: return op.Identity(tiled) block. This block was returning a tensor with an inflated rank. The code now correctly falls through to the final_shape calculation and op.Reshape to ensure the output rank perfectly matches the input rank.
  • Updated Tests: Removed the .skip() markers in ops_test_data.py that were intentionally bypassing tests for the dim=None case. This enables full test coverage for this newly supported scenario.

Testing

Related Issue

Fixes #2932

Comment thread onnxscript/function_libs/torch_lib/ops/core.py Fixed
Comment thread onnxscript/function_libs/torch_lib/ops/core.py Fixed

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the Torch-to-ONNX export implementation of aten::repeat_interleave to fix incorrect rank handling for 1D inputs and to add support for dim=None, aligning exported graphs with PyTorch semantics and preventing strict ONNX shape inference failures.

Changes:

  • Add handling for dim=None by flattening the input and routing through the existing expand/reshape logic.
  • Remove the prior self_rank == 1 shortcut that produced an incorrect output rank.
  • Enable previously skipped dim=None test cases in ops_test_data.py.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
onnxscript/function_libs/torch_lib/ops/core.py Adjusts aten_repeat_interleave_self_int rank/dim handling and removes the problematic 1D Identity shortcut.
tests/function_libs/torch_lib/ops_test_data.py Unskips dim=None test cases so they execute as part of the existing test suite.
Comments suppressed due to low confidence (1)

onnxscript/function_libs/torch_lib/ops/core.py:8304

  • The dim is None branch introduces redundant/ineffective self_rank logic and leaves dim unnormalized. self_rank is assigned in the branch and then immediately overwritten by self_rank = len(self.shape), and negative dim values are converted to pos_dim but dim is still later used (via end=dim), which can slice the wrong prefix when dim < 0. Normalize dim to pos_dim once and compute self_rank only once (after the optional flatten).
    if dim is None:
        self = op.Reshape(self,[-1])
        dim = 0
        self_rank = 1
    else:
        self_rank = len(self.shape)

    self_rank = len(self.shape)
    pos_dim = (dim + self_rank) % self_rank
    unsqueezed = op.Unsqueeze(self, [pos_dim + 1])

@codecov

codecov Bot commented Jun 15, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 0% with 7 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.65%. Comparing base (5989b56) to head (90ba456).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 0.00% 7 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2936      +/-   ##
==========================================
- Coverage   72.66%   72.65%   -0.01%     
==========================================
  Files         259      259              
  Lines       31748    31752       +4     
  Branches     3005     3005              
==========================================
  Hits        23069    23069              
- Misses       7660     7664       +4     
  Partials     1019     1019              

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

Comment thread tests/function_libs/torch_lib/ops_test.py Outdated
Comment thread onnxscript/function_libs/torch_lib/ops/core.py Fixed
Comment thread onnxscript/function_libs/torch_lib/ops/core.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

Comment on lines +8279 to +8283
"""repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor

The trick is to repeat in one direction orthogonal to reshape.

.. code-block:: python

x = torch.tensor([[0, 1, 2], [3, 4, 5]])
x.repeat_interleave(2, dim=0)

x = torch.tensor([[0, 1, 2], [3, 4, 5]])
x.repeat_interleave(2, dim=0)
@justinchuby

Copy link
Copy Markdown
Collaborator

The rank fix is correct — removing the if self_rank == 1: return op.Identity(tiled) early-return (which returned a rank-2 result and caused the (2)-vs-(1) shape-inference error) lets it fall through to the final_shape Concat+Reshape, giving the correct rank-1 output. Traced dim=None and dim=0/1 cases for 1D/2D all match torch.repeat_interleave. The removed .skip(dim is None) entries now run real OpInfo numerical parity, plus the new e2e test_repeat_interleave_int_dim_none. The 3 red CI jobs are all py311-torch-nightly failing on test_output_match_opinfo__logit_* — an unrelated torch-nightly numerical flake (this PR doesn't touch logit; same fails on main). Minor non-blocking nit: the docstring (~core.py L8278) lost its blank lines so the .. code-block:: python RST is now malformed — worth restoring. LGTM otherwise.


print(result)
"""


print(result)
"""

raise NotImplementedError("No conversion available yet when dim is None.")
flat_self = op.Reshape(self, [-1])
unsqueezed = op.Unsqueeze(flat_self, [1])

raise NotImplementedError("No conversion available yet when dim is None.")
flat_self = op.Reshape(self, [-1])
unsqueezed = op.Unsqueeze(flat_self, [1])

op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))),
axis=0,
)

self_rank = len(self.shape)
pos_dim = (dim + self_rank) % self_rank
unsqueezed = op.Unsqueeze(self, [pos_dim + 1])

op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)),
axis=0,
)

op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)),
axis=0,
)

op.Shape(self, start=pos_dim + 1),
axis=0,
)
)
op.Shape(self, start=pos_dim + 1),
axis=0,
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

bug in onnx export of aten_repeat_interleave_self_int results in incorrect code

4 participants