Fix incorrect rank handling in aten_repeat_interleave_self_int ONNX export #2932#2936
Fix incorrect rank handling in aten_repeat_interleave_self_int ONNX export #2932#2936PratikWayase wants to merge 6 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates the Torch-to-ONNX export implementation of aten::repeat_interleave to fix incorrect rank handling for 1D inputs and to add support for dim=None, aligning exported graphs with PyTorch semantics and preventing strict ONNX shape inference failures.
Changes:
- Add handling for
dim=Noneby flattening the input and routing through the existing expand/reshape logic. - Remove the prior
self_rank == 1shortcut that produced an incorrect output rank. - Enable previously skipped
dim=Nonetest cases inops_test_data.py.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| onnxscript/function_libs/torch_lib/ops/core.py | Adjusts aten_repeat_interleave_self_int rank/dim handling and removes the problematic 1D Identity shortcut. |
| tests/function_libs/torch_lib/ops_test_data.py | Unskips dim=None test cases so they execute as part of the existing test suite. |
Comments suppressed due to low confidence (1)
onnxscript/function_libs/torch_lib/ops/core.py:8304
- The
dim is Nonebranch introduces redundant/ineffectiveself_ranklogic and leavesdimunnormalized.self_rankis assigned in the branch and then immediately overwritten byself_rank = len(self.shape), and negativedimvalues are converted topos_dimbutdimis still later used (viaend=dim), which can slice the wrong prefix whendim < 0. Normalizedimtopos_dimonce and computeself_rankonly once (after the optional flatten).
if dim is None:
self = op.Reshape(self,[-1])
dim = 0
self_rank = 1
else:
self_rank = len(self.shape)
self_rank = len(self.shape)
pos_dim = (dim + self_rank) % self_rank
unsqueezed = op.Unsqueeze(self, [pos_dim + 1])
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2936 +/- ##
==========================================
- Coverage 72.66% 72.65% -0.01%
==========================================
Files 259 259
Lines 31748 31752 +4
Branches 3005 3005
==========================================
Hits 23069 23069
- Misses 7660 7664 +4
Partials 1019 1019 ☔ View full report in Codecov by Harness. |
| """repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor | ||
|
|
||
| The trick is to repeat in one direction orthogonal to reshape. | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| x = torch.tensor([[0, 1, 2], [3, 4, 5]]) | ||
| x.repeat_interleave(2, dim=0) | ||
|
|
||
| x = torch.tensor([[0, 1, 2], [3, 4, 5]]) | ||
| x.repeat_interleave(2, dim=0) |
|
The rank fix is correct — removing the |
|
|
||
| print(result) | ||
| """ | ||
|
|
|
|
||
| print(result) | ||
| """ | ||
|
|
| raise NotImplementedError("No conversion available yet when dim is None.") | ||
| flat_self = op.Reshape(self, [-1]) | ||
| unsqueezed = op.Unsqueeze(flat_self, [1]) | ||
|
|
| raise NotImplementedError("No conversion available yet when dim is None.") | ||
| flat_self = op.Reshape(self, [-1]) | ||
| unsqueezed = op.Unsqueeze(flat_self, [1]) | ||
|
|
| op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))), | ||
| axis=0, | ||
| ) | ||
|
|
| self_rank = len(self.shape) | ||
| pos_dim = (dim + self_rank) % self_rank | ||
| unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) | ||
|
|
| op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)), | ||
| axis=0, | ||
| ) | ||
|
|
| op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)), | ||
| axis=0, | ||
| ) | ||
|
|
| op.Shape(self, start=pos_dim + 1), | ||
| axis=0, | ||
| ) | ||
| ) |
| op.Shape(self, start=pos_dim + 1), | ||
| axis=0, | ||
| ) | ||
| ) |
Summary
This PR fixes a
ShapeInferenceErrorthat occurs during the ONNX export oftorch.repeat_interleaveon 1D tensors. Additionally, it implements full support for thedim=Nonecase, which was previously raising aNotImplementedError.Problem
When exporting a model using
torch.repeat_interleaveon a 1D tensor, the exporter incorrectly returned anIdentitynode with a rank of 2 instead of 1. This caused a strict shape inference failure:Solution
dim=None: Modified the logic incore.pyto flatten the input tensor (op.Reshape(self, [-1])) and setdim = 0andself_rank = 1whendimisNone. This allows the existing expansion logic to process it correctly.if self_rank == 1: return op.Identity(tiled)block. This block was returning a tensor with an inflated rank. The code now correctly falls through to thefinal_shapecalculation andop.Reshapeto ensure the output rank perfectly matches the input rank..skip()markers inops_test_data.pythat were intentionally bypassing tests for thedim=Nonecase. This enables full test coverage for this newly supported scenario.Testing
pytest tests/function_libs/torch_lib/ops_test.py -k "repeat_interleave" -v, resulting in 4 passed, 0 failed. The previously skippeddim=Nonetest cases now execute and pass successfully.Related Issue
Fixes #2932