fix(networks): only instantiate cross-attention layers when with_cross_attention=True#8873
Conversation
…s_attention=True TransformerBlock was unconditionally creating norm_cross_attn and cross_attn even when with_cross_attention=False, causing unused parameters to appear in model.parameters() and named_parameters(). Guard the instantiation with an if-block so those layers only exist when they are actually used. Closes Project-MONAI#8845 Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
📝 WalkthroughWalkthrough
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tests/networks/blocks/test_transformerblock.py (2)
67-74: ⚡ Quick winStrengthen the enabled-case parameter assertion.
any("cross_attn" in n ...)also matchesnorm_cross_attn; this can pass without provingcross_attn.*exists.Proposed test tightening
self.assertTrue( - any("cross_attn" in n for n in param_names), - "Expected cross-attention parameters not found when with_cross_attention=True", + any(n.startswith("cross_attn.") for n in param_names), + "Expected cross_attn parameters not found when with_cross_attention=True", + ) + self.assertTrue( + any(n.startswith("norm_cross_attn.") for n in param_names), + "Expected norm_cross_attn parameters not found when with_cross_attention=True", )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/networks/blocks/test_transformerblock.py` around lines 67 - 74, The assertion in test_cross_attention_params_when_enabled is too permissive because any("cross_attn" in n ...) matches names like norm_cross_attn; update the check to assert presence of actual cross-attention parameter namespace by matching a stricter pattern (e.g., names that start with "cross_attn" or contain ".cross_attn."/"cross_attn."), e.g., replace the any(...) with a predicate that uses n.startswith("cross_attn") or checks for ".cross_attn" to ensure TransformerBlock(with_cross_attention=True) truly registers parameters under the cross_attn module namespace.
57-58: ⚡ Quick winUse Google-style docstrings for the new test methods.
Current one-line docstrings don’t follow the required structured format.
As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."
Also applies to: 67-68
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/networks/blocks/test_transformerblock.py` around lines 57 - 58, Replace the one-line docstrings in the test methods with Google-style docstrings: for the test_no_cross_attention_params_when_disabled (and the adjacent test at the following definition), expand the docstring to a short summary line followed by sections "Args:", "Returns:", and "Raises:" (even if empty) and describe the test inputs (if any), the expected outcome (no cross-attention parameters registered), and that the test returns None; keep the summary concise and use the function names test_no_cross_attention_params_when_disabled and the adjacent test to locate and update the docstrings accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/networks/blocks/test_transformerblock.py`:
- Around line 67-74: The assertion in test_cross_attention_params_when_enabled
is too permissive because any("cross_attn" in n ...) matches names like
norm_cross_attn; update the check to assert presence of actual cross-attention
parameter namespace by matching a stricter pattern (e.g., names that start with
"cross_attn" or contain ".cross_attn."/"cross_attn."), e.g., replace the
any(...) with a predicate that uses n.startswith("cross_attn") or checks for
".cross_attn" to ensure TransformerBlock(with_cross_attention=True) truly
registers parameters under the cross_attn module namespace.
- Around line 57-58: Replace the one-line docstrings in the test methods with
Google-style docstrings: for the test_no_cross_attention_params_when_disabled
(and the adjacent test at the following definition), expand the docstring to a
short summary line followed by sections "Args:", "Returns:", and "Raises:" (even
if empty) and describe the test inputs (if any), the expected outcome (no
cross-attention parameters registered), and that the test returns None; keep the
summary concise and use the function names
test_no_cross_attention_params_when_disabled and the adjacent test to locate and
update the docstrings accordingly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 88e3a7d4-470d-4b08-968c-6ac5fb14919a
📒 Files selected for processing (2)
monai/networks/blocks/transformerblock.pytests/networks/blocks/test_transformerblock.py
|
The failing CI checks are caused by a transient upstream dependency regression in Root cause: Evidence:
Fix scope: This is a separate issue from this PR. The cleanest mitigation would be to pin This PR's tests pass locally: |
Status & BlockersAll code changes are complete and tests pass locally (4/4). Two things are blocking merge: 1. CI Failures (external regression — not our code)The failing CI jobs ( The failure chain is: This is pre-existing on the 2. Awaiting Required ReviewThis PR needs at least one approving review from a code owner before it can be merged. @KumoLiu @ericspod @Nic-Ma — could one of you take a look when you get a chance? The change is in
|
I'm seeing the same issue and am working on a fix right now on a PR I hope to integrate soon. |
Summary
Fixes #8845
TransformerBlockwas unconditionally creatingnorm_cross_attnandcross_attnin__init__regardless of thewith_cross_attentionflag. This caused unused parameters to appear inmodel.parameters()/model.named_parameters(), increasing model size and producing confusing "no gradient" warnings during training.Root cause — two layers created unconditionally:
Fix — guard instantiation with the same flag already used in
forward:Changes
monai/networks/blocks/transformerblock.py: wrap cross-attention layer init inif with_cross_attention:tests/networks/blocks/test_transformerblock.py: add two regression tests verifying parameter presence/absenceTest plan
test_no_cross_attention_params_when_disabled— asserts nocross_attn/norm_cross_attnparams whenwith_cross_attention=Falsetest_cross_attention_params_when_enabled— asserts cross-attn params exist whenwith_cross_attention=Truetest_ill_arg— existing validation tests unchangedtest_shapetests cover bothwith_cross_attention=True/Falseforward passeshttps://claude.ai/code/session_01LV2dy8NFh3smu9f2RfgFvs