Skip to content

Fix aten_stft ONNX spec violations#2943

Open
justinchuby wants to merge 1 commit into
mainfrom
justinchu/fix-stft-spec-2942
Open

Fix aten_stft ONNX spec violations#2943
justinchuby wants to merge 1 commit into
mainfrom
justinchu/fix-stft-spec-2942

Conversation

@justinchuby

Copy link
Copy Markdown
Collaborator

Summary

aten_stft emitted an ONNX STFT node that violated the operator spec in two ways (reported in #2942 by @bas-aarts):

  1. frame_step / frame_length type mismatch. frame_step was built via op.Reshape(hop_length, [1]), producing a rank-1 tensor, while frame_length (n_fft) is a rank-0 scalar. The spec (T2 for both) requires them to share the same type/rank.
  2. Signal rank. The ONNX STFT signal must be rank 3 ([batch, signal_length, 1] for real input). torch.stft accepts rank-1 or rank-2 signals, and the code only unsqueezed a rank-1 signal up to rank 2, leaving the signal one dimension short of the spec.

Fix

  • Pass hop_length directly as frame_step (a rank-0 scalar, matching n_fft) and remove the frame_step_const = op.Reshape(...) line. hop_length is always a Python int in this trace-only function (input arg or n_fft // 4), so it is emitted as a rank-0 scalar exactly like n_fft.
  • Add self = op.Unsqueeze(self, [-1]) after the existing batch-dim handling so the signal becomes rank 3 for both rank-1 and rank-2 inputs.

Rank trace (verified by inspecting the emitted STFT node)

Before → after, STFT inputs [signal, frame_step, window, frame_length]:

input buggy rank fixed rank spec
signal 2 3 3
frame_step 1 0 scalar
window 1 1 1
frame_length 0 0 scalar

End-to-end shapes (fixed): rank-1 input [L] → batch unsqueeze [1, L] → trailing unsqueeze [1, L, 1] → STFT [1, frames, bins, 2] → Transpose [1, bins, frames, 2] → squeeze batch [bins, frames, 2]. Rank-2 input [B, L][B, L, 1][B, frames, bins, 2][B, bins, frames, 2]. Both match torch.stft's real output. The normalized and onesided paths are unaffected (only dtype/scaling, not rank).

Tests

  • The existing OpInfo (ops.aten.stft) and _testing.assert_onnx_program value-comparison harnesses pass both before and after — they don't strictly enforce the STFT rank, which is why this spec bug went unnoticed.
  • Added test_aten_stft_emits_spec_compliant_node (parameterized for rank-1 and rank-2 inputs) in tests/function_libs/torch_lib/e2e_ops_tests.py, which asserts the emitted STFT node's signal is rank 3 and frame_step/frame_length are both scalar. This test fails on the old code (2 != 3) and passes with the fix.

Ran:

  • pytest tests/function_libs/torch_lib/ops_test.py -k stft → 2 passed, 1 skipped, 2 xfailed
  • pytest tests/function_libs/torch_lib/e2e_ops_tests.py -k stft → 6 passed
  • lintrunner -a → no lint issues

Fixes #2942

…nk-3 signal)

The ONNX STFT op requires:
- a rank-3 signal of shape [batch, signal_length, 1], and
- frame_step and frame_length to share the same (scalar) type.

aten_stft previously passed a rank-1 frame_step (Reshape of hop_length)
while frame_length (n_fft) was a rank-0 scalar, and only reshaped the
signal up to rank 2. This produced STFT nodes that violate the spec.

Fix by passing hop_length directly as the scalar frame_step and adding a
trailing [1] dimension to the signal so it is rank 3 for both rank-1 and
rank-2 torch.stft inputs. Adds an e2e regression test asserting the
emitted STFT node is spec-compliant for rank-1 and rank-2 inputs.

Fixes #2942

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@codecov

codecov Bot commented Jun 19, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 66.66667% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 72.66%. Comparing base (029441f) to head (0f94339).
⚠️ Report is 12 commits behind head on main.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2943      +/-   ##
==========================================
+ Coverage   72.61%   72.66%   +0.04%     
==========================================
  Files         259      259              
  Lines       31597    31748     +151     
  Branches     2973     3005      +32     
==========================================
+ Hits        22945    23069     +124     
- Misses       7643     7660      +17     
- Partials     1009     1019      +10     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

@justinchuby justinchuby changed the title Fix aten_stft ONNX spec violations (Fixes #2942) Fix aten_stft ONNX spec violations Jun 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

ONNX export at aten_stft generates STFT layer that violates the ONNX spec

1 participant