Fix aten_stft ONNX spec violations#2943
Open
justinchuby wants to merge 1 commit into
Open
Conversation
…nk-3 signal) The ONNX STFT op requires: - a rank-3 signal of shape [batch, signal_length, 1], and - frame_step and frame_length to share the same (scalar) type. aten_stft previously passed a rank-1 frame_step (Reshape of hop_length) while frame_length (n_fft) was a rank-0 scalar, and only reshaped the signal up to rank 2. This produced STFT nodes that violate the spec. Fix by passing hop_length directly as the scalar frame_step and adding a trailing [1] dimension to the signal so it is rank 3 for both rank-1 and rank-2 torch.stft inputs. Adds an e2e regression test asserting the emitted STFT node is spec-compliant for rank-1 and rank-2 inputs. Fixes #2942 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2943 +/- ##
==========================================
+ Coverage 72.61% 72.66% +0.04%
==========================================
Files 259 259
Lines 31597 31748 +151
Branches 2973 3005 +32
==========================================
+ Hits 22945 23069 +124
- Misses 7643 7660 +17
- Partials 1009 1019 +10 ☔ View full report in Codecov by Harness. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
aten_stftemitted an ONNXSTFTnode that violated the operator spec in two ways (reported in #2942 by @bas-aarts):frame_step/frame_lengthtype mismatch.frame_stepwas built viaop.Reshape(hop_length, [1]), producing a rank-1 tensor, whileframe_length(n_fft) is a rank-0 scalar. The spec (T2for both) requires them to share the same type/rank.STFTsignal must be rank 3 ([batch, signal_length, 1]for real input).torch.stftaccepts rank-1 or rank-2 signals, and the code only unsqueezed a rank-1 signal up to rank 2, leaving the signal one dimension short of the spec.Fix
hop_lengthdirectly asframe_step(a rank-0 scalar, matchingn_fft) and remove theframe_step_const = op.Reshape(...)line.hop_lengthis always a Python int in this trace-only function (input arg orn_fft // 4), so it is emitted as a rank-0 scalar exactly liken_fft.self = op.Unsqueeze(self, [-1])after the existing batch-dim handling so the signal becomes rank 3 for both rank-1 and rank-2 inputs.Rank trace (verified by inspecting the emitted STFT node)
Before → after, STFT inputs
[signal, frame_step, window, frame_length]:End-to-end shapes (fixed): rank-1 input
[L]→ batch unsqueeze[1, L]→ trailing unsqueeze[1, L, 1]→ STFT[1, frames, bins, 2]→ Transpose[1, bins, frames, 2]→ squeeze batch[bins, frames, 2]. Rank-2 input[B, L]→[B, L, 1]→[B, frames, bins, 2]→[B, bins, frames, 2]. Both matchtorch.stft's real output. Thenormalizedandonesidedpaths are unaffected (only dtype/scaling, not rank).Tests
ops.aten.stft) and_testing.assert_onnx_programvalue-comparison harnesses pass both before and after — they don't strictly enforce the STFT rank, which is why this spec bug went unnoticed.test_aten_stft_emits_spec_compliant_node(parameterized for rank-1 and rank-2 inputs) intests/function_libs/torch_lib/e2e_ops_tests.py, which asserts the emitted STFT node'ssignalis rank 3 andframe_step/frame_lengthare both scalar. This test fails on the old code (2 != 3) and passes with the fix.Ran:
pytest tests/function_libs/torch_lib/ops_test.py -k stft→ 2 passed, 1 skipped, 2 xfailedpytest tests/function_libs/torch_lib/e2e_ops_tests.py -k stft→ 6 passedlintrunner -a→ no lint issuesFixes #2942